Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/include/rng.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace rng
namespace detail
{
using Generator = std::default_random_engine;
using ProcessGeneratorPtr = std::shared_ptr<Generator>;
ProcessGeneratorPtr process_generator;
using GeneratorPtr = std::shared_ptr<Generator>;
GeneratorPtr process_generator;

int get_process_seed()
{
Expand All @@ -25,14 +25,20 @@ int get_process_seed()
return 10000 * (rank + 1);
}

ProcessGeneratorPtr get_process_generator()
GeneratorPtr get_process_generator()
{
if (!process_generator) {
process_generator = std::make_shared<Generator>(get_process_seed());
}
return process_generator;
}

GeneratorPtr get_constant_generator(int seed)
{
// FIXME this doesn't need to be a shared pointer
return std::make_shared<Generator>(seed);
}

} // namespace detail

// ======================================================================
Expand All @@ -41,6 +47,10 @@ ProcessGeneratorPtr get_process_generator()
template <typename Real>
struct Uniform
{
Uniform(Real min, Real max, int seed)
: dist(min, max), gen(detail::get_constant_generator(seed))
{}

Uniform(Real min, Real max)
: dist(min, max), gen(detail::get_process_generator())
{}
Expand All @@ -50,7 +60,7 @@ struct Uniform
Real get() { return dist(*gen); }

private:
detail::ProcessGeneratorPtr gen;
detail::GeneratorPtr gen;
std::uniform_real_distribution<Real> dist;
};

Expand All @@ -60,9 +70,14 @@ private:
template <typename Real>
struct Normal
{
Normal(Real mean, Real stdev, int seed)
: dist(mean, stdev), gen(detail::get_constant_generator(seed))
{}

Normal(Real mean, Real stdev)
: dist(mean, stdev), gen(detail::get_process_generator())
{}

Normal() : Normal(0, 1) {}

Real get() { return dist(*gen); }
Expand All @@ -74,7 +89,7 @@ struct Normal
}

private:
detail::ProcessGeneratorPtr gen;
detail::GeneratorPtr gen;
std::normal_distribution<Real> dist;
};

Expand Down