Skip to content

Commit

Permalink
Merge pull request scipy#476 from pv/special-binom
Browse files Browse the repository at this point in the history
BUG: special: fix up corner case accuracy in binom()
  • Loading branch information
rgommers committed May 5, 2013
2 parents a054445 + 5e25807 commit c5b9533
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 27 deletions.
2 changes: 1 addition & 1 deletion scipy/special/_testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def assert_func_equal(func, results, points, rtol=None, atol=None,
result_columns=result_columns, result_func=result_func,
rtol=rtol, atol=atol, param_filter=param_filter,
knownfailure=knownfailure, nan_ok=nan_ok, vectorized=vectorized,
ignore_inf_sign=False)
ignore_inf_sign=ignore_inf_sign)
fdata.check()

class FuncData(object):
Expand Down
81 changes: 71 additions & 10 deletions scipy/special/cephes/beta.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ extern int sgngam;
#define ASYMP_FACTOR 1e6

static double lbeta_asymp(double a, double b, int *sgn);
static double lbeta_negint(int a, double b);
static double beta_negint(int a, double b);
double gammasgn(double x);

double beta(a, b)
double a, b;
Expand All @@ -80,13 +83,25 @@ double a, b;
sign = 1;

if (a <= 0.0) {
if (a == floor(a))
goto over;
if (a == floor(a)) {
if (a == (int)a) {
return beta_negint((int)a, b);
}
else {
goto over;
}
}
}

if (b <= 0.0) {
if (b == floor(b))
goto over;
if (b == floor(b)) {
if (b == (int)b) {
return beta_negint((int)b, a);
}
else {
goto over;
}
}
}

if (fabs(a) < fabs(b)) {
Expand All @@ -100,7 +115,7 @@ double a, b;
}

