Skip to content

Commit

Permalink
Add betainc to JAX (jax-ml#1998)
Browse files Browse the repository at this point in the history
Adds betaln, a wrapper for the Beta function (scipy.special.betaln).
  • Loading branch information
srvasude authored and hawkinsp committed Jan 15, 2020
1 parent 12975bb commit 80b35dd
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 53 deletions.
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Operators
batch_matmul
bessel_i0e
bessel_i1e
betainc
bitcast_convert_type
bitwise_not
bitwise_and
Expand Down
1 change: 1 addition & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ jax.scipy.special
.. autosummary::
:toctree: _autosummary

betainc
digamma
entr
erf
Expand Down
34 changes: 17 additions & 17 deletions jax/interpreters/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,25 +328,25 @@ def vectorized_masking_rule(prim, padded_vals, logical_shapes, **params):
return prim.bind(padded_val, **params)


def defbinop(prim):
shape_rules[prim] = binop_shape_rule
masking_rules[prim] = partial(binop_masking_rule, prim)

def binop_shape_rule(shape_exprs):
x_shape_expr, y_shape_expr = shape_exprs
if x_shape_expr == y_shape_expr:
return x_shape_expr
elif not x_shape_expr:
return y_shape_expr
elif not y_shape_expr:
return x_shape_expr
else:
raise ShapeError
def defnaryop(prim):
shape_rules[prim] = naryop_shape_rule
masking_rules[prim] = partial(naryop_masking_rule, prim)

def binop_masking_rule(prim, padded_vals, logical_shapes):
def naryop_masking_rule(prim, padded_vals, logical_shapes):
del logical_shapes # Unused.
padded_x, padded_y = padded_vals
return prim.bind(padded_x, padded_y)
return prim.bind(*padded_vals)

# Assumes n > 1
def naryop_shape_rule(shape_exprs):
if shape_exprs.count(shape_exprs[0]) == len(shape_exprs):
return shape_exprs[0]

filtered_exprs = [s for s in shape_exprs if s]
if filtered_exprs.count(filtered_exprs[0]) == len(filtered_exprs):
return filtered_exprs[0]

raise ShapeError



### definition-time (import-time) shape checker tracer machinery
Expand Down
92 changes: 58 additions & 34 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ def atan2(x, y):
:math:`\mathrm{atan}({x \over y})`."""
return atan2_p.bind(x, y)

def betainc(a, b, x):
r"""Elementwise regularized incomplete beta integral."""
a = _brcast(_brcast(a, b), x)
b = _brcast(b, a)
x = _brcast(x, a)
return regularized_incomplete_beta_p.bind(a, b, x)

def lgamma(x):
r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`."""
return lgamma_p.bind(x)
Expand Down Expand Up @@ -1335,7 +1342,7 @@ def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0):
len_axis = operand.shape[axis]
start_index = start_index if start_index is not None else 0
limit_index = limit_index if limit_index is not None else len_axis

# translate negative indices
if start_index < 0:
start_index = start_index + len_axis
Expand Down Expand Up @@ -1564,7 +1571,7 @@ def unop(result_dtype, accepted_dtypes, name, translation_rule=None):
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)


def binop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
aval_dtypes = [aval.dtype for aval in avals]
for i, (aval_dtype, types) in enumerate(zip(aval_dtypes, accepted_dtypes)):
if not any(dtypes.issubdtype(aval_dtype, t) for t in types):
Expand Down Expand Up @@ -1593,23 +1600,23 @@ def _broadcasting_shape_rule(name, *avals):
return tuple(result_shape)


def binop(result_dtype, accepted_dtypes, name, translation_rule=None):
dtype_rule = partial(binop_dtype_rule, result_dtype, accepted_dtypes, name)
def naryop(result_dtype, accepted_dtypes, name, translation_rule=None):
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name)
shape_rule = partial(_broadcasting_shape_rule, name)
prim = standard_primitive(shape_rule, dtype_rule, name,
translation_rule=translation_rule)
batching.defbroadcasting(prim)
masking.defbinop(prim)
masking.defnaryop(prim)
return prim
standard_binop = partial(binop, _input_dtype)
standard_naryop = partial(naryop, _input_dtype)


# NOTE(mattjj): this isn't great for orchestrate fwd mode because it means JVPs
# get two extra ops in them: a reshape and a broadcast_in_dim (or sometimes just
# a broadcast). but saving the shape info with the primitives isn't great either
# because then we can't trace these ops without shape data.
def _brcast(x, *others):
# Used in jvprules to make binop broadcasting explicit for transposability.
# Used in jvprules to make naryop broadcasting explicit for transposability.
# Requires shape info during jvp tracing, which isn't strictly necessary.
# We don't need full numpy broadcasting, but otherwise the logic is the same
# so we reuse the broadcast_shapes function after filtering out scalars.
Expand Down Expand Up @@ -1659,7 +1666,7 @@ def _sign_translation_rule(c, x):
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
ad.defjvp_zero(sign_p)

