forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
libcore: add N(0,1) and Exp(1) distributions to core::rand.
Sample from the normal and exponential distributions using the Ziggurat algorithm.
- Loading branch information
Showing
4 changed files
with
687 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#!/usr/bin/env python | ||
# xfail-license | ||
|
||
# This creates the tables used for distributions implemented using the | ||
# ziggurat algorithm in `core::rand::distributions;`. They are | ||
# (basically) the tables as used in the ZIGNOR variant (Doornik 2005). | ||
# They are changed rarely, so the generated file should be checked in | ||
# to git. | ||
# | ||
# It creates 3 tables: X as in the paper, F which is f(x_i), and | ||
# F_DIFF which is f(x_i) - f(x_{i-1}). The latter two are just cached | ||
# values which is not done in that paper (but is done in other | ||
# variants). Note that the adZigR table is unnecessary because of | ||
# algebra. | ||
# | ||
# It is designed to be compatible with Python 2 and 3. | ||
|
||
from math import exp, sqrt, log, floor | ||
import random | ||
|
||
# The order should match the return value of `tables` | ||
TABLE_NAMES = ['X', 'F', 'F_DIFF'] | ||
|
||
# The actual length of the table is 1 more, to stop | ||
# index-out-of-bounds errors. This should match the bitwise operation | ||
# to find `i` in `zigurrat` in `libstd/rand/mod.rs`. Also the *_R and | ||
# *_V constants below depend on this value. | ||
TABLE_LEN = 256 | ||
|
||
# equivalent to `zigNorInit` in Doornik2005, but generalised to any | ||
# distribution. r = dR, v = dV, f = probability density function, | ||
# f_inv = inverse of f | ||
def tables(r, v, f, f_inv): | ||
# compute the x_i | ||
xvec = [0]*(TABLE_LEN+1) | ||
|
||
xvec[0] = v / f(r) | ||
xvec[1] = r | ||
|
||
for i in range(2, TABLE_LEN): | ||
last = xvec[i-1] | ||
xvec[i] = f_inv(v / last + f(last)) | ||
|
||
# cache the f's | ||
fvec = [0]*(TABLE_LEN+1) | ||
fdiff = [0]*(TABLE_LEN+1) | ||
for i in range(TABLE_LEN+1): | ||
fvec[i] = f(xvec[i]) | ||
if i > 0: | ||
fdiff[i] = fvec[i] - fvec[i-1] | ||
|
||
return xvec, fvec, fdiff | ||
|
||
# Distributions | ||
# N(0, 1) | ||
def norm_f(x): | ||
return exp(-x*x/2.0) | ||
def norm_f_inv(y): | ||
return sqrt(-2.0*log(y)) | ||
|
||
NORM_R = 3.6541528853610088 | ||
NORM_V = 0.00492867323399 | ||
|
||
NORM = tables(NORM_R, NORM_V, | ||
norm_f, norm_f_inv) | ||
|
||
# Exp(1) | ||
def exp_f(x): | ||
return exp(-x) | ||
def exp_f_inv(y): | ||
return -log(y) | ||
|
||
EXP_R = 7.69711747013104972 | ||
EXP_V = 0.0039496598225815571993 | ||
|
||
EXP = tables(EXP_R, EXP_V, | ||
exp_f, exp_f_inv) | ||
|
||
|
||
# Output the tables/constants/types | ||
|
||
def render_static(name, type, value): | ||
# no space or | ||
return 'pub static %s: %s =%s;\n' % (name, type, value) | ||
|
||
# static `name`: [`type`, .. `len(values)`] = | ||
# [values[0], ..., values[3], | ||
# values[4], ..., values[7], | ||
# ... ]; | ||
def render_table(name, values): | ||
rows = [] | ||
# 4 values on each row | ||
for i in range(0, len(values), 4): | ||
row = values[i:i+4] | ||
rows.append(', '.join('%.18f' % f for f in row)) | ||
|
||
rendered = '\n [%s]' % ',\n '.join(rows) | ||
return render_static(name, '[f64, .. %d]' % len(values), rendered) | ||
|
||
|
||
with open('ziggurat_tables.rs', 'w') as f: | ||
f.write('''// Copyright 2013 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// http://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
// Tables for distributions which are sampled using the ziggurat | ||
// algorithm. Autogenerated by `ziggurat_tables.py`. | ||
pub type ZigTable = &\'static [f64, .. %d]; | ||
''' % (TABLE_LEN + 1)) | ||
for name, tables, r in [('NORM', NORM, NORM_R), | ||
('EXP', EXP, EXP_R)]: | ||
f.write(render_static('ZIG_%s_R' % name, 'f64', ' %.18f' % r)) | ||
for (tabname, table) in zip(TABLE_NAMES, tables): | ||
f.write(render_table('ZIG_%s_%s' % (name, tabname), table)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// http://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
//! Sampling from random distributions | ||
// Some implementations use the Ziggurat method | ||
// https://en.wikipedia.org/wiki/Ziggurat_algorithm | ||
// | ||
// The version used here is ZIGNOR [Doornik 2005, "An Improved | ||
// Ziggurat Method to Generate Normal Random Samples"] which is slower | ||
// (about double, it generates an extra random number) than the | ||
// canonical version [Marsaglia & Tsang 2000, "The Ziggurat Method for | ||
// Generating Random Variables"], but more robust. If one wanted, one | ||
// could implement VIZIGNOR the ZIGNOR paper for more speed. | ||
|
||
use prelude::*; | ||
use rand::{Rng,Rand}; | ||
|
||
mod ziggurat_tables; | ||
|
||
// inlining should mean there is no performance penalty for this | ||
#[inline(always)] | ||
fn ziggurat<R:Rng>(rng: &R, | ||
center_u: bool, | ||
X: ziggurat_tables::ZigTable, | ||
F: ziggurat_tables::ZigTable, | ||
F_DIFF: ziggurat_tables::ZigTable, | ||
pdf: &'static fn(f64) -> f64, // probability density function | ||
zero_case: &'static fn(&R, f64) -> f64) -> f64 { | ||
loop { | ||
let u = if center_u {2.0 * rng.gen() - 1.0} else {rng.gen()}; | ||
let i: uint = rng.gen::<uint>() & 0xff; | ||
let x = u * X[i]; | ||
|
||
let test_x = if center_u {f64::abs(x)} else {x}; | ||
|
||
// algebraically equivalent to |u| < X[i+1]/X[i] (or u < X[i+1]/X[i]) | ||
if test_x < X[i + 1] { | ||
return x; | ||
} | ||
if i == 0 { | ||
return zero_case(rng, u); | ||
} | ||
// algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 | ||
if F[i+1] + F_DIFF[i+1] * rng.gen() < pdf(x) { | ||
return x; | ||
} | ||
} | ||
} | ||
|
||
/// A wrapper around an `f64` to generate N(0, 1) random numbers (a.k.a. a | ||
/// standard normal, or Gaussian). Multiplying the generated values by the | ||
/// desired standard deviation `sigma` then adding the desired mean `mu` will | ||
/// give N(mu, sigma^2) distributed random numbers. | ||
/// | ||
/// Note that this has to be unwrapped before use as an `f64` (using either | ||
/// `*` or `cast::transmute` is safe). | ||
/// | ||
/// # Example | ||
/// | ||
/// ~~~ | ||
/// use core::rand::distributions::StandardNormal; | ||
/// | ||
/// fn main() { | ||
/// let normal = 2.0 + (*rand::random::<StandardNormal>()) * 3.0; | ||
/// println(fmt!("%f is from a N(2, 9) distribution", normal)) | ||
/// } | ||
/// ~~~ | ||
pub struct StandardNormal(f64); | ||
|
||
impl Rand for StandardNormal { | ||
fn rand<R:Rng>(rng: &R) -> StandardNormal { | ||
#[inline(always)] | ||
fn pdf(x: f64) -> f64 { | ||
f64::exp((-x*x/2.0) as f64) as f64 | ||
} | ||
#[inline(always)] | ||
fn zero_case<R:Rng>(rng: &R, u: f64) -> f64 { | ||
// compute a random number in the tail by hand | ||
|
||
// strange initial conditions, because the loop is not | ||
// do-while, so the condition should be true on the first | ||
// run, they get overwritten anyway (0 < 1, so these are | ||
// good). | ||
let mut x = 1.0, y = 0.0; | ||
|
||
// XXX infinities? | ||
while -2.0*y < x * x { | ||
x = f64::ln(rng.gen()) / ziggurat_tables::ZIG_NORM_R; | ||
y = f64::ln(rng.gen()); | ||
} | ||
if u < 0.0 {x-ziggurat_tables::ZIG_NORM_R} else {ziggurat_tables::ZIG_NORM_R-x} | ||
} | ||
|
||
StandardNormal(ziggurat( | ||
rng, | ||
true, // this is symmetric | ||
&ziggurat_tables::ZIG_NORM_X, | ||
&ziggurat_tables::ZIG_NORM_F, &ziggurat_tables::ZIG_NORM_F_DIFF, | ||
pdf, zero_case)) | ||
} | ||
} | ||
|
||
/// A wrapper around an `f64` to generate Exp(1) random numbers. Dividing by | ||
/// the desired rate `lambda` will give Exp(lambda) distributed random | ||
/// numbers. | ||
/// | ||
/// Note that this has to be unwrapped before use as an `f64` (using either | ||
/// `*` or `cast::transmute` is safe). | ||
/// | ||
/// # Example | ||
/// | ||
/// ~~~ | ||
/// use core::rand::distributions::Exp1; | ||
/// | ||
/// fn main() { | ||
/// let exp2 = (*rand::random::<Exp1>()) * 0.5; | ||
/// println(fmt!("%f is from a Exp(2) distribution", exp2)); | ||
/// } | ||
/// ~~~ | ||
pub struct Exp1(f64); | ||
|
||
// This could be done via `-f64::ln(rng.gen::<f64>())` but that is slower. | ||
impl Rand for Exp1 { | ||
#[inline] | ||
fn rand<R:Rng>(rng: &R) -> Exp1 { | ||
#[inline(always)] | ||
fn pdf(x: f64) -> f64 { | ||
f64::exp(-x) | ||
} | ||
#[inline(always)] | ||
fn zero_case<R:Rng>(rng: &R, _u: f64) -> f64 { | ||
ziggurat_tables::ZIG_EXP_R - f64::ln(rng.gen()) | ||
} | ||
|
||
Exp1(ziggurat(rng, false, | ||
&ziggurat_tables::ZIG_EXP_X, | ||
&ziggurat_tables::ZIG_EXP_F, &ziggurat_tables::ZIG_EXP_F_DIFF, | ||
pdf, zero_case)) | ||
} | ||
} |
Oops, something went wrong.