Skip to content

Commit

Permalink
Math::RNG now uses C++11 functions, no longer needs GSL
Browse files Browse the repository at this point in the history
This introduces quite wide-ranging changes across the codebase. Will
need thorough checking.
  • Loading branch information
jdtournier committed Apr 10, 2015
1 parent 81e4ad4 commit f501679
Show file tree
Hide file tree
Showing 32 changed files with 188 additions and 201 deletions.
6 changes: 4 additions & 2 deletions cmd/dirflip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ class Processor {
public:
Processor (Shared& shared) :
shared (shared),
signs (shared.get_init_signs()) { }
signs (shared.get_init_signs()),
uniform (0, signs.size()-1) { }

void execute () {
while (eval());
Expand All @@ -126,7 +127,7 @@ class Processor {

void next_permutation ()
{
signs[rng.uniform_int (signs.size())] *= -1;
signs[uniform(rng)] *= -1;
}

bool eval ()
Expand All @@ -145,6 +146,7 @@ class Processor {
Shared& shared;
std::vector<int> signs;
Math::RNG rng;
std::uniform_int_distribution<int> uniform;
};


Expand Down
5 changes: 3 additions & 2 deletions cmd/dirgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ void run () {
bipolar = false;

Math::RNG rng;
std::uniform_real_distribution<double> uniform (0.0, 1.0);
Math::Vector<double> v (2*ndirs);

for (size_t n = 0; n < 2*ndirs; n+=2) {
v[n] = Math::pi * (2.0 * rng.uniform() - 1.0);
v[n+1] = std::asin (2.0 * rng.uniform() - 1.0);
v[n] = Math::pi * (2.0 * uniform(rng) - 1.0);
v[n+1] = std::asin (2.0 * uniform(rng) - 1.0);
}

gsl_multimin_function_fdf fdf;
Expand Down
9 changes: 5 additions & 4 deletions cmd/dirsplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,14 @@ class EnergyCalculator {
void next_permutation ()
{
size_t i,j;
std::uniform_int_distribution<size_t> dist(0, subset.size()-1);
do {
i = rng.uniform_int (subset.size());
j = rng.uniform_int (subset.size());
i = dist (rng);
j = dist (rng);
} while (i == j);

size_t n_i = rng.uniform_int (subset[i].size());
size_t n_j = rng.uniform_int (subset[j].size());
size_t n_i = std::uniform_int_distribution<size_t> (0, subset[i].size()-1) (rng);
size_t n_j = std::uniform_int_distribution<size_t> (0, subset[j].size()-1) (rng);

std::swap (subset[i][n_i], subset[j][n_j]);
}
Expand Down
7 changes: 4 additions & 3 deletions cmd/label2colour.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ void run ()

node_map.insert (std::make_pair (0, Node_info ("None", 0, 0, 0, 0)));
Math::RNG rng;
std::uniform_int_distribution<uint8_t> dist;

for (node_t i = 1; i <= max_index; ++i) {
Point<uint8_t> colour;
do {
colour[0] = rng.uniform_int (255);
colour[1] = rng.uniform_int (255);
colour[2] = rng.uniform_int (255);
colour[0] = dist (rng);
colour[1] = dist (rng);
colour[2] = dist (rng);
} while (int(colour[0]) + int(colour[1]) + int(colour[2]) < 100);
node_map.insert (std::make_pair (i, Node_info (str(i), colour)));
}
Expand Down
12 changes: 10 additions & 2 deletions cmd/mrcalc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,16 @@ inline Chunk& StackEntry::evaluate (ThreadLocalStorage& storage) const
if (evaluator) return evaluator->evaluate (storage);
if (rng) {
Chunk& chunk = storage.next();
for (size_t n = 0; n < chunk.size(); ++n)
chunk[n] = rng_gausssian ? rng->normal() : rng->uniform();
if (rng_gausssian) {
std::normal_distribution<real_type> dis (0.0, 1.0);
for (size_t n = 0; n < chunk.size(); ++n)
chunk[n] = dis (*rng);
}
else {
std::uniform_real_distribution<real_type> dis (0.0, 1.0);
for (size_t n = 0; n < chunk.size(); ++n)
chunk[n] = dis (*rng);
}
return chunk;
}
return storage.next();
Expand Down
3 changes: 2 additions & 1 deletion cmd/testing_gen_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ void run ()

struct fill {
Math::RNG rng;
void operator() (decltype(vox)& v) { v.value() = rng.normal(); }
std::normal_distribution<float> normal;
void operator() (decltype(vox)& v) { v.value() = normal(rng); }
};
Image::ThreadedLoop ("generating random data...", vox).run (fill(), vox);
}
Expand Down
114 changes: 52 additions & 62 deletions lib/math/rng.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,87 +23,77 @@
#ifndef __math_simulation_h__
#define __math_simulation_h__

#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include <random>
#ifdef MRTRIX_WINDOWS
#include <sys/time.h>
#endif

#include "math/vector.h"
#include "mrtrix.h"

namespace MR
{
namespace Math
{

class RNG
//! random number generator
/*! this is a thin wrapper around the standard C++11 std::mt19937 random
* number generator. It can be used in combination with the standard C++11
* distributions. It differs from the standard in its constructors: the
* default constructor will seed using std::random_device, unless a seed
* has been expicitly passed using the MRTRIX_RNG_SEED environment
* variable. The copy constructor will seed itself using 1 + the last seed
* used - this ensures the seeds are unique across instances in
* multi-threading. */
class RNG : public std::mt19937
{
public:
RNG () {
generator = gsl_rng_alloc (gsl_rng_mt19937);
struct timeval tv;
gettimeofday (&tv, NULL);
set (tv.tv_sec ^ tv.tv_usec);
}
RNG (size_t seed) {
generator = gsl_rng_alloc (gsl_rng_mt19937);
set (seed);
}
RNG (const RNG& rng) {
generator = gsl_rng_alloc (gsl_rng_mt19937);
set (rng.get()+1);
}
~RNG () {
gsl_rng_free (generator);
}



void set (size_t seed) {
gsl_rng_set (generator, seed);
}

size_t get () const {
return gsl_rng_get (generator);
RNG () : std::mt19937 (get_seed()) { }
RNG (std::mt19937::result_type seed) : std::mt19937 (seed) { }
RNG (const RNG& rng) : std::mt19937 (get_seed()) { }
template <typename ValueType> class Uniform;
template <typename ValueType> class Normal;

static std::mt19937::result_type get_seed () {
static std::mt19937::result_type current_seed = get_seed_private();
return current_seed++;
}

private:
static std::mt19937::result_type get_seed_private () {
const char* from_env = getenv ("MRTRIX_RNG_SEED");
if (from_env)
return to<std::mt19937::result_type> (from_env);

gsl_rng* operator() () {
return generator;
}

float uniform () {
return gsl_rng_uniform (generator);
}
size_t uniform_int (size_t max) {
return gsl_rng_uniform_int (generator, max);
}
float normal (float SD = 1.0) {
return gsl_ran_gaussian (generator, SD);
}
float rician (float amplitude, float SD) {
amplitude += gsl_ran_gaussian_ratio_method (generator, SD);
float imag = gsl_ran_gaussian_ratio_method (generator, SD);
return sqrt (amplitude*amplitude + imag*imag);
}

template <typename T> void shuffle (Vector<T>& V) {
gsl_ran_shuffle (generator, V->ptr(), V.size(), sizeof (T));
}
template <class T> void shuffle (std::vector<T>& V) {
gsl_ran_shuffle (generator, &V[0], V.size(), sizeof (T));
#ifdef MRTRIX_WINDOWS
struct timeval tv;
gettimeofday (&tv, nullptr);
return tv.tv_sec ^ tv.tv_usec;
#else
// TODO check whether this does in fact work on Windows...
std::random_device rd;
return rd();
#endif
}


protected:
gsl_rng* generator;
};


inline float cauchy (float x, float s)
{
x /= s;
return (1.0 / (1.0 + x*x));
}

template <typename ValueType>
class RNG::Uniform {
public:
RNG rng;
std::uniform_real_distribution<ValueType> dist;
ValueType operator() () { return dist (rng); }
};

template <typename ValueType>
class RNG::Normal {
public:
RNG rng;
std::normal_distribution<ValueType> dist;
ValueType operator() () { return dist (rng); }
};

}
}
Expand Down
4 changes: 2 additions & 2 deletions src/dwi/directions/set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,13 @@ namespace MR {

void FastLookupSet::test_lookup() const
{

Math::RNG rng;
std::normal_distribution<float> normal (0.0, 1.0);

size_t error_count = 0;
const size_t checks = 1000000;
for (size_t i = 0; i != checks; ++i) {
Point<float> p (rng.normal(), rng.normal(), rng.normal());
Point<float> p (normal(rng), normal(rng), normal(rng));
p.normalise();
if (select_direction (p) != select_direction_slow (p))
++error_count;
Expand Down
2 changes: 1 addition & 1 deletion src/dwi/sdeconv/rf_estimation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ Math::Matrix<float> ResponseEstimator::gen_rotation_matrix (const Point<float>&
// Here the other two axes are determined at random (but both are orthogonal to the FOD peak direction)
Math::Matrix<float> R (3, 3);
R (2, 0) = dir[0]; R (2, 1) = dir[1]; R (2, 2) = dir[2];
Point<float> vec2 (rng.uniform(), rng.uniform(), rng.uniform());
Point<float> vec2 (rng(), rng(), rng());
vec2 = dir.cross (vec2);
vec2.normalise();
R (0, 0) = vec2[0]; R (0, 1) = vec2[1]; R (0, 2) = vec2[2];
Expand Down
3 changes: 1 addition & 2 deletions src/dwi/sdeconv/rf_estimation.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ class ResponseEstimator
shared (csd_shared),
lmax (lmax),
output (output),
rng (),
mutex (new std::mutex()) { }

ResponseEstimator (const ResponseEstimator& that) :
Expand All @@ -297,7 +296,7 @@ class ResponseEstimator
const size_t lmax;
Response& output;

mutable Math::RNG rng;
mutable Math::RNG::Uniform<float> rng;

std::shared_ptr<std::mutex> mutex;

Expand Down
4 changes: 2 additions & 2 deletions src/dwi/tractography/SIFT/sifter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,13 @@ namespace MR
void SIFTer::test_sorting_block_size (const size_t num_tracks) const
{

Math::RNG rng;
Math::RNG::Normal<float> rng;

std::vector<Cost_fn_gradient_sort> gradient_vector;
gradient_vector.assign (num_tracks, Cost_fn_gradient_sort (num_tracks, 0.0, 0.0));
// Fill the gradient vector with random Gaussian data
for (track_t index = 0; index != num_tracks; ++index) {
const float value = rng.normal();
const float value = rng();
gradient_vector[index].set (index, value, value);
}

Expand Down
6 changes: 3 additions & 3 deletions src/dwi/tractography/SIFT/sifter.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

#include <vector>

#include "math/rng.h"
#include "image/buffer.h"
#include "image/header.h"
#include "dwi/fixel_map.h"
#include "dwi/directions/set.h"
#include "dwi/tractography/SIFT/fixel.h"
Expand All @@ -37,9 +40,6 @@
#include "dwi/tractography/SIFT/track_index_range.h"
#include "dwi/tractography/SIFT/types.h"

#include "image/buffer.h"
#include "image/header.h"



namespace MR
Expand Down
7 changes: 4 additions & 3 deletions src/dwi/tractography/algorithms/fact.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ namespace MR
if (!get_data (source)) return false;
if (!S.init_dir) {
if (!dir.valid())
dir.set (rng.normal(), rng.normal(), rng.normal());
} else {
dir = random_direction();
}
else
dir = S.init_dir;
}

return do_next (dir) >= S.threshold;
}

Expand Down
13 changes: 4 additions & 9 deletions src/dwi/tractography/algorithms/iFOD1.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,15 @@ namespace MR
const Point<Tracking::value_type> init_dir (dir);

for (size_t n = 0; n < S.max_seed_attempts; n++) {
if (init_dir.valid()) {
dir = rand_dir (init_dir);
} else {
dir.set (rng.normal(), rng.normal(), rng.normal());
dir.normalise();
}
dir = init_dir.valid() ? rand_dir (init_dir) : random_direction();
value_type val = FOD (dir);
if (std::isfinite (val))
if (val > S.init_threshold)
return true;
}

} else {

}
else {
dir = S.init_dir;
value_type val = FOD (dir);
if (std::isfinite (val))
Expand Down Expand Up @@ -212,7 +207,7 @@ namespace MR
max_truncation = val/max_val;
}

if (rng.uniform() < val/max_val) {
if (uniform_rng() < val/max_val) {
dir = new_dir;
dir.normalise();
pos += S.step_size * dir;
Expand Down
9 changes: 2 additions & 7 deletions src/dwi/tractography/algorithms/iFOD2.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,7 @@ namespace MR
const Point<float> init_dir (dir);

for (size_t n = 0; n < S.max_seed_attempts; n++) {
if (init_dir.valid()) {
dir = rand_dir (init_dir);
} else {
dir.set (rng.normal(), rng.normal(), rng.normal());
dir.normalise();
}
dir = init_dir.valid() ? rand_dir (init_dir) : random_direction();
half_log_prob0 = FOD (dir);
if (std::isfinite (half_log_prob0) && (half_log_prob0 > S.init_threshold))
goto end_init;
Expand Down Expand Up @@ -273,7 +268,7 @@ namespace MR
max_truncation = val/max_val;
}

if (rng.uniform() < val/max_val) {
if (uniform_rng() < val/max_val) {
mean_sample_num += n;
half_log_prob0 = last_half_log_probN;
pos = positions[0];
Expand Down
Loading

0 comments on commit f501679

Please sign in to comment.