nextafter_p = standard_binop(
nextafter_p = standard_naryop(
[_float, _float], 'nextafter',
translation_rule=lambda c, x1, x2: c.NextAfter(x1, x2))

Expand Down Expand Up @@ -1696,11 +1703,28 @@ def _sign_translation_rule(c, x):
cos_p = standard_unop(_float | _complex, 'cos')
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))

atan2_p = standard_binop([_float, _float], 'atan2')
atan2_p = standard_naryop([_float, _float], 'atan2')
ad.defjvp(atan2_p,
lambda g, x, y: _brcast(g, y) * (y / (square(x) + square(y))),
lambda g, x, y: _brcast(g, x) * -x / (square(x) + square(y)))

regularized_incomplete_beta_p = standard_naryop(
[_float, _float, _float], 'regularized_incomplete_beta')

def betainc_gradx(g, a, b, x):
lbeta = lgamma(a) + lgamma(b) - lgamma(a + b)
partial_x = exp((b - 1) * log1p(-x) +
(a - 1) * log(x) - lbeta)
return partial_x * g

def betainc_grad_not_implemented(g, a, b, x):
raise ValueError("Betainc gradient with respect to a and b not supported.")

ad.defjvp(regularized_incomplete_beta_p,
betainc_grad_not_implemented,
betainc_grad_not_implemented,
betainc_gradx)

lgamma_p = standard_unop(_float, 'lgamma')
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))

Expand Down Expand Up @@ -1738,7 +1762,7 @@ def _bessel_i1e_jvp(g, y, x):
ad.defjvp(imag_p, lambda g, _: real(mul(_const(g, -1j), g)))

_complex_dtype = lambda dtype, *args: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype
complex_p = binop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
'complex')
ad.deflinear(complex_p, lambda t: [real(t), imag(neg(t))])

Expand Down Expand Up @@ -1775,7 +1799,7 @@ def _abs_jvp_rule(g, ans, x):
lambda g, ans, x:
_safe_mul(g, mul(_const(x, -0.5), pow(x, _const(x, -1.5)))))

pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')

def _pow_jvp_lhs(g, ans, x, y):
# we call _safe_mul here so that we get the behavior 0*inf = 0, since when a
Expand All @@ -1792,20 +1816,20 @@ def _pow_jvp_rhs(g, ans, x, y):

not_p = standard_unop(_int | _bool, 'not')

and_p = standard_binop([_any, _any], 'and')
and_p = standard_naryop([_any, _any], 'and')
ad.defjvp_zero(and_p)

or_p = standard_binop([_any, _any], 'or')
or_p = standard_naryop([_any, _any], 'or')
ad.defjvp_zero(or_p)

xor_p = standard_binop([_any, _any], 'xor')
xor_p = standard_naryop([_any, _any], 'xor')
ad.defjvp_zero(xor_p)

def _add_transpose(t, x, y):
# assert x is ad.undefined_primal and y is ad.undefined_primal # not affine
return [t, t]

add_p = standard_binop([_num, _num], 'add')
add_p = standard_naryop([_num, _num], 'add')
ad.defjvp(add_p, lambda g, x, y: _brcast(g, y), lambda g, x, y: _brcast(g, x))
ad.primitive_transposes[add_p] = _add_transpose

Expand All @@ -1814,13 +1838,13 @@ def _sub_transpose(t, x, y):
assert x is ad.undefined_primal and y is ad.undefined_primal # not affine
return [t, neg(t) if t is not ad_util.zero else ad_util.zero]

sub_p = standard_binop([_num, _num], 'sub')
sub_p = standard_naryop([_num, _num], 'sub')
ad.defjvp(sub_p,
lambda g, x, y: _brcast(g, y),
lambda g, x, y: _brcast(neg(g), x))
ad.primitive_transposes[sub_p] = _sub_transpose

mul_p = standard_binop([_num, _num], 'mul')
mul_p = standard_naryop([_num, _num], 'mul')
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul)


Expand All @@ -1833,7 +1857,7 @@ def _safe_mul_translation_rule(c, x, y):
c.Broadcast(zero, out_shape),
c.Mul(x, y))

safe_mul_p = standard_binop([_num, _num], 'safe_mul',
safe_mul_p = standard_naryop([_num, _num], 'safe_mul',
translation_rule=_safe_mul_translation_rule)
ad.defbilinear_broadcasting(_brcast, safe_mul_p, _safe_mul, _safe_mul)

Expand All @@ -1842,13 +1866,13 @@ def _div_transpose_rule(cotangent, x, y):
assert x is ad.undefined_primal and y is not ad.undefined_primal
res = ad_util.zero if cotangent is ad_util.zero else div(cotangent, y)
return res, None
div_p = standard_binop([_num, _num], 'div')
div_p = standard_naryop([_num, _num], 'div')
ad.defjvp(div_p,
lambda g, x, y: div(_brcast(g, y), y),
lambda g, x, y: div(mul(neg(_brcast(g, x)), x), square(y)))
ad.primitive_transposes[div_p] = _div_transpose_rule

rem_p = standard_binop([_num, _num], 'rem')
rem_p = standard_naryop([_num, _num], 'rem')
ad.defjvp(rem_p,
lambda g, x, y: _brcast(g, y),
lambda g, x, y: mul(_brcast(neg(g), x), floor(div(x, y))))
Expand Down Expand Up @@ -1879,44 +1903,44 @@ def _minmax_translation_rule(c, x, y, minmax=None, cmp=None):
x, y)
return minmax(c)(x, y)

