Line data Source code
1 : /////////////////////////////////////////////////////////////////////////// 2 : // Thomas Millot (c), Unistra, 2020 3 : // Under LGPL Licence, please you must read the LICENCE file. 4 : /////////////////////////////////////////////////////////////////////////// 5 : #ifndef SPPHILOXGENERATOR_HPP 6 : #define SPPHILOXGENERATOR_HPP 7 : #include <iostream> 8 : #include <random> 9 : #include <array> 10 : 11 : // Implementation of the Philox algorithm to generate random numbers in parallel. 12 : // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf 13 : // Also based on an implementation of the algorithm by Tensorflow. 14 : // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/random/philox_random.h 15 : // It's the Philox-4×32 version, meaning 4 32bits random numbers each time. 16 : // The number of cycles can be user defined, by default it's 10. 17 : // The engine satisfies C++ named requirements RandomNumberDistribution, 18 : // so it can be used with std::uniform_real_distribution for example. 19 : 20 : // Exactly the same interface as SpMTGenerator, so it can be interchangeable 21 : template<class RealType = double> 22 : class SpPhiloxGenerator { 23 : class philox4x32 { 24 : typedef uint_fast32_t uint32; 25 : typedef uint_fast64_t uint64; 26 : 27 : static constexpr int DEFAULT_CYCLES = 10; 28 : public: 29 : 30 : // An array of four uint32, the results of the philox4 engine 31 : using Result = std::array<uint32, 4>; 32 : 33 : // 64-bit seed stored in two uint32 34 : using Key = std::array<uint32, 2>; 35 : 36 : philox4x32() = default; 37 : 38 404 : explicit philox4x32(uint64 seed, int cycles = DEFAULT_CYCLES) 39 : : counter_(), temp_results_(), key_(), temp_counter_(0), cycles_(cycles), 40 3232 : force_computation_(true), operatorPPcounter(0) 41 : { 42 : // Splitting the seed in two 43 404 : key_[0] = static_cast<uint32>(seed); 44 404 : key_[1] = static_cast<uint32>(seed >> 32); 45 : 46 404 : counter_.fill(0); 47 404 : temp_results_.fill(0); 48 404 : } 49 : 50 : // Returns the minimum value productible by the engine 51 287284 : static constexpr uint32 min() { return _Min; } 52 : 53 : // Returns the maximum value productible by the engine 54 : static constexpr uint32 max() { return _Max; } 55 : 56 : // Skip the specified number of steps 57 167284 : void Skip(uint64 count) { 58 167284 : if(count > 0) { 59 : 60 167284 : const auto nbStepsToNextMultipleOf4 = 4 - temp_counter_; 61 : 62 167284 : if(count <= nbStepsToNextMultipleOf4) { 63 36668 : temp_counter_ += count; 64 36668 : return; 65 : } 66 : 67 130616 : count -= nbStepsToNextMultipleOf4; 68 : 69 : // We need to add 1 to the counter because we have moved past 70 : // all the 4 results from the current temp_results_ array. This 71 : // also includes the special case where we already are on the edge 72 : // (temp_counter_ == 4) but we haven't triggered a counter increment yet. 73 : // We can safely add 1 here (instead of calling SkipOne). I won't cause any 74 : // overfow since we are dividing the value of count by 4 and count has a 75 : // width of 64 bits. 76 130616 : const auto nbOfCounterIncrements = count / 4 + 1; 77 : 78 130616 : temp_counter_ = count % 4; 79 : 80 130616 : const auto count_lo = static_cast<uint32>(nbOfCounterIncrements); 81 130616 : auto count_hi = static_cast<uint32>(nbOfCounterIncrements >> 32); 82 : 83 : // 128 bit add 84 : 85 130616 : counter_[0] += count_lo; 86 130616 : if (counter_[0] < count_lo) { 87 0 : ++count_hi; 88 : } 89 : 90 130616 : counter_[1] += count_hi; 91 130616 : if (counter_[1] < count_hi) { 92 0 : if (++counter_[2] == 0) { 93 0 : ++counter_[3]; 94 : } 95 : } 96 : 97 130616 : force_computation_ = true; 98 : } 99 : } 100 : 101 : // Returns a random number using the philox engine 102 287284 : uint32 operator()() { 103 287284 : operatorPPcounter++; 104 : 105 287284 : if(temp_counter_ == 4) { 106 53312 : temp_counter_ = 0; 107 53312 : SkipOne(); 108 53312 : force_computation_ = true; 109 : } 110 : 111 287284 : if(force_computation_) { 112 183948 : force_computation_ = false; 113 183948 : temp_results_ = counter_; 114 183948 : ExecuteRounds(); 115 : } 116 : 117 287284 : uint32 value = temp_results_[temp_counter_]; 118 287284 : temp_counter_++; 119 : 120 287284 : return value; 121 : } 122 : 123 574568 : auto getOperatorPPCounter() const{ 124 574568 : return operatorPPcounter; 125 : } 126 : 127 : private: 128 : 129 : // Using the same constants as recommended in the original paper. 130 : static constexpr uint32 kPhiloxW32A = 0x9E3779B9; 131 : static constexpr uint32 kPhiloxW32B = 0xBB67AE85; 132 : static constexpr uint32 kPhiloxM4x32A = 0xD2511F53; 133 : static constexpr uint32 kPhiloxM4x32B = 0xCD9E8D57; 134 : 135 : // The minimum return value 136 : static constexpr uint32 _Min = 0; 137 : // The maximum return value 138 : static constexpr uint32 _Max = UINT_FAST32_MAX; 139 : 140 : // The counter for the current state of the engine 141 : Result counter_; 142 : 143 : // Keeping the last to results to improve performances during consecutive call 144 : Result temp_results_; 145 : 146 : // The split seed 147 : Key key_; 148 : 149 : // To iterate through the temp_results_ 150 : uint64 temp_counter_; 151 : 152 : // The number of cycles used to generate randomness 153 : int cycles_; 154 : 155 : // To force the engine to compute the rounds to populates temp_results_ 156 : bool force_computation_; 157 : 158 : // The number of times operator () is called to ensure that the STL 159 : // always call it once 160 : uint32 operatorPPcounter; 161 : 162 : // Skip one step 163 53312 : void SkipOne() { 164 : // 128 bit increment 165 53312 : if (++counter_[0] == 0) { 166 0 : if (++counter_[1] == 0) { 167 0 : if (++counter_[2] == 0) { 168 0 : ++counter_[3]; 169 : } 170 : } 171 : } 172 53312 : } 173 : 174 : // Helper function to return the lower and higher 32-bits from two 32-bit integer multiplications. 175 3678960 : static void MultiplyHighLow(uint32 a, uint32 b, uint32 *result_low, uint32 *result_high) { 176 : 177 3678960 : const uint64 product = static_cast<uint64>(a) * b; 178 3678960 : *result_low = static_cast<uint32>(product); 179 3678960 : *result_high = static_cast<uint32>(product >> 32); 180 : 181 3678960 : } 182 : 183 183948 : void ExecuteRounds() { 184 : 185 183948 : Key key = key_; 186 : 187 : // Run the single rounds for ten times. 188 2023428 : for (int i = 0; i < cycles_; ++i) { 189 1839480 : temp_results_ = ComputeSingleRound(temp_results_, key); 190 1839480 : RaiseKey(&key); 191 : } 192 183948 : } 193 : 194 : // Helper function for a single round of the underlying Philox algorithm. 195 1839480 : static Result ComputeSingleRound(const Result &counter, const Key &key) { 196 : uint32 lo0; 197 : uint32 hi0; 198 1839480 : MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0); 199 : 200 : uint32 lo1; 201 : uint32 hi1; 202 1839480 : MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1); 203 : 204 : Result result; 205 1839480 : result[0] = hi1 ^ counter[1] ^ key[0]; 206 1839480 : result[1] = lo1; 207 1839480 : result[2] = hi0 ^ counter[3] ^ key[1]; 208 1839480 : result[3] = lo0; 209 1839480 : return result; 210 : } 211 : 212 1839480 : void RaiseKey(Key *key) { 213 1839480 : (*key)[0] += kPhiloxW32A; 214 1839480 : (*key)[1] += kPhiloxW32B; 215 1839480 : } 216 : }; 217 : 218 : philox4x32 phEngine; 219 : std::uniform_real_distribution<RealType> dis01; 220 : std::size_t nbValuesGenerated; 221 : 222 : public: 223 : explicit SpPhiloxGenerator() : phEngine(std::random_device()()), dis01(0, 1), nbValuesGenerated(0) {} 224 : 225 404 : explicit SpPhiloxGenerator(const size_t inSeed) : phEngine(inSeed), dis01(0, 1), nbValuesGenerated(0) {} 226 : 227 : SpPhiloxGenerator(const SpPhiloxGenerator &) = default; 228 : 229 : SpPhiloxGenerator(SpPhiloxGenerator &&) = default; 230 : 231 : SpPhiloxGenerator &operator=(const SpPhiloxGenerator &) = default; 232 : 233 : SpPhiloxGenerator &operator=(SpPhiloxGenerator &&) = default; 234 : 235 207284 : SpPhiloxGenerator &skip(const size_t inNbToSkip) { 236 207284 : if(inNbToSkip == 0){ 237 40000 : return *this; 238 : } 239 : 240 167284 : phEngine.Skip(inNbToSkip); 241 : 242 167284 : nbValuesGenerated += inNbToSkip; 243 : 244 167284 : return *this; 245 : } 246 : 247 287284 : RealType getRand01() { 248 287284 : nbValuesGenerated++; 249 287284 : [[maybe_unused]] const auto counterOperatorPPBefore = phEngine.getOperatorPPCounter(); 250 287284 : const RealType number = dis01(phEngine); 251 287284 : [[maybe_unused]] const auto counterOperatorPPAfter = phEngine.getOperatorPPCounter(); 252 287284 : assert(counterOperatorPPAfter == counterOperatorPPBefore+1); 253 287284 : return number; 254 : } 255 : 256 : size_t getNbValuesGenerated() const { 257 : return nbValuesGenerated; 258 : } 259 : }; 260 : 261 : 262 : #endif