Line data Source code
1 : #ifndef SPWORKER_HPP 2 : #define SPWORKER_HPP 3 : 4 : #include <mutex> 5 : #include <atomic> 6 : #include <condition_variable> 7 : #include <thread> 8 : 9 : #include "Data/SpDataAccessMode.hpp" 10 : #include "Utils/SpUtils.hpp" 11 : #include "Task/SpAbstractTask.hpp" 12 : #include "Utils/small_vector.hpp" 13 : 14 : class SpComputeEngine; 15 : class SpAbstractTaskGraph; 16 : 17 : class SpWorker { 18 : public: 19 : enum class SpWorkerType { 20 : CPU_WORKER, 21 : GPU_WORKER 22 : }; 23 : 24 : static std::atomic<long int> totalNbThreadsCreated; 25 : 26 50 : static auto createATeamOfNCpuWorkers(const int nbCpuWorkers) { 27 50 : small_vector<std::unique_ptr<SpWorker>> res; 28 50 : res.reserve(nbCpuWorkers); 29 : 30 198 : for(int i = 0; i < nbCpuWorkers; i++) { 31 148 : res.emplace_back(std::make_unique<SpWorker>(SpWorker::SpWorkerType::CPU_WORKER)); 32 : } 33 : 34 50 : return res; 35 : } 36 : 37 1 : static auto createDefaultWorkerTeam() { 38 1 : return createATeamOfNCpuWorkers(SpUtils::DefaultNumThreads()); 39 : } 40 : 41 : static void setWorkerForThread(SpWorker *w); 42 : static SpWorker* getWorkerForThread(); 43 : 44 : private: 45 : const SpWorkerType wt; 46 : std::mutex workerMutex; 47 : std::condition_variable workerConditionVariable; 48 : std::atomic<bool> stopFlag; 49 : std::atomic<SpComputeEngine*> ce; 50 : long int threadId; 51 : std::thread t; 52 : 53 : private: 54 152 : void setStopFlag(const bool inStopFlag) { 55 152 : stopFlag.store(inStopFlag, std::memory_order_relaxed); 56 152 : } 57 : 58 1413 : bool hasBeenStopped() const { 59 1413 : return stopFlag.load(std::memory_order_relaxed); 60 : } 61 : 62 5 : SpComputeEngine* getComputeEngine() const { 63 5 : return ce.load(std::memory_order_relaxed); 64 : } 65 : 66 1073 : void execute(SpAbstractTask *task) { 67 1073 : switch(this->getType()) { 68 1073 : case SpWorkerType::CPU_WORKER: 69 1073 : task->execute(SpCallableType::CPU); 70 1073 : break; 71 0 : case SpWorkerType::GPU_WORKER: 72 0 : task->execute(SpCallableType::GPU); 73 0 : break; 74 0 : default: 75 0 : assert(false && "Worker is of unknown type."); 76 : } 77 1073 : } 78 : 79 152 : void waitForThread() { 80 152 : t.join(); 81 152 : } 82 : 83 152 : void stop() { 84 152 : if(t.joinable()) { 85 0 : if(stopFlag.load(std::memory_order_relaxed)) { 86 : { 87 0 : std::unique_lock<std::mutex> workerLock(workerMutex); 88 0 : stopFlag.store(true, std::memory_order_relaxed); 89 : } 90 0 : workerConditionVariable.notify_one(); 91 : } 92 0 : waitForThread(); 93 : } 94 152 : } 95 : 96 153 : void bindTo(SpComputeEngine* inCe) { 97 153 : if(inCe) { 98 : { 99 306 : std::unique_lock workerLock(workerMutex); 100 153 : ce.store(inCe, std::memory_order_release); 101 : } 102 153 : workerConditionVariable.notify_one(); 103 : } 104 153 : } 105 : 106 4 : void idleWait() { 107 8 : std::unique_lock<std::mutex> workerLock(workerMutex); 108 12 : workerConditionVariable.wait(workerLock, [&]() { return stopFlag.load(std::memory_order_relaxed) || ce.load(std::memory_order_relaxed); }); 109 4 : } 110 : 111 : void waitOnCe(SpComputeEngine* inCe, SpAbstractTaskGraph* atg); 112 : 113 : friend class SpComputeEngine; 114 : 115 : public: 116 : 117 152 : explicit SpWorker(const SpWorkerType inWt) : 118 : wt(inWt), workerMutex(), workerConditionVariable(), 119 152 : stopFlag(false), ce(nullptr), threadId(0), t() { 120 152 : threadId = totalNbThreadsCreated.fetch_add(1, std::memory_order_relaxed); 121 152 : } 122 : 123 : SpWorker(const SpWorker& other) = delete; 124 : SpWorker(SpWorker&& other) = delete; 125 : SpWorker& operator=(const SpWorker& other) = delete; 126 : SpWorker& operator=(SpWorker&& other) = delete; 127 : 128 152 : ~SpWorker() { 129 152 : stop(); 130 152 : } 131 : 132 2103 : SpWorkerType getType() const { 133 2103 : return wt; 134 : } 135 : 136 : void start(); 137 : 138 : void doLoop(SpAbstractTaskGraph* inAtg); 139 : }; 140 : 141 : #endif