Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/include/rng.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,51 @@
#include <random>
#include <chrono>
#include <functional>
#include <memory>
#include "mpi.h"

namespace rng
{

namespace detail
{
using Generator = std::default_random_engine;
using ProcessGeneratorPtr = std::shared_ptr<Generator>;
ProcessGeneratorPtr process_generator;

int get_process_seed()
{
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return 10000 * (rank + 1);
}

ProcessGeneratorPtr get_process_generator()
{
if (!process_generator) {
process_generator = std::make_shared<Generator>(get_process_seed());
}
return process_generator;
}

} // namespace detail

// ======================================================================
// Uniform

template <typename Real>
struct Uniform
{
Uniform(Real min, Real max)
: dist(min, max),
gen(std::chrono::system_clock::now().time_since_epoch().count())
: dist(min, max), gen(detail::get_process_generator())
{}

Uniform() : Uniform(0, 1) {}

Real get() { return dist(gen); }
Real get() { return dist(*gen); }

private:
std::default_random_engine gen;
detail::ProcessGeneratorPtr gen;
std::uniform_real_distribution<Real> dist;
};

Expand All @@ -37,21 +61,20 @@ template <typename Real>
struct Normal
{
Normal(Real mean, Real stdev)
: dist(mean, stdev),
gen(std::chrono::system_clock::now().time_since_epoch().count())
: dist(mean, stdev), gen(detail::get_process_generator())
{}
Normal() : Normal(0, 1) {}

Real get() { return dist(gen); }
Real get() { return dist(*gen); }
// FIXME remove me, or make standalone func
Real get(Real mean, Real stdev)
{
// should be possible to pass params to existing dist
return std::normal_distribution<Real>(mean, stdev)(gen);
return std::normal_distribution<Real>(mean, stdev)(*gen);
}

private:
std::default_random_engine gen;
detail::ProcessGeneratorPtr gen;
std::normal_distribution<Real> dist;
};

Expand Down