LCOV - code coverage report
Current view: top level - Src/Compute - SpComputeEngine.hpp (source / functions) Hit Total Coverage
Test: Coverage example Lines: 103 114 90.4 %
Date: 2021-12-02 17:21:05 Functions: 28 28 100.0 %

          Line data    Source code
       1             : #ifndef SPCOMPUTEENGINE_HPP
       2             : #define SPCOMPUTEENGINE_HPP
       3             : 
       4             : #include <memory>
       5             : #include <optional>
       6             : #include <utility>
       7             : #include <algorithm>
       8             : #include <iterator>
       9             : #include <atomic>
      10             : 
      11             : #include "Compute/SpWorker.hpp"
      12             : #include "Scheduler/SpPrioScheduler.hpp"
      13             : #include "Utils/small_vector.hpp"
      14             : 
      15             : class SpAbstractTaskGraph;
      16             : 
      17             : class SpComputeEngine {
      18             : 
      19             : private:
      20             :     small_vector<std::unique_ptr<SpWorker>> workers;
      21             :     std::mutex ceMutex;
      22             :     std::condition_variable ceCondVar;
      23             :     std::mutex migrationMutex;
      24             :     std::condition_variable migrationCondVar;
      25             :     SpPrioScheduler prioSched;
      26             :     std::atomic<long int> nbWorkersToMigrate;
      27             :     std::atomic<long int> migrationSignalingCounter;
      28             :     SpWorker::SpWorkerType workerTypeToMigrate;
      29             :     SpComputeEngine* ceToMigrateTo;
      30             :     long int nbAvailableCpuWorkers;
      31             :     long int nbAvailableGpuWorkers;
      32             :     long int totalNbCpuWorkers;
      33             :     long int totalNbGpuWorkers;
      34             :     bool hasBeenStopped;
      35             : 
      36             : private:
      37             :     
      38           2 :     auto sendWorkersToInternal(SpComputeEngine *otherCe, const SpWorker::SpWorkerType wt, const long int maxCount, const bool allowBusyWorkersToBeDetached) {
      39           2 :         small_vector<std::unique_ptr<SpWorker>> res;
      40             :         using iter_t = small_vector<std::unique_ptr<SpWorker>>::iterator;
      41             :         
      42             :         auto computeNbWorkersToDetach =
      43           2 :         [&]() {
      44             :             auto compute = 
      45           2 :             [](long int nbTotal, long int nbWaiting, const bool allowBusyWorToBeDeta, const long int max) {
      46           2 :                 if(allowBusyWorToBeDeta) {
      47           2 :                     return std::min(nbTotal, max);
      48             :                 } else {
      49           0 :                     return std::min(nbWaiting, max);
      50             :                 }
      51             :             };
      52           2 :             switch(wt) {
      53           2 :                 case SpWorker::SpWorkerType::CPU_WORKER:
      54           2 :                     return compute(totalNbCpuWorkers, nbAvailableCpuWorkers, allowBusyWorkersToBeDetached, maxCount);
      55           0 :                 case SpWorker::SpWorkerType::GPU_WORKER:
      56           0 :                     return compute(totalNbGpuWorkers, nbAvailableGpuWorkers, allowBusyWorkersToBeDetached, maxCount);
      57           0 :                 default:
      58           0 :                     return static_cast<long int>(0);
      59             :             }
      60           2 :         };
      61             :         
      62             :         const auto nbWorkersToDetach =
      63           2 :         [&]()    
      64             :         {
      65           2 :             std::unique_lock<std::mutex> computeEngineLock(ceMutex);
      66             :             
      67           2 :             auto result = computeNbWorkersToDetach();
      68             :             
      69           2 :             if(result > 0) {
      70           4 :                 workerTypeToMigrate = wt;
      71           4 :                 ceToMigrateTo = otherCe;
      72           2 :                 migrationSignalingCounter.store(result, std::memory_order_relaxed);
      73           2 :                 nbWorkersToMigrate.store(result, std::memory_order_release);
      74             :             }
      75             :             
      76           4 :             return result;
      77           2 :         }();
      78             :         
      79           2 :         if(nbWorkersToDetach > 0) {
      80           2 :             ceCondVar.notify_all();
      81             :             
      82             :             {
      83           4 :                 std::unique_lock<std::mutex> migrationLock(migrationMutex);
      84          10 :                 migrationCondVar.wait(migrationLock, [&](){ return !(migrationSignalingCounter.load(std::memory_order_acquire) > 0); });
      85             :             }
      86             :             
      87           2 :             auto startIt = std::move_iterator<iter_t>(workers.begin());
      88           2 :             auto endIt = std::move_iterator<iter_t>(workers.end());
      89             :             
      90             :             auto eraseStartPosIt = std::remove_if(startIt, endIt,
      91           5 :                                                 [&](std::unique_ptr<SpWorker>&& wPtr) {
      92           5 :                                                     if(wPtr->getComputeEngine() != this) {
      93           4 :                                                         res.push_back(std::move(wPtr));
      94           4 :                                                         return true;
      95             :                                                     } else {
      96           1 :                                                         return false;
      97             :                                                     }
      98           2 :                                                 });
      99             :             
     100           2 :             workers.erase(eraseStartPosIt.base(), workers.end());
     101             :             
     102           4 :             std::unique_lock<std::mutex> computeEngineLock(ceMutex);
     103           2 :             updateWorkerCounters<true, true>(wt, -nbWorkersToDetach);
     104             :         
     105             :         }
     106             :         
     107           4 :         return res;
     108             :     }
     109             :     
     110             :     template <const bool bindAndStartWorkers>
     111          54 :     void addWorkersInternal(small_vector_base<std::unique_ptr<SpWorker>>&& inWorkers) {
     112         210 :         for(auto& w : inWorkers) {
     113         156 :             updateWorkerCounters<true,false>(w->getType(), +1);
     114             :             if constexpr(bindAndStartWorkers) {
     115         153 :                 w->bindTo(this);
     116         153 :                 w->start();
     117             :             }
     118             :         }
     119             :         
     120          54 :         if(workers.empty()) {
     121          52 :             workers = std::move(inWorkers);
     122             :         } else {
     123           2 :             workers.reserve(workers.size() + inWorkers.size());
     124           2 :             std::move(std::begin(inWorkers), std::end(inWorkers), std::back_inserter(workers));
     125             :         }
     126          54 :     }
     127             :     
     128        2779 :     bool areThereAnyWorkersToMigrate() const {
     129        5558 :         return nbWorkersToMigrate.load(std::memory_order_acquire) > 0;
     130             :     }
     131             :     
     132        2772 :     bool areThereAnyReadyTasks() const {
     133        2772 :         return prioSched.getNbTasks() > 0;
     134             :     }
     135             :     
     136           4 :     bool areWorkersToMigrateOfType(SpWorker::SpWorkerType inWt) {
     137           4 :         return workerTypeToMigrate == inWt;
     138             :     }
     139             :     
     140        1073 :     SpAbstractTask* getTask() {
     141        1073 :         return prioSched.pop();
     142             :     }
     143             :     
     144             :     template <const bool updateTotalCounter, const bool updateAvailableCounter>
     145        1032 :     void updateWorkerCounters(const SpWorker::SpWorkerType inWt, const long int addend) {
     146        1032 :         switch(inWt) {
     147        1032 :             case SpWorker::SpWorkerType::CPU_WORKER:
     148             :                 if constexpr(updateTotalCounter) {
     149         158 :                     totalNbCpuWorkers += addend;
     150             :                 }
     151             :                 if constexpr(updateAvailableCounter) {
     152         876 :                     nbAvailableGpuWorkers += addend;
     153             :                 }
     154        1032 :                 break;
     155           0 :             case SpWorker::SpWorkerType::GPU_WORKER:
     156             :                 if constexpr(updateTotalCounter) {
     157           0 :                     totalNbGpuWorkers += addend;
     158             :                 }
     159             :                 
     160             :                 if constexpr(updateAvailableCounter) {
     161           0 :                     nbAvailableGpuWorkers += addend;
     162             :                 }
     163           0 :                 break;
     164           0 :             default:
     165           0 :                 break;
     166             :         }
     167        1032 :     }
     168             :     
     169             :     void wait(SpWorker& worker, SpAbstractTaskGraph* atg);
     170             :     
     171           4 :     auto getCeToMigrateTo() {
     172           4 :         return ceToMigrateTo;
     173             :     }
     174             :     
     175           4 :     auto fetchDecNbOfWorkersToMigrate() {
     176           4 :         return nbWorkersToMigrate.fetch_sub(1, std::memory_order_relaxed);
     177             :     }
     178             :     
     179           2 :     void notifyMigrationFinished() {
     180             :         { 
     181           2 :             std::unique_lock<std::mutex> migrationLock(migrationMutex);
     182             :         }
     183           2 :         migrationCondVar.notify_one();
     184           2 :     }
     185             :     
     186           4 :     auto fetchDecMigrationSignalingCounter() {
     187           4 :         return migrationSignalingCounter.fetch_sub(1, std::memory_order_release);
     188             :     }
     189             :     
     190             :     friend void SpWorker::waitOnCe(SpComputeEngine* inCe, SpAbstractTaskGraph* atg);
     191             :     friend void SpWorker::doLoop(SpAbstractTaskGraph* atg);
     192             : 
     193             : public:
     194          52 :     explicit SpComputeEngine(small_vector_base<std::unique_ptr<SpWorker>>&& inWorkers)
     195          52 :     : workers(), ceMutex(), ceCondVar(), migrationMutex(), migrationCondVar(), prioSched(), nbWorkersToMigrate(0),
     196             :       migrationSignalingCounter(0),  workerTypeToMigrate(SpWorker::SpWorkerType::CPU_WORKER), ceToMigrateTo(nullptr), nbAvailableCpuWorkers(0),
     197          52 :       nbAvailableGpuWorkers(0), totalNbCpuWorkers(0), totalNbGpuWorkers(0), hasBeenStopped(false) {
     198          52 :         addWorkers(std::move(inWorkers));
     199          52 :     }
     200             :     
     201             :     explicit SpComputeEngine() : SpComputeEngine(small_vector<std::unique_ptr<SpWorker>, 0>{}) {}
     202             :     
     203          52 :     ~SpComputeEngine() {
     204          52 :         stopIfNotAlreadyStopped();
     205          52 :     }
     206             :     
     207        1069 :     void pushTask(SpAbstractTask* t) {
     208        1069 :         prioSched.push(t);
     209        1069 :         wakeUpWaitingWorkers();
     210        1069 :     }
     211             :     
     212          52 :     void pushTasks(small_vector_base<SpAbstractTask*>& tasks) {
     213          52 :         prioSched.pushTasks(tasks);
     214          52 :         wakeUpWaitingWorkers();
     215          52 :     }
     216             :     
     217          77 :     size_t getCurrentNbOfWorkers() const {
     218          77 :         return workers.size();
     219             :     }
     220             :     
     221          53 :     void addWorkers(small_vector_base<std::unique_ptr<SpWorker>>&& inWorkers) {
     222          53 :         addWorkersInternal<true>(std::move(inWorkers));
     223          53 :     }
     224             :     
     225           1 :     void sendWorkersTo(SpComputeEngine& otherCe, const SpWorker::SpWorkerType wt, const size_t maxCount, const bool allowBusyWorkersToBeDetached) {
     226           1 :         SpComputeEngine* otherCePtr = std::addressof(otherCe);
     227             :         
     228           1 :         if(otherCePtr && otherCePtr != this) {
     229           1 :             otherCePtr->addWorkersInternal<false>(sendWorkersToInternal(otherCePtr, wt, maxCount, allowBusyWorkersToBeDetached));
     230             :         }
     231           1 :     }
     232             :     
     233           1 :     auto detachWorkers(const SpWorker::SpWorkerType wt, const size_t maxCount, const bool allowBusyWorkersToBeDetached) {
     234           1 :         return sendWorkersToInternal(nullptr, wt, maxCount, allowBusyWorkersToBeDetached);
     235             :     }
     236             :     
     237             :     void stopIfNotAlreadyStopped();
     238             :     
     239        1219 :     void wakeUpWaitingWorkers() {
     240             :         {
     241        1219 :             std::unique_lock<std::mutex> ceLock(ceMutex);
     242             :         }
     243        1219 :         ceCondVar.notify_all();
     244        1219 :     }
     245             : };
     246             : 
     247             : #endif

Generated by: LCOV version 1.14