LCOV - code coverage report
Current view: top level - Src/Compute - SpWorker.hpp (source / functions) Hit Total Coverage
Test: Coverage example Lines: 46 56 82.1 %
Date: 2021-12-02 17:21:05 Functions: 14 14 100.0 %

          Line data    Source code
       1             : #ifndef SPWORKER_HPP
       2             : #define SPWORKER_HPP
       3             : 
       4             : #include <mutex>
       5             : #include <atomic>
       6             : #include <condition_variable>
       7             : #include <thread>
       8             : 
       9             : #include "Data/SpDataAccessMode.hpp"
      10             : #include "Utils/SpUtils.hpp"
      11             : #include "Task/SpAbstractTask.hpp"
      12             : #include "Utils/small_vector.hpp"
      13             : 
      14             : class SpComputeEngine;
      15             : class SpAbstractTaskGraph;
      16             : 
      17             : class SpWorker {
      18             : public:
      19             :     enum class SpWorkerType {
      20             :         CPU_WORKER,
      21             :         GPU_WORKER
      22             :     };
      23             :     
      24             :     static std::atomic<long int> totalNbThreadsCreated;
      25             :     
      26          50 :     static auto createATeamOfNCpuWorkers(const int nbCpuWorkers) {
      27          50 :         small_vector<std::unique_ptr<SpWorker>> res;
      28          50 :         res.reserve(nbCpuWorkers);
      29             :         
      30         198 :         for(int i = 0; i < nbCpuWorkers; i++) {
      31         148 :             res.emplace_back(std::make_unique<SpWorker>(SpWorker::SpWorkerType::CPU_WORKER));
      32             :         }
      33             :         
      34          50 :         return res;
      35             :     }
      36             :     
      37           1 :     static auto createDefaultWorkerTeam() {
      38           1 :         return createATeamOfNCpuWorkers(SpUtils::DefaultNumThreads());
      39             :     }
      40             :     
      41             :     static void setWorkerForThread(SpWorker *w);
      42             :     static SpWorker* getWorkerForThread();
      43             : 
      44             : private:
      45             :     const SpWorkerType wt;
      46             :     std::mutex workerMutex;
      47             :     std::condition_variable workerConditionVariable;
      48             :     std::atomic<bool> stopFlag;
      49             :     std::atomic<SpComputeEngine*> ce;
      50             :     long int threadId;
      51             :     std::thread t;
      52             :     
      53             : private:
      54         152 :     void setStopFlag(const bool inStopFlag) {
      55         152 :         stopFlag.store(inStopFlag, std::memory_order_relaxed);
      56         152 :     }
      57             :     
      58        1413 :     bool hasBeenStopped() const {
      59        1413 :         return stopFlag.load(std::memory_order_relaxed);
      60             :     }
      61             :     
      62           5 :     SpComputeEngine* getComputeEngine() const {
      63           5 :         return ce.load(std::memory_order_relaxed);
      64             :     }
      65             :     
      66        1073 :     void execute(SpAbstractTask *task) {
      67        1073 :         switch(this->getType()) {
      68        1073 :             case SpWorkerType::CPU_WORKER:
      69        1073 :                 task->execute(SpCallableType::CPU);
      70        1073 :                 break;
      71           0 :             case SpWorkerType::GPU_WORKER:
      72           0 :                 task->execute(SpCallableType::GPU);
      73           0 :                 break;
      74           0 :             default:
      75           0 :                 assert(false && "Worker is of unknown type.");
      76             :         }
      77        1073 :     }
      78             :     
      79         152 :     void waitForThread() {
      80         152 :         t.join();
      81         152 :     }
      82             :     
      83         152 :     void stop() {
      84         152 :         if(t.joinable()) {
      85           0 :             if(stopFlag.load(std::memory_order_relaxed)) {
      86             :                 {
      87           0 :                     std::unique_lock<std::mutex> workerLock(workerMutex);
      88           0 :                     stopFlag.store(true, std::memory_order_relaxed);
      89             :                 }
      90           0 :                 workerConditionVariable.notify_one();
      91             :             }
      92           0 :             waitForThread();
      93             :         }
      94         152 :     }
      95             :     
      96         153 :     void bindTo(SpComputeEngine* inCe) {
      97         153 :         if(inCe) {
      98             :             {
      99         306 :                 std::unique_lock workerLock(workerMutex);
     100         153 :                 ce.store(inCe, std::memory_order_release);
     101             :             }
     102         153 :             workerConditionVariable.notify_one();
     103             :         }
     104         153 :     }
     105             :     
     106           4 :     void idleWait() {
     107           8 :         std::unique_lock<std::mutex> workerLock(workerMutex);
     108          12 :         workerConditionVariable.wait(workerLock, [&]() { return stopFlag.load(std::memory_order_relaxed) || ce.load(std::memory_order_relaxed); });
     109           4 :     }
     110             :     
     111             :     void waitOnCe(SpComputeEngine* inCe, SpAbstractTaskGraph* atg);
     112             :     
     113             :     friend class SpComputeEngine;
     114             : 
     115             : public:
     116             : 
     117         152 :     explicit SpWorker(const SpWorkerType inWt) :
     118             :     wt(inWt), workerMutex(), workerConditionVariable(),
     119         152 :     stopFlag(false), ce(nullptr), threadId(0), t() {
     120         152 :         threadId = totalNbThreadsCreated.fetch_add(1, std::memory_order_relaxed);
     121         152 :     }
     122             : 
     123             :     SpWorker(const SpWorker& other) = delete;
     124             :     SpWorker(SpWorker&& other) = delete;
     125             :     SpWorker& operator=(const SpWorker& other) = delete;
     126             :     SpWorker& operator=(SpWorker&& other) = delete;
     127             :     
     128         152 :     ~SpWorker() {
     129         152 :         stop();
     130         152 :     }
     131             :     
     132        2103 :     SpWorkerType getType() const {
     133        2103 :         return wt;
     134             :     }
     135             :     
     136             :     void start();
     137             :     
     138             :     void doLoop(SpAbstractTaskGraph* inAtg);
     139             : };
     140             : 
     141             : #endif

Generated by: LCOV version 1.14