LCOV - code coverage report
Current view: top level - Src/Random - SpPhiloxGenerator.hpp (source / functions) Hit Total Coverage
Test: Coverage example Lines: 79 85 92.9 %
Date: 2021-12-02 17:21:05 Functions: 13 13 100.0 %

          Line data    Source code
       1             : ///////////////////////////////////////////////////////////////////////////
       2             : // Thomas Millot (c), Unistra, 2020
       3             : // Under LGPL Licence, please you must read the LICENCE file.
       4             : ///////////////////////////////////////////////////////////////////////////
       5             : #ifndef SPPHILOXGENERATOR_HPP
       6             : #define SPPHILOXGENERATOR_HPP
       7             : #include <iostream>
       8             : #include <random>
       9             : #include <array>
      10             : 
      11             : // Implementation of the Philox algorithm to generate random numbers in parallel.
      12             : // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
      13             : // Also based on an implementation of the algorithm by Tensorflow.
      14             : // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/random/philox_random.h
      15             : // It's the Philox-4×32 version, meaning 4 32bits random numbers each time.
      16             : // The number of cycles can be user defined, by default it's 10.
      17             : // The engine satisfies C++ named requirements RandomNumberDistribution,
      18             : // so it can be used with std::uniform_real_distribution for example.
      19             : 
      20             : // Exactly the same interface as SpMTGenerator, so it can be interchangeable
      21             : template<class RealType = double>
      22             : class SpPhiloxGenerator {
      23             :     class philox4x32 {
      24             :         typedef uint_fast32_t uint32;
      25             :         typedef uint_fast64_t uint64;
      26             : 
      27             :         static constexpr int DEFAULT_CYCLES = 10;
      28             :     public:
      29             : 
      30             :         // An array of four uint32, the results of the philox4 engine
      31             :         using Result = std::array<uint32, 4>;
      32             : 
      33             :         // 64-bit seed stored in two uint32
      34             :         using Key = std::array<uint32, 2>;
      35             : 
      36             :         philox4x32() = default;
      37             : 
      38         404 :         explicit philox4x32(uint64 seed, int cycles = DEFAULT_CYCLES)
      39             :         : counter_(), temp_results_(), key_(), temp_counter_(0), cycles_(cycles),
      40        3232 :         force_computation_(true), operatorPPcounter(0)
      41             :         {
      42             :             // Splitting the seed in two
      43         404 :             key_[0] = static_cast<uint32>(seed);
      44         404 :             key_[1] = static_cast<uint32>(seed >> 32);
      45             : 
      46         404 :             counter_.fill(0);
      47         404 :             temp_results_.fill(0);
      48         404 :         }
      49             : 
      50             :         // Returns the minimum value productible by the engine
      51      287284 :         static constexpr uint32 min() { return _Min; }
      52             : 
      53             :         // Returns the maximum value productible by the engine
      54             :         static constexpr uint32 max() { return _Max; }
      55             : 
      56             :         // Skip the specified number of steps
      57      167284 :         void Skip(uint64 count) {
      58      167284 :             if(count > 0) {
      59             :                 
      60      167284 :                 const auto nbStepsToNextMultipleOf4 = 4 - temp_counter_;
      61             :                 
      62      167284 :                 if(count <= nbStepsToNextMultipleOf4) {
      63       36668 :                     temp_counter_ += count;
      64       36668 :                     return; 
      65             :                 }
      66             :                 
      67      130616 :                 count -= nbStepsToNextMultipleOf4;
      68             :                 
      69             :                 // We need to add 1 to the counter because we have moved past
      70             :                 // all the 4 results from the current temp_results_ array. This 
      71             :                 // also includes the special case where we already are on the edge
      72             :                 // (temp_counter_ == 4) but we haven't triggered a counter increment yet.
      73             :                 // We can safely add 1 here (instead of calling SkipOne). I won't cause any
      74             :                 // overfow since we are dividing the value of count by 4 and count has a 
      75             :                 // width of 64 bits.  
      76      130616 :                 const auto nbOfCounterIncrements = count / 4 + 1;
      77             :                 
      78      130616 :                 temp_counter_ = count % 4;
      79             :                 
      80      130616 :                 const auto count_lo = static_cast<uint32>(nbOfCounterIncrements);
      81      130616 :                 auto count_hi = static_cast<uint32>(nbOfCounterIncrements >> 32);
      82             :                 
      83             :                 // 128 bit add
      84             :                 
      85      130616 :                 counter_[0] += count_lo;
      86      130616 :                 if (counter_[0] < count_lo) {
      87           0 :                     ++count_hi;
      88             :                 }
      89             : 
      90      130616 :                 counter_[1] += count_hi;
      91      130616 :                 if (counter_[1] < count_hi) {
      92           0 :                     if (++counter_[2] == 0) {
      93           0 :                         ++counter_[3];
      94             :                     }
      95             :                 }
      96             :                 
      97      130616 :                 force_computation_ = true;
      98             :             }
      99             :         }
     100             : 
     101             :         // Returns a random number using the philox engine
     102      287284 :         uint32 operator()() {
     103      287284 :             operatorPPcounter++;
     104             : 
     105      287284 :             if(temp_counter_ == 4) {
     106       53312 :                 temp_counter_ = 0;
     107       53312 :                 SkipOne();
     108       53312 :                 force_computation_ = true;
     109             :             }
     110             :             
     111      287284 :             if(force_computation_) {
     112      183948 :                 force_computation_ = false;
     113      183948 :                 temp_results_ = counter_;
     114      183948 :                 ExecuteRounds();
     115             :             }
     116             : 
     117      287284 :             uint32 value = temp_results_[temp_counter_];
     118      287284 :             temp_counter_++;
     119             : 
     120      287284 :             return value;
     121             :         }
     122             : 
     123      574568 :         auto getOperatorPPCounter() const{
     124      574568 :             return operatorPPcounter;
     125             :         }
     126             : 
     127             :     private:
     128             : 
     129             :         // Using the same constants as recommended in the original paper.
     130             :         static constexpr uint32 kPhiloxW32A = 0x9E3779B9;
     131             :         static constexpr uint32 kPhiloxW32B = 0xBB67AE85;
     132             :         static constexpr uint32 kPhiloxM4x32A = 0xD2511F53;
     133             :         static constexpr uint32 kPhiloxM4x32B = 0xCD9E8D57;
     134             : 
     135             :         // The minimum return value
     136             :         static constexpr uint32 _Min = 0;
     137             :         // The maximum return value
     138             :         static constexpr uint32 _Max = UINT_FAST32_MAX;
     139             : 
     140             :         // The counter for the current state of the engine
     141             :         Result counter_;
     142             : 
     143             :         // Keeping the last to results to improve performances during consecutive call
     144             :         Result temp_results_;
     145             : 
     146             :         // The split seed
     147             :         Key key_;
     148             : 
     149             :         // To iterate through the temp_results_
     150             :         uint64 temp_counter_;
     151             : 
     152             :         // The number of cycles used to generate randomness
     153             :         int cycles_;
     154             : 
     155             :         // To force the engine to compute the rounds to populates temp_results_
     156             :         bool force_computation_;
     157             : 
     158             :         // The number of times operator () is called to ensure that the STL
     159             :         // always call it once
     160             :         uint32 operatorPPcounter;
     161             : 
     162             :         // Skip one step
     163       53312 :         void SkipOne() {
     164             :             // 128 bit increment
     165       53312 :             if (++counter_[0] == 0) {
     166           0 :                 if (++counter_[1] == 0) {
     167           0 :                     if (++counter_[2] == 0) {
     168           0 :                         ++counter_[3];
     169             :                     }
     170             :                 }
     171             :             }
     172       53312 :         }
     173             : 
     174             :         // Helper function to return the lower and higher 32-bits from two 32-bit integer multiplications.
     175     3678960 :         static void MultiplyHighLow(uint32 a, uint32 b, uint32 *result_low, uint32 *result_high) {
     176             : 
     177     3678960 :             const uint64 product = static_cast<uint64>(a) * b;
     178     3678960 :             *result_low = static_cast<uint32>(product);
     179     3678960 :             *result_high = static_cast<uint32>(product >> 32);
     180             : 
     181     3678960 :         }
     182             : 
     183      183948 :         void ExecuteRounds() {
     184             : 
     185      183948 :             Key key = key_;
     186             : 
     187             :             // Run the single rounds for ten times.
     188     2023428 :             for (int i = 0; i < cycles_; ++i) {
     189     1839480 :                 temp_results_ = ComputeSingleRound(temp_results_, key);
     190     1839480 :                 RaiseKey(&key);
     191             :             }
     192      183948 :         }
     193             : 
     194             :         // Helper function for a single round of the underlying Philox algorithm.
     195     1839480 :         static Result ComputeSingleRound(const Result &counter, const Key &key) {
     196             :             uint32 lo0;
     197             :             uint32 hi0;
     198     1839480 :             MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0);
     199             : 
     200             :             uint32 lo1;
     201             :             uint32 hi1;
     202     1839480 :             MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1);
     203             : 
     204             :             Result result;
     205     1839480 :             result[0] = hi1 ^ counter[1] ^ key[0];
     206     1839480 :             result[1] = lo1;
     207     1839480 :             result[2] = hi0 ^ counter[3] ^ key[1];
     208     1839480 :             result[3] = lo0;
     209     1839480 :             return result;
     210             :         }
     211             : 
     212     1839480 :         void RaiseKey(Key *key) {
     213     1839480 :             (*key)[0] += kPhiloxW32A;
     214     1839480 :             (*key)[1] += kPhiloxW32B;
     215     1839480 :         }
     216             :     };
     217             : 
     218             :     philox4x32 phEngine;
     219             :     std::uniform_real_distribution<RealType> dis01;
     220             :     std::size_t nbValuesGenerated;
     221             : 
     222             : public:
     223             :     explicit SpPhiloxGenerator() : phEngine(std::random_device()()), dis01(0, 1), nbValuesGenerated(0) {}
     224             : 
     225         404 :     explicit SpPhiloxGenerator(const size_t inSeed) : phEngine(inSeed), dis01(0, 1), nbValuesGenerated(0) {}
     226             : 
     227             :     SpPhiloxGenerator(const SpPhiloxGenerator &) = default;
     228             : 
     229             :     SpPhiloxGenerator(SpPhiloxGenerator &&) = default;
     230             : 
     231             :     SpPhiloxGenerator &operator=(const SpPhiloxGenerator &) = default;
     232             : 
     233             :     SpPhiloxGenerator &operator=(SpPhiloxGenerator &&) = default;
     234             : 
     235      207284 :     SpPhiloxGenerator &skip(const size_t inNbToSkip) {
     236      207284 :         if(inNbToSkip == 0){
     237       40000 :             return *this;
     238             :         }
     239             : 
     240      167284 :         phEngine.Skip(inNbToSkip);
     241             :         
     242      167284 :         nbValuesGenerated += inNbToSkip;
     243             :         
     244      167284 :         return *this;
     245             :     }
     246             : 
     247      287284 :     RealType getRand01() {
     248      287284 :         nbValuesGenerated++;
     249      287284 :         [[maybe_unused]] const auto counterOperatorPPBefore = phEngine.getOperatorPPCounter();
     250      287284 :         const RealType number = dis01(phEngine);
     251      287284 :         [[maybe_unused]] const auto counterOperatorPPAfter = phEngine.getOperatorPPCounter();
     252      287284 :         assert(counterOperatorPPAfter == counterOperatorPPBefore+1);
     253      287284 :         return number;
     254             :     }
     255             : 
     256             :     size_t getNbValuesGenerated() const {
     257             :         return nbValuesGenerated;
     258             :     }
     259             : };
     260             : 
     261             : 
     262             : #endif

Generated by: LCOV version 1.14