y = a + b;
if (fabs(y) > MAXGAM) {
if (fabs(y) > MAXGAM || fabs(a) > MAXGAM || fabs(b) > MAXGAM) {
y = lgam(y);
sign *= sgngam; /* keep track of the sign */
y = lgam(b) - y;
Expand Down Expand Up @@ -143,13 +158,25 @@ double a, b;
sign = 1;

if (a <= 0.0) {
if (a == floor(a))
goto over;
if (a == floor(a)) {
if (a == (int)a) {
return lbeta_negint((int)a, b);
}
else {
goto over;
}
}
}

if (b <= 0.0) {
if (b == floor(b))
goto over;
if (b == floor(b)) {
if (b == (int)b) {
return lbeta_negint((int)b, a);
}
else {
goto over;
}
}
}

if (fabs(a) < fabs(b)) {
Expand All @@ -164,7 +191,7 @@ double a, b;
}

y = a + b;
if (fabs(y) > MAXGAM) {
if (fabs(y) > MAXGAM || fabs(a) > MAXGAM || fabs(b) > MAXGAM) {
y = lgam(y);
sign *= sgngam; /* keep track of the sign */
y = lgam(b) - y;
Expand Down Expand Up @@ -218,3 +245,37 @@ static double lbeta_asymp(double a, double b, int *sgn)

return r;
}


/*
* Special case for a negative integer argument
*/

static double beta_negint(int a, double b)
{
int sgn;
if (b == (int)b && 1 - a - b > 0) {
sgn = ((int)b % 2 == 0) ? 1 : -1;
return sgn * beta(1 - a - b, b);
}
else {
mtherr("lbeta", OVERFLOW);
return sgn*NPY_INFINITY;
}
}

static double lbeta_negint(int a, double b)
{
double r;
int sgn;
if (b == (int)b && 1 - a - b > 0) {
sgn = ((int)b % 2 == 0) ? 1 : -1;
r = lbeta(1 - a - b, b);
sgngam *= sgn;
return r;
}
else {
mtherr("lbeta", OVERFLOW);
return NPY_INFINITY;
}
}
37 changes: 32 additions & 5 deletions scipy/special/orthogonal_eval.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ References
# Direct evaluation of polynomials
#------------------------------------------------------------------------------
cimport cython
from libc.math cimport sqrt, exp, floor, fabs
from libc.math cimport sqrt, exp, floor, fabs, log, sin, M_PI as pi

from numpy cimport npy_cdouble
from _complexstuff cimport nan, inf, number_t
Expand All @@ -32,6 +32,7 @@ cdef extern from "cephes.h":
double Gamma(double x) nogil
double lgam(double x) nogil
double beta (double a, double b) nogil
double lbeta (double a, double b) nogil
double hyp2f1_wrap "hyp2f1" (double a, double b, double c, double x) nogil

cdef extern from "specfun_wrappers.h":
Expand Down Expand Up @@ -66,7 +67,7 @@ cdef inline number_t hyp1f1(double a, double b, number_t z) nogil:

@cython.cdivision(True)
cdef inline double binom(double n, double k) nogil:
cdef double kx, nx, num, den
cdef double kx, nx, num, den, dk, sgn
cdef int i

if n < 0:
Expand All @@ -76,16 +77,19 @@ cdef inline double binom(double n, double k) nogil:
return nan

kx = floor(k)
if k == kx:
if k == kx and (fabs(n) > 1e-8 or n == 0):
# Integer case: use multiplication formula for less rounding error
# for cases where the result is an integer.
#
# This cannot be used for small nonzero n due to loss of
# precision.

nx = floor(n)
if nx == n and kx > nx/2 and nx > 0:
# Reduce kx by symmetry
kx = nx - kx

if kx >= 1 and kx < 20:
if kx >= 0 and kx < 20:
num = 1.0
den = 1.0
for i in range(1, 1 + <int>kx):
Expand All @@ -97,7 +101,30 @@ cdef inline double binom(double n, double k) nogil:
return num/den

# general case:
return 1/beta(1 + n - k, 1 + k)/(n + 1)
if n >= 1e10*k and k > 0:
# avoid under/overflows in intermediate results
return exp(-lbeta(1 + n - k, 1 + k) - log(n + 1))
elif k > 1e8*fabs(n):
# avoid loss of precision
num = Gamma(1 + n) / fabs(k) + Gamma(1 + n) * n / (2*k**2) # + ...
num /= pi * fabs(k)**n
if k > 0:
kx = floor(k)
if <int>kx == kx:
dk = k - kx
sgn = 1 if (<int>kx) % 2 == 0 else -1
else:
dk = k
sgn = 1
return num * sin((dk-n)*pi) * sgn
else:
kx = floor(k)
if <int>kx == kx:
return 0
else:
return num * sin(k*pi)
else:
return 1/beta(1 + n - k, 1 + k)/(n + 1)

#-----------------------------------------------------------------------------
# Jacobi
Expand Down
1 change: 0 additions & 1 deletion scipy/special/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_binom(self):
nk,
atol=1e-10, rtol=1e-10)

@dec.knownfailureif(True, "beta function overflow bug for a >> b")
def test_binom_2(self):
# Test branches in implementation
np.random.seed(1234)
Expand Down
42 changes: 32 additions & 10 deletions scipy/special/tests/test_mpmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ def test_beta():
b = np.r_[np.logspace(-200, 200, 4),
np.logspace(-10, 10, 4),
np.logspace(-1, 1, 4),
np.arange(-10, 11, 1),
np.arange(-10, 11, 1) + 0.5,
-1, -2.3, -3, -100.3, -10003.4]
a = b

Expand All @@ -275,7 +277,8 @@ def test_beta():
lambda a, b: float(mpmath.beta(a, b)),
ab,
vectorized=False,
rtol=1e-10)
rtol=1e-10,
ignore_inf_sign=True)

