Line data Source code
1 : #ifndef SPCOMPUTEENGINE_HPP
2 : #define SPCOMPUTEENGINE_HPP
3 :
4 : #include <memory>
5 : #include <optional>
6 : #include <utility>
7 : #include <algorithm>
8 : #include <iterator>
9 : #include <atomic>
10 :
11 : #include "Compute/SpWorker.hpp"
12 : #include "Scheduler/SpPrioScheduler.hpp"
13 : #include "Utils/small_vector.hpp"
14 :
15 : class SpAbstractTaskGraph;
16 :
17 : class SpComputeEngine {
18 :
19 : private:
20 : small_vector<std::unique_ptr<SpWorker>> workers;
21 : std::mutex ceMutex;
22 : std::condition_variable ceCondVar;
23 : std::mutex migrationMutex;
24 : std::condition_variable migrationCondVar;
25 : SpPrioScheduler prioSched;
26 : std::atomic<long int> nbWorkersToMigrate;
27 : std::atomic<long int> migrationSignalingCounter;
28 : SpWorker::SpWorkerType workerTypeToMigrate;
29 : SpComputeEngine* ceToMigrateTo;
30 : long int nbAvailableCpuWorkers;
31 : long int nbAvailableGpuWorkers;
32 : long int totalNbCpuWorkers;
33 : long int totalNbGpuWorkers;
34 : bool hasBeenStopped;
35 :
36 : private:
37 :
38 2 : auto sendWorkersToInternal(SpComputeEngine *otherCe, const SpWorker::SpWorkerType wt, const long int maxCount, const bool allowBusyWorkersToBeDetached) {
39 2 : small_vector<std::unique_ptr<SpWorker>> res;
40 : using iter_t = small_vector<std::unique_ptr<SpWorker>>::iterator;
41 :
42 : auto computeNbWorkersToDetach =
43 2 : [&]() {
44 : auto compute =
45 2 : [](long int nbTotal, long int nbWaiting, const bool allowBusyWorToBeDeta, const long int max) {
46 2 : if(allowBusyWorToBeDeta) {
47 2 : return std::min(nbTotal, max);
48 : } else {
49 0 : return std::min(nbWaiting, max);
50 : }
51 : };
52 2 : switch(wt) {
53 2 : case SpWorker::SpWorkerType::CPU_WORKER:
54 2 : return compute(totalNbCpuWorkers, nbAvailableCpuWorkers, allowBusyWorkersToBeDetached, maxCount);
55 0 : case SpWorker::SpWorkerType::GPU_WORKER:
56 0 : return compute(totalNbGpuWorkers, nbAvailableGpuWorkers, allowBusyWorkersToBeDetached, maxCount);
57 0 : default:
58 0 : return static_cast<long int>(0);
59 : }
60 2 : };
61 :
62 : const auto nbWorkersToDetach =
63 2 : [&]()
64 : {
65 2 : std::unique_lock<std::mutex> computeEngineLock(ceMutex);
66 :
67 2 : auto result = computeNbWorkersToDetach();
68 :
69 2 : if(result > 0) {
70 4 : workerTypeToMigrate = wt;
71 4 : ceToMigrateTo = otherCe;
72 2 : migrationSignalingCounter.store(result, std::memory_order_relaxed);
73 2 : nbWorkersToMigrate.store(result, std::memory_order_release);
74 : }
75 :
76 4 : return result;
77 2 : }();
78 :
79 2 : if(nbWorkersToDetach > 0) {
80 2 : ceCondVar.notify_all();
81 :
82 : {
83 4 : std::unique_lock<std::mutex> migrationLock(migrationMutex);
84 10 : migrationCondVar.wait(migrationLock, [&](){ return !(migrationSignalingCounter.load(std::memory_order_acquire) > 0); });
85 : }
86 :
87 2 : auto startIt = std::move_iterator<iter_t>(workers.begin());
88 2 : auto endIt = std::move_iterator<iter_t>(workers.end());
89 :
90 : auto eraseStartPosIt = std::remove_if(startIt, endIt,
91 5 : [&](std::unique_ptr<SpWorker>&& wPtr) {
92 5 : if(wPtr->getComputeEngine() != this) {
93 4 : res.push_back(std::move(wPtr));
94 4 : return true;
95 : } else {
96 1 : return false;
97 : }
98 2 : });
99 :
100 2 : workers.erase(eraseStartPosIt.base(), workers.end());
101 :
102 4 : std::unique_lock<std::mutex> computeEngineLock(ceMutex);
103 2 : updateWorkerCounters<true, true>(wt, -nbWorkersToDetach);
104 :
105 : }
106 :
107 4 : return res;
108 : }
109 :
110 : template <const bool bindAndStartWorkers>
111 54 : void addWorkersInternal(small_vector_base<std::unique_ptr<SpWorker>>&& inWorkers) {
112 210 : for(auto& w : inWorkers) {
113 156 : updateWorkerCounters<true,false>(w->getType(), +1);
114 : if constexpr(bindAndStartWorkers) {
115 153 : w->bindTo(this);
116 153 : w->start();
117 : }
118 : }
119 :
120 54 : if(workers.empty()) {
121 52 : workers = std::move(inWorkers);
122 : } else {
123 2 : workers.reserve(workers.size() + inWorkers.size());
124 2 : std::move(std::begin(inWorkers), std::end(inWorkers), std::back_inserter(workers));
125 : }
126 54 : }
127 :
128 2779 : bool areThereAnyWorkersToMigrate() const {
129 5558 : return nbWorkersToMigrate.load(std::memory_order_acquire) > 0;
130 : }
131 :
132 2772 : bool areThereAnyReadyTasks() const {
133 2772 : return prioSched.getNbTasks() > 0;
134 : }
135 :
136 4 : bool areWorkersToMigrateOfType(SpWorker::SpWorkerType inWt) {
137 4 : return workerTypeToMigrate == inWt;
138 : }
139 :
140 1073 : SpAbstractTask* getTask() {
141 1073 : return prioSched.pop();
142 : }
143 :
144 : template <const bool updateTotalCounter, const bool updateAvailableCounter>
145 1032 : void updateWorkerCounters(const SpWorker::SpWorkerType inWt, const long int addend) {
146 1032 : switch(inWt) {
147 1032 : case SpWorker::SpWorkerType::CPU_WORKER:
148 : if constexpr(updateTotalCounter) {
149 158 : totalNbCpuWorkers += addend;
150 : }
151 : if constexpr(updateAvailableCounter) {
152 876 : nbAvailableGpuWorkers += addend;
153 : }
154 1032 : break;
155 0 : case SpWorker::SpWorkerType::GPU_WORKER:
156 : if constexpr(updateTotalCounter) {
157 0 : totalNbGpuWorkers += addend;
158 : }
159 :
160 : if constexpr(updateAvailableCounter) {
161 0 : nbAvailableGpuWorkers += addend;
162 : }
163 0 : break;
164 0 : default:
165 0 : break;
166 : }
167 1032 : }
168 :
169 : void wait(SpWorker& worker, SpAbstractTaskGraph* atg);
170 :
171 4 : auto getCeToMigrateTo() {
172 4 : return ceToMigrateTo;
173 : }
174 :
175 4 : auto fetchDecNbOfWorkersToMigrate() {
176 4 : return nbWorkersToMigrate.fetch_sub(1, std::memory_order_relaxed);
177 : }
178 :
179 2 : void notifyMigrationFinished() {
180 : {
181 2 : std::unique_lock<std::mutex> migrationLock(migrationMutex);
182 : }
183 2 : migrationCondVar.notify_one();
184 2 : }
185 :
186 4 : auto fetchDecMigrationSignalingCounter() {
187 4 : return migrationSignalingCounter.fetch_sub(1, std::memory_order_release);
188 : }
189 :
190 : friend void SpWorker::waitOnCe(SpComputeEngine* inCe, SpAbstractTaskGraph* atg);
191 : friend void SpWorker::doLoop(SpAbstractTaskGraph* atg);
192 :
193 : public:
194 52 : explicit SpComputeEngine(small_vector_base<std::unique_ptr<SpWorker>>&& inWorkers)
195 52 : : workers(), ceMutex(), ceCondVar(), migrationMutex(), migrationCondVar(), prioSched(), nbWorkersToMigrate(0),
196 : migrationSignalingCounter(0), workerTypeToMigrate(SpWorker::SpWorkerType::CPU_WORKER), ceToMigrateTo(nullptr), nbAvailableCpuWorkers(0),
197 52 : nbAvailableGpuWorkers(0), totalNbCpuWorkers(0), totalNbGpuWorkers(0), hasBeenStopped(false) {
198 52 : addWorkers(std::move(inWorkers));
199 52 : }
200 :
201 : explicit SpComputeEngine() : SpComputeEngine(small_vector<std::unique_ptr<SpWorker>, 0>{}) {}
202 :
203 52 : ~SpComputeEngine() {
204 52 : stopIfNotAlreadyStopped();
205 52 : }
206 :
207 1069 : void pushTask(SpAbstractTask* t) {
208 1069 : prioSched.push(t);
209 1069 : wakeUpWaitingWorkers();
210 1069 : }
211 :
212 52 : void pushTasks(small_vector_base<SpAbstractTask*>& tasks) {
213 52 : prioSched.pushTasks(tasks);
214 52 : wakeUpWaitingWorkers();
215 52 : }
216 :
217 77 : size_t getCurrentNbOfWorkers() const {
218 77 : return workers.size();
219 : }
220 :
221 53 : void addWorkers(small_vector_base<std::unique_ptr<SpWorker>>&& inWorkers) {
222 53 : addWorkersInternal<true>(std::move(inWorkers));
223 53 : }
224 :
225 1 : void sendWorkersTo(SpComputeEngine& otherCe, const SpWorker::SpWorkerType wt, const size_t maxCount, const bool allowBusyWorkersToBeDetached) {
226 1 : SpComputeEngine* otherCePtr = std::addressof(otherCe);
227 :
228 1 : if(otherCePtr && otherCePtr != this) {
229 1 : otherCePtr->addWorkersInternal<false>(sendWorkersToInternal(otherCePtr, wt, maxCount, allowBusyWorkersToBeDetached));
230 : }
231 1 : }
232 :
233 1 : auto detachWorkers(const SpWorker::SpWorkerType wt, const size_t maxCount, const bool allowBusyWorkersToBeDetached) {
234 1 : return sendWorkersToInternal(nullptr, wt, maxCount, allowBusyWorkersToBeDetached);
235 : }
236 :
237 : void stopIfNotAlreadyStopped();
238 :
239 1219 : void wakeUpWaitingWorkers() {
240 : {
241 1219 : std::unique_lock<std::mutex> ceLock(ceMutex);
242 : }
243 1219 : ceCondVar.notify_all();
244 1219 : }
245 : };
246 :
247 : #endif
|