Line data Source code
1 : /////////////////////////////////////////////////////////////////////////// 2 : // Spetabaru - Berenger Bramas MPCDF - 2017 3 : // Under LGPL Licence, please you must read the LICENCE file. 4 : /////////////////////////////////////////////////////////////////////////// 5 : 6 : #include <chrono> 7 : #include <future> 8 : #include <algorithm> 9 : #include <array> 10 : #include <memory> 11 : 12 : #include "UTester.hpp" 13 : #include "Data/SpDataAccessMode.hpp" 14 : #include "Utils/SpUtils.hpp" 15 : #include "Task/SpTask.hpp" 16 : #include "TaskGraph/SpTaskGraph.hpp" 17 : #include "Compute/SpComputeEngine.hpp" 18 : #include "Speculation/SpSpeculativeModel.hpp" 19 : #include "Compute/SpWorker.hpp" 20 : #include "Utils/small_vector.hpp" 21 : 22 : class ComputeEngineTest : public UTester< ComputeEngineTest > { 23 : using Parent = UTester< ComputeEngineTest >; 24 : 25 1 : void Test(){ 26 : 27 2 : SpTaskGraph<SpSpeculativeModel::SP_MODEL_1> tg1, tg2; 28 : 29 2 : std::array<small_vector<std::unique_ptr<SpWorker>, 2>, 2> workerVectors; 30 : 31 : auto generateFunc = 32 4 : []() { 33 4 : return std::make_unique<SpWorker>(SpWorker::SpWorkerType::CPU_WORKER); 34 : }; 35 : 36 3 : for(auto& workerVector : workerVectors) { 37 2 : workerVector.resize(2); 38 2 : std::generate(std::begin(workerVector), std::end(workerVector), generateFunc); 39 6 : for(auto& w : workerVector) { 40 4 : w->start(); 41 : } 42 : } 43 : 44 2 : SpComputeEngine ce1(std::move(workerVectors[0])), ce2(std::move(workerVectors[1])); 45 : 46 1 : int a = 0, b = 0; 47 : 48 2 : std::promise<bool> tg1Promise; 49 2 : std::promise<bool> mainThreadPromise; 50 : 51 0 : tg1.task(SpWrite(a), 52 1 : [&](int& inA) { 53 1 : mainThreadPromise.set_value(true); 54 1 : tg1Promise.get_future().get(); 55 1 : inA = 1; 56 4 : }); 57 : 58 1 : tg1.task(SpRead(a), SpWrite(b), 59 1 : [](const int& inA, int& inB) { 60 1 : inB = inA; 61 2 : }); 62 : 63 2 : std::array<std::promise<bool>, 4> promises; 64 : 65 4 : for(size_t i = 1; i < promises.size(); i++) { 66 : tg2.task( 67 6 : [&promises, i]() { 68 3 : promises[i].get_future().get(); 69 3 : promises[i-1].set_value(true); 70 3 : } 71 3 : ); 72 : } 73 : 74 1 : tg1.computeOn(ce1); 75 : 76 1 : mainThreadPromise.get_future().get(); 77 : 78 1 : auto workers = ce1.detachWorkers(SpWorker::SpWorkerType::CPU_WORKER, 1, true); 79 : 80 1 : tg1Promise.set_value(true); 81 : 82 1 : UASSERTEEQUAL(static_cast<int>(workers.size()), 1); 83 : 84 1 : tg2.computeOn(ce2); 85 : 86 1 : promises[promises.size()-1].set_value(true); 87 : 88 1 : ce2.addWorkers(std::move(workers)); 89 : 90 1 : tg2.waitAllTasks(); 91 : 92 1 : ce2.sendWorkersTo(ce1, SpWorker::SpWorkerType::CPU_WORKER, 3, true); 93 : 94 1 : tg1.waitAllTasks(); 95 : 96 1 : UASSERTEEQUAL(static_cast<int>(ce1.getCurrentNbOfWorkers()), 4); 97 1 : UASSERTEEQUAL(static_cast<int>(ce2.getCurrentNbOfWorkers()), 0); 98 : 99 1 : tg1.generateTrace("/tmp/taskgraph1.svg"); 100 1 : tg2.generateTrace("/tmp/taskgraph2.svg"); 101 : 102 1 : } 103 : 104 1 : void SetTests() { 105 1 : Parent::AddTest(&ComputeEngineTest::Test, "Compute engine test"); 106 1 : } 107 : }; 108 : 109 : // You must do this 110 1 : TestClass(ComputeEngineTest) 111 : 112 :