assert_func_equal(
sc.betaln,
Expand Down Expand Up @@ -314,7 +317,8 @@ def values(self, n):
n3 = max(2, n - n1 - n2)

v1 = np.linspace(-1, 1, n1)
v2 = np.linspace(-10, 10, n2)
v2 = np.r_[np.linspace(-10, 10, max(0, n2-4)),
-9, -5.5, 5.5, 9]
if self.a >= 0 and self.b > 0:
v3 = np.logspace(-30, np.log10(self.b), n3//2)
v4 = np.logspace(-30, 5, n3//2)
Expand Down Expand Up @@ -370,7 +374,8 @@ def values(self, n):

class MpmathData(object):
def __init__(self, scipy_func, mpmath_func, arg_spec, name=None,
dps=None, prec=None, n=5000, rtol=1e-7, atol=1e-300):
dps=None, prec=None, n=5000, rtol=1e-7, atol=1e-300,
ignore_inf_sign=False):
self.scipy_func = scipy_func
self.mpmath_func = mpmath_func
self.arg_spec = arg_spec
Expand All @@ -380,6 +385,7 @@ def __init__(self, scipy_func, mpmath_func, arg_spec, name=None,
self.rtol = rtol
self.atol = atol
self.is_complex = any([isinstance(arg, ComplexArg) for arg in self.arg_spec])
self.ignore_inf_sign = ignore_inf_sign
if not name or name == '<lambda>':
name = getattr(scipy_func, '__name__', None)
if not name or name == '<lambda>':
Expand Down Expand Up @@ -433,7 +439,8 @@ def pytype(x):
argarr,
vectorized=False,
rtol=self.rtol, atol=self.atol,
nan_ok=True)
nan_ok=True,
ignore_inf_sign=self.ignore_inf_sign)
break
except AssertionError:
if j >= len(dps_list)-1:
Expand Down Expand Up @@ -648,7 +655,8 @@ def test_besseli(self):
assert_mpmath_equal(sc.iv,
_exception_to_nan(lambda v, z: mpmath.besseli(v, z, **HYPERKW)),
[Arg(-1e100, 1e100), Arg()],
n=1000)
n=1000,
atol=1e-270)

def test_besseli_complex(self):
assert_mpmath_equal(lambda v, z: sc.iv(v.real, z),
Expand Down Expand Up @@ -694,28 +702,42 @@ def test_bessely(self):
[Arg(-1e100, 1e100), Arg()],
n=1000)

@knownfailure_overridable("sin(pi k) != sin_pi(k) at negative half-integer orders")
def test_bessely_complex(self):
assert_mpmath_equal(lambda v, z: sc.yv(v.real, z),
lambda v, z: _exception_to_nan(mpmath.bessely)(v, z, **HYPERKW),
[Arg(), ComplexArg()],
n=2000)

@knownfailure_overridable()
def test_beta(self):
def beta(a, b):
if a < -1e12 or b < -1e12:
# Function is defined here only at integers, but due
# to loss of precision this is numerically
# ill-defined. Don't compare values here.
return np.nan
return mpmath.beta(a, b)
assert_mpmath_equal(sc.beta,
mpmath.beta,
beta,
[Arg(), Arg()],
dps=400)
dps=400,
ignore_inf_sign=True)

def test_betainc(self):
assert_mpmath_equal(sc.betainc,
_exception_to_nan(lambda a, b, x: mpmath.betainc(a, b, 0, x, regularized=True)),
[Arg(), Arg(), Arg()])

@knownfailure_overridable()
def test_binom(self):
def binomial(n, k):
if abs(k) > 1e8*(abs(n) + 1):
# The binomial is rapidly oscillating in this region,
# and the function is numerically ill-defined. Don't
# compare values here.
return np.nan
return mpmath.binomial(n, k)
assert_mpmath_equal(sc.binom,
mpmath.binomial,
binomial,
[Arg(), Arg()],
dps=400)

Expand Down

0 comments on commit c5b9533

Please sign in to comment.