Program Listing for File rand_distributions.h

Return to documentation for file (src/rand_distributions.h)

#ifndef PROJECT_RANDOM_H
#define PROJECT_RANDOM_H

#include <gmpxx.h>

#include <chrono> // NOLINT (build/c++11)
#include <cstdlib>
#include <iostream>
#include <random>
#include <vector>

#include "solver_config.h"
#include "primitive_types.h"
#include "sampler_tools.h"


class Random{
 public:
  static void init(SolverConfiguration * config) {
    master_config_ = config;
    gmp_randinit_mt(Random::Mpz::rand_state_);

    // create the generator seed for the random engine to reference
    long long int seed;
    if (master_config_->debug_mode) {
      if (!master_config_->quiet)
        std::cout << "WARNING: Debug mode has a fixed seed for the MPZ class.\n";
      seed = 0;
    } else {
      seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
    }
    mpz_class rand_seed;
    mpz_import(rand_seed.get_mpz_t(), 1, -1, sizeof(seed), 0, 0, &seed);
    gmp_randseed(Random::Mpz::rand_state_, rand_seed.get_mpz_t());

    if (master_config_->debug_mode) {
      if (!master_config_->quiet)
        std::cout << "WARNING: Debug mode uses a fixed random number seed." << std::endl;
    } else {
      Random::rng_ = std::mt19937(Random::rd_());
    }
  }
  template <typename T> inline static T uniform(T min, T max) {
    std::uniform_int_distribution<T> distribution(min, max);
    return distribution(rng_);
  }
  inline static int uniform(int min = INT_MIN, int max = INT_MAX) {
    if (min == INT_MIN && max == INT_MAX) {
      return uni_(rng_);
    } else {
      std::uniform_int_distribution<int> distribution(min, max);
      return distribution(rng_);
    }
  }

  static SampleSize binom(SampleSize n, double p) {
    std::binomial_distribution<> d(n, p);
    return static_cast<SampleSize>(d(rng_));
  }
  template<typename ListType, typename CountType>
  static void DownsampleList(CountType target_size, std::vector<ListType> &oversampled_vec,
                             bool resize = true) {
    assert(oversampled_vec.size() >= target_size);

    CountType end_point = oversampled_vec.size() - 1;
    while (end_point > target_size) {
      auto id_loc = Random::uniform<CountType>(0, end_point);
      oversampled_vec[id_loc] = oversampled_vec[end_point];  // Copy back & overwrite the used value
      end_point--;
    }
    if (resize)
      oversampled_vec.resize(target_size);
  }
  template<typename T>
  static void shuffle(std::vector<T> &vec) {
    std::shuffle(vec.begin(), vec.end(), rng_);
  }

  class Mpz {
    // Allow the Random class to access this class' private methods/fields
    friend class Random;
   public:
    inline static void uniform(mpz_class max_z, mpz_class &rand_val) {
      mpz_urandomm(rand_val.get_mpz_t(), Random::Mpz::rand_state_, max_z.get_mpz_t());
    }
    inline static SampleSize binom(SampleSize n, const mpz_class &t,
                                   const mpz_class &a) {
      assert(t > 0 && t >= a);
      // Handle the edge cases
      if (a == 0)
        return 0;
      else if (a == t)
        return n;

      if (use_approx_binom_)
        return Random::Mpz::binom_approx(n, t, a);
      else
        return Random::Mpz::binom_exact(n, t, a);
    }

   private:
    static bool use_approx_binom_;
    static gmp_randstate_t rand_state_;
    inline static SampleSize binom_approx(SampleSize n, const mpz_class &t,
                                          const mpz_class &a) {
      mpf_class a_mpf = a, t_mpf = t;
      mpf_class p_mpf = a_mpf / t_mpf;
      return Random::binom(n, p_mpf.get_d());
//      double p = mpz_get_d(a.get_mpz_t()) / mpz_get_d(t.get_mpz_t());
//      return Random::binom(n,p);
    }
    inline static SampleSize binom_exact(SampleSize n, const mpz_class &t,
                                         const mpz_class &a) {
      mpz_class rand_mpz = 0;
      SampleSize num_success = 0;
      for (SampleSize i = 0; i < n; i++) {
        Random::Mpz::uniform(t, rand_mpz);
        if (rand_mpz < a)
          num_success++;
      }
      return num_success;
    }
  };
  static void SelectRangeInts(SampleSize max_id, SampleSize num_elements,
                              std::vector<SampleSize> &samples_to_replace) {
    assert(max_id >= num_elements);
    samples_to_replace.clear();
    samples_to_replace.reserve(max_id);
    for (SampleSize i = 0; i < max_id; i++)
      samples_to_replace.emplace_back(i);
    //ToDo Modify this function to only ever do a maximum of max_id/2 random numbers
    Random::DownsampleList<SampleSize, SampleSize>(num_elements, samples_to_replace);
  }
  template<typename T>
  static void shuffle(typename std::vector<T>::iterator begin, // Dependent types so add the
                      typename std::vector<T>::iterator end) {  // "typename" keyword.
    std::shuffle(begin, end, rng_);
  }

 private:
  Random() = default;
  static std::uniform_int_distribution<int> uni_;
  static SolverConfiguration * master_config_;
  static std::random_device rd_;
  static std::mt19937 rng_;
};

#endif //PROJECT_RANDOM_H