Skip to content

Commit

Permalink
MAINT: adjust tolerance for validating the sum of probs in random.choice
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbrc committed Jul 29, 2015
1 parent a92c4a1 commit b6d0263
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
8 changes: 7 additions & 1 deletion numpy/random/mtrand/mtrand.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,12 @@ cdef class RandomState:

if p is not None:
d = len(p)

atol = np.sqrt(np.finfo(np.float64).eps)
if isinstance(p, np.ndarray):
if np.issubdtype(p.dtype, np.floating):
atol = max(atol, np.sqrt(np.finfo(p.dtype).eps))

p = <ndarray>PyArray_ContiguousFromObject(p, NPY_DOUBLE, 1, 1)
pix = <double*>PyArray_DATA(p)

Expand All @@ -1093,7 +1099,7 @@ cdef class RandomState:
raise ValueError("a and p must have same size")
if np.logical_or.reduce(p < 0):
raise ValueError("probabilities are not non-negative")
if abs(kahan_sum(pix, d) - 1.) > 1e-8:
if abs(kahan_sum(pix, d) - 1.) > atol:
raise ValueError("probabilities do not sum to 1")

shape = size
Expand Down
14 changes: 13 additions & 1 deletion numpy/random/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from numpy.testing import (TestCase, run_module_suite, assert_,
assert_array_equal)
assert_array_equal, assert_raises)
from numpy import random
from numpy.compat import long
import numpy as np
Expand Down Expand Up @@ -100,6 +100,18 @@ def test_beta_small_parameters(self):
x = np.random.beta(0.0001, 0.0001, size=100)
assert_(not np.any(np.isnan(x)), 'Nans in np.random.beta')

def test_choice_sum_of_probs_tolerance(self):
# The sum of probs should be 1.0 with some tolerance.
# For low precision dtypes the tolerance was too tight.
# See numpy github issue 6123.
np.random.seed(1234)
a = [1, 2, 3]
counts = [4, 4, 2]
for dt in np.float16, np.float32, np.float64:
probs = np.array(counts, dtype=dt) / sum(counts)
c = np.random.choice(a, p=probs)
assert_(c in a)
assert_raises(ValueError, np.random.choice, a, p=probs*0.9)

if __name__ == "__main__":
run_module_suite()

0 comments on commit b6d0263

Please sign in to comment.