max_p = standard_binop([_any, _any], 'max', translation_rule=partial(
max_p = standard_naryop([_any, _any], 'max', translation_rule=partial(
_minmax_translation_rule, minmax=lambda c: c.Max, cmp=lambda c: c.Gt))
ad.defjvp2(max_p,
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))

min_p = standard_binop([_any, _any], 'min', translation_rule=partial(
min_p = standard_naryop([_any, _any], 'min', translation_rule=partial(
_minmax_translation_rule, minmax=lambda c: c.Min, cmp=lambda c: c.Lt))
ad.defjvp2(min_p,
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))


shift_left_p = standard_binop([_int, _int], 'shift_left')
shift_left_p = standard_naryop([_int, _int], 'shift_left')
ad.defjvp_zero(shift_left_p)

shift_right_arithmetic_p = standard_binop([_int, _int], 'shift_right_arithmetic')
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
ad.defjvp_zero(shift_right_arithmetic_p)

shift_right_logical_p = standard_binop([_int, _int], 'shift_right_logical')
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
ad.defjvp_zero(shift_right_logical_p)

eq_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'eq')
eq_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'eq')
ad.defjvp_zero(eq_p)

ne_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'ne')
ne_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'ne')
ad.defjvp_zero(ne_p)

ge_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'ge')
ge_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'ge')
ad.defjvp_zero(ge_p)

gt_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'gt')
gt_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'gt')
ad.defjvp_zero(gt_p)

le_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'le')
le_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'le')
ad.defjvp_zero(le_p)

lt_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'lt')
lt_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'lt')
ad.defjvp_zero(lt_p)


Expand Down Expand Up @@ -1992,7 +2016,7 @@ def _conv_general_dilated_shape_rule(
def _conv_general_dilated_dtype_rule(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, **unused_kwargs):
return binop_dtype_rule(_input_dtype, [_float, _float],
return naryop_dtype_rule(_input_dtype, [_float, _float],
'conv_general_dilated', lhs, rhs)

_conv_spec_transpose = lambda spec: (spec[1], spec[0]) + spec[2:]
Expand Down Expand Up @@ -2185,7 +2209,7 @@ def _dot_general_shape_rule(lhs, rhs, dimension_numbers, precision):


def _dot_general_dtype_rule(lhs, rhs, dimension_numbers, precision):
return binop_dtype_rule(_input_dtype, [_num, _num], 'dot_general', lhs, rhs)
return naryop_dtype_rule(_input_dtype, [_num, _num], 'dot_general', lhs, rhs)


def _dot_general_transpose_lhs(g, y, dimension_numbers, precision,
Expand Down Expand Up @@ -2391,7 +2415,7 @@ def _clamp_shape_rule(min, operand, max):
raise TypeError(m.format(max.shape))
return operand.shape

_clamp_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_any, _any, _any],
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
'clamp')

clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
Expand Down
4 changes: 2 additions & 2 deletions jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from jax.util import partial, prod
from jax.abstract_arrays import ShapedArray
from jax.core import Primitive
from jax.lax import (standard_primitive, standard_unop, binop_dtype_rule,
from jax.lax import (standard_primitive, standard_unop, naryop_dtype_rule,
_float, _complex, _input_dtype, _broadcasting_select)
from jax.lib import xla_client
from jax.lib import lapack
Expand Down Expand Up @@ -311,7 +311,7 @@ def eigh_batching_rule(batched_args, batch_dims, lower):


triangular_solve_dtype_rule = partial(
binop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
'triangular_solve')

def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
Expand Down
1 change: 1 addition & 0 deletions jax/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
sinh = onp.sinh
cosh = onp.cosh

betainc = scipy.special.betainc
lgamma = scipy.special.gammaln
digamma = scipy.special.digamma
erf = scipy.special.erf
Expand Down
6 changes: 6 additions & 0 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def betaln(x, y):
return lax.lgamma(x) + lax.lgamma(y) - lax.lgamma(x + y)


@_wraps(osp_special.betainc)
def betainc(a, b, x):
a, b, x = _promote_args_inexact("betainc", a, b, x)
return lax.betainc(a, b, x)


@_wraps(osp_special.digamma, update_doc=False)
def digamma(x):
x, = _promote_args_inexact("digamma", x)
Expand Down
6 changes: 6 additions & 0 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import scipy.stats as osp_stats

from jax import api
from jax import lib
from jax import test_util as jtu
from jax.scipy import special as lsp_special
from jax.scipy import stats as lsp_stats
Expand Down Expand Up @@ -76,6 +77,11 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, test_name=None):
op_record("entr", 1, float_dtypes, jtu.rand_default, False),
]

if lib.version > (0, 1, 37):
JAX_SPECIAL_FUNCTION_RECORDS.append(
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False)
)

CombosWithReplacement = itertools.combinations_with_replacement


Expand Down
Loading

0 comments on commit 80b35dd

Please sign in to comment.