diff --git a/jax/abstract_arrays.py b/jax/abstract_arrays.py index bd966c794e4a..28f151011015 100644 --- a/jax/abstract_arrays.py +++ b/jax/abstract_arrays.py @@ -198,10 +198,11 @@ def zeros_like_array(x): dtype = dtypes.canonicalize_dtype(dtypes.result_type(x)) return onp.broadcast_to(onp.array(0, dtype), onp.shape(x)) -array_types = {onp.ndarray, onp.float64, onp.float32, onp.float16, +array_types = {onp.ndarray, onp.bool_, + onp.int8, onp.int16, onp.int32, onp.int64, + onp.uint8, onp.uint16, onp.uint32, onp.uint64, + dtypes.bfloat16, onp.float16, onp.float32, onp.float64, onp.complex64, onp.complex128, - onp.int64, onp.int32, onp.int16, onp.int8, - onp.bool_, onp.uint64, onp.uint32, onp.uint16, onp.uint8, onp.longlong} for t in array_types: diff --git a/jax/dtypes.py b/jax/dtypes.py index 6ddeacae03dd..efd2547ed20f 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -12,24 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Array type functions. +# +# JAX dtypes differ from NumPy in both: +# a) their type promotion rules, and +# b) the set of supported types (e.g., bfloat16), +# so we need our own implementation that deviates from NumPy in places. + from __future__ import absolute_import from __future__ import division from __future__ import print_function from distutils.util import strtobool +import functools import os import numpy as onp import six -from .config import flags from . import util +from .config import flags +from .lib import xla_client FLAGS = flags.FLAGS flags.DEFINE_bool('jax_enable_x64', strtobool(os.getenv('JAX_ENABLE_X64', 'False')), 'Enable 64-bit types to be used.') +# bfloat16 support +bfloat16 = xla_client.bfloat16 +_bfloat16_dtype = onp.dtype(bfloat16) + +class _bfloat16_finfo(object): + bits = 16 + eps = bfloat16(float.fromhex("0x1p-7")) + epsneg = bfloat16(float.fromhex("0x1p-8")) + machep = -7 + negep = -8 + max = bfloat16(float.fromhex("0x1.FEp127")) + min = -max + nexp = 8 + nmant = 7 + iexp = nexp + precision = 2 + resolution = 10 ** -2 + tiny = bfloat16(float.fromhex("0x1p-126")) # Default types. @@ -96,12 +123,56 @@ def coerce_to_array(x): return onp.array(x, dtype) if dtype else onp.array(x) iinfo = onp.iinfo -finfo = onp.finfo + +def finfo(dtype): + # Since NumPy doesn't consider bfloat16 a floating-point type, we have to + # provide an alternative implementation of finfo that does so. + if onp.result_type(dtype) == _bfloat16_dtype: + return _bfloat16_finfo + else: + return onp.finfo(dtype) + + +def issubdtype(a, b): + if a == bfloat16: + return b in [onp.floating, onp.inexact, onp.number] + return onp.issubdtype(a, b) can_cast = onp.can_cast -issubdtype = onp.issubdtype issubsctype = onp.issubsctype -promote_types = onp.promote_types + +_bfloat16_type_promotions = { + onp.dtype('bool'): onp.dtype(bfloat16), + onp.dtype(bfloat16): onp.dtype(bfloat16), + onp.dtype('float16'): onp.dtype('float32'), + onp.dtype('float32'): onp.dtype('float32'), + onp.dtype('float64'): onp.dtype('float64'), + onp.dtype('complex64'): onp.dtype('complex64'), + onp.dtype('complex128'): onp.dtype('complex128'), + onp.dtype('int8'): onp.dtype(bfloat16), + onp.dtype('int16'): onp.dtype('float32'), + onp.dtype('int32'): onp.dtype('float64'), + onp.dtype('int64'): onp.dtype('float64'), + onp.dtype('uint8'): onp.dtype(bfloat16), + onp.dtype('uint16'): onp.dtype('float32'), + onp.dtype('uint32'): onp.dtype('float64'), + onp.dtype('uint64'): onp.dtype('float64'), +} + +def promote_types(a, b): + a = onp.dtype(a) + b = onp.dtype(b) + if b == _bfloat16_dtype: + a, b = b, a + + if a == _bfloat16_dtype: + try: + return _bfloat16_type_promotions[b] + except: + raise TypeError("invalid type promotion of bfloat16 type and {}" + .format(b)) + + return onp.promote_types(a, b) def is_python_scalar(x): @@ -138,4 +209,4 @@ def result_type(*args): (scalars if is_python_scalar(x) else dtypes).append(dtype(x)) array_priority = max(map(_dtype_priority, dtypes)) if dtypes else -1 dtypes += [x for x in scalars if _dtype_priority(x) > array_priority] - return canonicalize_dtype(onp.result_type(*dtypes)) \ No newline at end of file + return canonicalize_dtype(functools.reduce(promote_types, dtypes)) diff --git a/jax/lax/lax.py b/jax/lax/lax.py index d3f1b5d1970e..f74c0c966e0c 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -351,17 +351,23 @@ def convert_element_type(operand, new_dtype): """ new_dtype = dtypes.canonicalize_dtype(new_dtype) old_dtype = dtypes.canonicalize_dtype(_dtype(operand)) - if old_dtype != new_dtype: - if (dtypes.issubdtype(old_dtype, onp.complexfloating) and - not dtypes.issubdtype(new_dtype, onp.complexfloating)): - msg = "Casting complex values to real discards the imaginary part" - warnings.warn(msg, onp.ComplexWarning, stacklevel=2) - operand = real(operand) - old_dtype = _dtype(operand) - return convert_element_type_p.bind( - operand, new_dtype=new_dtype, old_dtype=old_dtype) - else: + if old_dtype == new_dtype: return operand + if (dtypes.issubdtype(old_dtype, onp.complexfloating) and + not dtypes.issubdtype(new_dtype, onp.complexfloating)): + msg = "Casting complex values to real discards the imaginary part" + warnings.warn(msg, onp.ComplexWarning, stacklevel=2) + operand = real(operand) + old_dtype = _dtype(operand) + # TODO(b/143311238, b/142974574): work around bfloat16 conversion bugs by + # introducing an intermediate cast via float32. + if ((old_dtype == dtypes.bfloat16 and new_dtype != onp.float32) or + (new_dtype == dtypes.bfloat16 and old_dtype != onp.float32)): + operand = convert_element_type_p.bind( + operand, new_dtype=onp.float32, old_dtype=old_dtype) + old_dtype = onp.float32 + return convert_element_type_p.bind( + operand, new_dtype=new_dtype, old_dtype=old_dtype) def bitcast_convert_type(operand, new_dtype): """Elementwise bitcast. @@ -1377,7 +1383,19 @@ def reciprocal(x): r"""Elementwise reciprocal: :math:`1 \over x`.""" return div(_const(x, 1), x) +def _upcast_fp16_for_computation(f): + @functools.wraps(f) + def f_wrapped(x): + dtype = _dtype(x) + if dtype == onp.float16 or dtype == dtypes.bfloat16: + return convert_element_type( + f(convert_element_type(x, onp.float32)), dtype) + return f(x) + + return f_wrapped + @api.jit +@_upcast_fp16_for_computation def tan(x): r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" return div(sin(x), cos(x)) @@ -1401,17 +1419,6 @@ def atan(x): r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.""" return atan2(x, _const(x, 1)) -def _upcast_fp16_for_computation(f): - @functools.wraps(f) - def f_wrapped(x): - dtype = _dtype(x) - if dtype == onp.float16: - return convert_element_type( - f(convert_element_type(x, onp.float32)), dtype) - return f(x) - - return f_wrapped - @api.jit @_upcast_fp16_for_computation def sinh(x): @@ -1586,7 +1593,7 @@ def _brcast_to(x, shape): return broadcast(x, shape) -_float = {onp.floating} +_float = {onp.floating, dtypes.bfloat16} _complex = {onp.complexfloating} _complex_elem_types = {onp.float32, onp.float64} _int = {onp.integer} diff --git a/jax/lax_reference.py b/jax/lax_reference.py index 4dd5327dec84..6fda3d5365f9 100644 --- a/jax/lax_reference.py +++ b/jax/lax_reference.py @@ -25,6 +25,8 @@ from six.moves import builtins +from . import dtypes + _slice = builtins.slice _max = builtins.max _min = builtins.min @@ -88,7 +90,7 @@ def complex(x, y): mul = onp.multiply def div(lhs, rhs): - if onp.issubdtype(onp.result_type(lhs), onp.integer): + if dtypes.issubdtype(dtypes.result_type(lhs), onp.integer): quotient = onp.floor_divide(lhs, rhs) select = onp.logical_and(onp.sign(lhs) != onp.sign(rhs), onp.remainder(lhs, rhs) != 0) @@ -176,7 +178,11 @@ def dot_general(lhs, rhs, dimension_numbers): not_none = lambda x: x is not None out_axis_ids = filter(not_none, batch_ids + lhs_out_axis_ids + rhs_out_axis_ids) - return onp.einsum(lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids) + assert lhs.dtype == rhs.dtype + dtype = onp.float32 if lhs.dtype == dtypes.bfloat16 else None + out = onp.einsum(lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids, + dtype=dtype) + return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out def broadcast(operand, sizes): return onp.broadcast_to(operand, sizes + onp.shape(operand)) @@ -352,15 +358,15 @@ def _make_reducer(py_binop, init_val): monoid_record = _monoids.get(getattr(py_binop, '__name__')) if monoid_record: reducer, monoid_identity = monoid_record - if init_val == monoid_identity(onp.result_type(init_val)): + if init_val == monoid_identity(dtypes.result_type(init_val)): return reducer return _reducer_from_pyfunc(py_binop, init_val) def _get_max_identity(dt): - return -onp.inf if onp.issubdtype(dt, onp.floating) else onp.iinfo(dt).min + return -onp.inf if dtypes.issubdtype(dt, onp.floating) else onp.iinfo(dt).min def _get_min_identity(dt): - return onp.inf if onp.issubdtype(dt, onp.floating) else onp.iinfo(dt).max + return onp.inf if dtypes.issubdtype(dt, onp.floating) else onp.iinfo(dt).max def _identity_getter(op): return lambda dtype: onp.asarray(op.identity, dtype=dtype) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index c82bafb05104..ad265dc1d8a1 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -134,6 +134,7 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, int16 = onp.int16 int32 = onp.int32 int64 = onp.int64 +bfloat16 = dtypes.bfloat16 float16 = onp.float16 float32 = single = onp.float32 float64 = double = onp.float64 @@ -152,10 +153,8 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, unsignedinteger = onp.unsignedinteger iinfo = dtypes.iinfo -finfo = dtypes.finfo can_cast = dtypes.can_cast -issubdtype = dtypes.issubdtype issubsctype = dtypes.issubsctype result_type = dtypes.result_type promote_types = dtypes.promote_types @@ -338,6 +337,14 @@ def _canonicalize_axis(axis, num_dims): ### implementations of numpy functions in terms of lax +@_wraps(onp.finfo) +def finfo(dtype): return dtypes.finfo(dtype) + +@_wraps(onp.issubdtype) +def issubdtype(arg1, arg2): return dtypes.issubdtype(arg1, arg2) + +issubdtype = dtypes.issubdtype + @_wraps(onp.isscalar) def isscalar(num): return dtypes.is_python_scalar(num) or onp.isscalar(num) @@ -1143,9 +1150,12 @@ def nan_to_num(x, copy=True): ### Reducers -def _make_reduction(np_fun, op, init_val, preproc=None): +def _make_reduction(np_fun, op, init_val, preproc=None, bool_op=None, + upcast_f16_for_computation=False): """Creates reduction function given a binary operation and monoid identity.""" + bool_op = bool_op or op + @_wraps(np_fun) def reduction(a, axis=None, dtype=None, out=None, keepdims=False): if out is not None: @@ -1154,16 +1164,18 @@ def reduction(a, axis=None, dtype=None, out=None, keepdims=False): a = a if isinstance(a, ndarray) else asarray(a) a = preproc(a) if preproc else a dims = _reduction_dims(a, axis) - result_dtype = _dtype(np_fun(onp.ones((), dtype=dtype or _dtype(a)))) - if _dtype(a) != result_dtype: - a = lax.convert_element_type(a, result_dtype) - result = lax.reduce(a, _reduction_init_val(a, init_val), op, dims) + result_dtype = dtype or _dtype(np_fun(onp.ones((), dtype=_dtype(a)))) + if upcast_f16_for_computation and issubdtype(result_dtype, inexact): + computation_dtype = promote_types(result_dtype, float32) + else: + computation_dtype = result_dtype + a = lax.convert_element_type(a, computation_dtype) + result = lax.reduce(a, _reduction_init_val(a, init_val), + op if computation_dtype != bool_ else bool_op, dims) if keepdims: shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims))) result = lax.reshape(result, shape_with_singletons) - if dtype and onp.dtype(dtype) != onp.dtype(result_dtype): - result = lax.convert_element_type(result, dtype) - return result + return lax.convert_element_type(result, dtype or result_dtype) return reduction @@ -1188,10 +1200,12 @@ def _reduction_init_val(a, init_val): sign, info = onp.sign(init_val), iinfo(a_dtype) return onp.array(info.min if sign < 0 else info.max, dtype=a_dtype) -_cast_to_bool = partial(lax.convert_element_type, new_dtype=onp.bool_) +_cast_to_bool = partial(lax.convert_element_type, new_dtype=bool_) -sum = _make_reduction(onp.sum, lax.add, 0) -product = prod = _make_reduction(onp.prod, lax.mul, 1) +sum = _make_reduction(onp.sum, lax.add, 0, upcast_f16_for_computation=True, + bool_op=lax.bitwise_or) +product = prod = _make_reduction(onp.prod, lax.mul, 1, bool_op=lax.bitwise_and, + upcast_f16_for_computation=True) amax = max = _make_reduction(onp.max, lax.max, -onp.inf) amin = min = _make_reduction(onp.min, lax.min, onp.inf) all = alltrue = _make_reduction(onp.all, lax.bitwise_and, True, _cast_to_bool) @@ -1210,7 +1224,7 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): if dtype is None: if (issubdtype(_dtype(a), onp.bool_) or issubdtype(_dtype(a), onp.integer)): - dtype = dtypes.canonicalize_dtype(onp.float64) + dtype = float_ else: dtype = _dtype(a) @@ -1267,19 +1281,28 @@ def average(a, axis=None, weights=None, returned=False): return avg, weights_sum return avg +_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype @_wraps(onp.var) def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): if out is not None: raise ValueError("var does not support the `out` argument.") - if dtype is None: - if (issubdtype(_dtype(a), onp.bool_) or - issubdtype(_dtype(a), onp.integer)): - dtype = dtypes.canonicalize_dtype(onp.float64) - centered = subtract(a, mean(a, axis, dtype=dtype, keepdims=True)) - if iscomplexobj(centered): - centered = lax.abs(centered) + a_dtype = _dtype(a) + if dtype: + a_dtype = promote_types(a_dtype, dtype) + else: + if not issubdtype(a_dtype, inexact): + dtype = a_dtype = float_ + else: + dtype = _complex_basetype(a_dtype) + a_dtype = promote_types(a_dtype, float32) + a_mean = mean(a, axis, dtype=a_dtype, keepdims=True) + centered = a - a_mean + if issubdtype(centered.dtype, complexfloating): + centered = lax.real(lax.mul(centered, lax.conj(centered))) + else: + centered = lax.square(centered) if axis is None: normalizer = size(a) @@ -1287,9 +1310,10 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): normalizer = onp.prod(onp.take(shape(a), axis)) normalizer = normalizer - ddof - result = sum(lax.mul(centered, centered), axis, - dtype=dtype, keepdims=keepdims) - return lax.div(result, lax.convert_element_type(normalizer, _dtype(result))) + result = sum(centered, axis, keepdims=keepdims) + out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) + return lax.convert_element_type(out, dtype) + @_wraps(onp.std) @@ -1767,24 +1791,28 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, @_wraps(onp.logspace) def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): """Implementation of logspace differentiable in start and stop args.""" + dtype = dtype or result_type(start, stop, float_) + computation_dtype = promote_types(dtype, float_) + start = asarray(start, dtype=computation_dtype) + stop = asarray(stop, dtype=computation_dtype) lin = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis) - if dtype is None: - return power(base, lin) - else: - return lax.convert_element_type(power(base, lin), dtype) + return lax.convert_element_type(power(base, lin), dtype) @_wraps(onp.geomspace) def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): """Implementation of geomspace differentiable in start and stop args.""" dtype = dtype or result_type(start, stop, float(num), zeros((), dtype)) + computation_dtype = promote_types(dtype, float32) + start = asarray(start, dtype=computation_dtype) + stop = asarray(stop, dtype=computation_dtype) # follow the numpy geomspace convention for negative and complex endpoints signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2 res = signflip * logspace(log10(signflip * start), log10(signflip * stop), num, endpoint=endpoint, base=10.0, - dtype=dtype, axis=0) + dtype=computation_dtype, axis=0) if axis != 0: res = moveaxis(res, 0, axis) return lax.convert_element_type(res, dtype) @@ -3104,6 +3132,7 @@ def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None): c = real_part return c + @_wraps(getattr(onp, "quantile", None)) def quantile(a, q, axis=None, out=None, overwrite_input=False, interpolation="linear", keepdims=False): @@ -3113,9 +3142,11 @@ def quantile(a, q, axis=None, out=None, overwrite_input=False, raise ValueError(msg) if interpolation != "linear": raise NotImplementedError("Only interpolation='linear' is implemented") + return _quantile(a, q, axis, keepdims) +@partial(jit, static_argnums=(2, 3)) +def _quantile(a, q, axis, keepdims): a = asarray(a) - if axis is None: a = ravel(a) axis = 0 @@ -3128,11 +3159,15 @@ def quantile(a, q, axis=None, out=None, overwrite_input=False, if q_ndim > 1: raise ValueError("q must be have rank <= 1, got shape {}".format(shape(q))) - a, q = _promote_dtypes(a, q) - if not issubdtype(a.dtype, floating): + q = asarray(q) + + if not issubdtype(a.dtype, floating) or not issubdtype(q.dtype, floating): msg = "q and a arguments to quantile must be of float type, got {} and {}" raise TypeError(msg.format(a.dtype, q.dtype)) + # Promote q to at least float32 for precise interpolation. + q = lax.convert_element_type(q, promote_types(q.dtype, float32)) + a_shape = shape(a) a = lax.sort(a, dimension=axis) @@ -3168,8 +3203,9 @@ def quantile(a, q, axis=None, out=None, overwrite_input=False, broadcast_dimensions=(0,)) high_weight = lax.broadcast_in_dim(high_weight, high_value.shape, broadcast_dimensions=(0,)) - return lax.add(lax.mul(low_value, low_weight), - lax.mul(high_value, high_weight)) + return lax.convert_element_type( + lax.add(lax.mul(low_value.astype(q.dtype), low_weight), + lax.mul(high_value.astype(q.dtype), high_weight)), a.dtype) @_wraps(onp.percentile) diff --git a/jax/test_util.py b/jax/test_util.py index 6834c804fa47..84a8746f4149 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -67,7 +67,7 @@ def is_sequence(x): else: return True -default_tolerance = { +_default_tolerance = { onp.dtype(onp.bool_): 0, onp.dtype(onp.int8): 0, onp.dtype(onp.int16): 0, @@ -77,6 +77,7 @@ def is_sequence(x): onp.dtype(onp.uint16): 0, onp.dtype(onp.uint32): 0, onp.dtype(onp.uint64): 0, + onp.dtype(dtypes.bfloat16): 1e-2, onp.dtype(onp.float16): 1e-3, onp.dtype(onp.float32): 1e-6, onp.dtype(onp.float64): 1e-15, @@ -84,11 +85,16 @@ def is_sequence(x): onp.dtype(onp.complex128): 1e-15, } -tpu_default_tolerance = default_tolerance.copy() -tpu_default_tolerance[onp.dtype(onp.float32)] = 1e-3 -tpu_default_tolerance[onp.dtype(onp.complex64)] = 1e-3 +def default_tolerance(): + if device_under_test() != "tpu": + return _default_tolerance + tol = _default_tolerance.copy() + tol[onp.dtype(onp.float32)] = 1e-3 + tol[onp.dtype(onp.complex64)] = 1e-3 + return tol default_gradient_tolerance = { + onp.dtype(dtypes.bfloat16): 1e-1, onp.dtype(onp.float16): 1e-2, onp.dtype(onp.float32): 2e-3, onp.dtype(onp.float64): 1e-5, @@ -96,28 +102,46 @@ def is_sequence(x): onp.dtype(onp.complex128): 1e-5, } -def _assert_numpy_eq(x, y): - onp.testing.assert_allclose(x, y) +def _assert_numpy_allclose(a, b, atol=None, rtol=None): + a = a.astype(onp.float32) if a.dtype == dtypes.bfloat16 else a + b = b.astype(onp.float32) if b.dtype == dtypes.bfloat16 else b + kw = {} + if atol: kw["atol"] = atol + if rtol: kw["rtol"] = rtol + onp.testing.assert_allclose(a, b, **kw) def tolerance(dtype, tol=None): tol = tol or {} if not isinstance(tol, dict): return tol tol = {onp.dtype(key): value for key, value in tol.items()} - default = (tpu_default_tolerance if device_under_test() == "tpu" - else default_tolerance) dtype = dtypes.canonicalize_dtype(onp.dtype(dtype)) - return tol.get(dtype, default[dtype]) + return tol.get(dtype, default_tolerance()[dtype]) + +def _normalize_tolerance(tol): + tol = tol or 0 + if isinstance(tol, dict): + return {onp.dtype(k): v for k, v in tol.items()} + else: + return {k: tol for k in _default_tolerance.keys()} + +def join_tolerance(tol1, tol2): + tol1 = _normalize_tolerance(tol1) + tol2 = _normalize_tolerance(tol2) + out = tol1 + for k, v in tol2.items(): + out[k] = max(v, tol1.get(k, 0)) + return out def _assert_numpy_close(a, b, atol=None, rtol=None): assert a.shape == b.shape atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol)) rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol)) - onp.testing.assert_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size) + _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size) def check_eq(xs, ys): - tree_all(tree_multimap(_assert_numpy_eq, xs, ys)) + tree_all(tree_multimap(_assert_numpy_allclose, xs, ys)) def check_close(xs, ys, atol=None, rtol=None): @@ -126,7 +150,8 @@ def check_close(xs, ys, atol=None, rtol=None): def inner_prod(xs, ys): - contract = lambda x, y: onp.real(onp.vdot(x, y)) + def contract(x, y): + return onp.real(onp.dot(onp.conj(x).reshape(-1), y.reshape(-1))) return tree_reduce(onp.add, tree_multimap(contract, xs, ys)) @@ -226,12 +251,12 @@ def device_under_test(): def supported_dtypes(): if device_under_test() == "tpu": return {onp.bool_, onp.int32, onp.int64, onp.uint32, onp.uint64, - onp.float32, onp.complex64} + dtypes.bfloat16, onp.float32, onp.complex64} else: return {onp.bool_, onp.int8, onp.int16, onp.int32, onp.int64, onp.uint8, onp.uint16, onp.uint32, onp.uint64, - onp.float16, onp.float32, onp.float64, onp.complex64, - onp.complex128} + dtypes.bfloat16, onp.float16, onp.float32, onp.float64, + onp.complex64, onp.complex128} def skip_on_devices(*disabled_devices): """A decorator for test methods to skip the test on certain devices.""" @@ -352,7 +377,7 @@ def rand_default(): def rand_nonzero(): - post = lambda x: onp.where(x == 0, 1, x) + post = lambda x: onp.where(x == 0, onp.array(1, dtype=x.dtype), x) randn = npr.RandomState(0).randn return partial(_rand_dtype, randn, scale=3, post=post) @@ -423,8 +448,8 @@ def rand(shape, dtype): neginf_flips = rng.rand(*dims) < 0.1 vals = base_rand(shape, dtype) - vals = onp.where(posinf_flips, onp.inf, vals) - vals = onp.where(neginf_flips, -onp.inf, vals) + vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) + vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) @@ -449,7 +474,7 @@ def rand(shape, dtype): nan_flips = rng.rand(*dims) < 0.1 vals = base_rand(shape, dtype) - vals = onp.where(nan_flips, onp.nan, vals) + vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) @@ -480,9 +505,9 @@ def rand(shape, dtype): nan_flips = rng.rand(*dims) < 0.1 vals = base_rand(shape, dtype) - vals = onp.where(posinf_flips, onp.inf, vals) - vals = onp.where(neginf_flips, -onp.inf, vals) - vals = onp.where(nan_flips, onp.nan, vals) + vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) + vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) + vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) @@ -500,7 +525,7 @@ def rand(shape, dtype): zeros = rng.rand(*dims) < 0.5 vals = base_rand(shape, dtype) - vals = onp.where(zeros, 0, vals) + vals = onp.where(zeros, onp.array(0, dtype=dtype), vals) return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) @@ -564,7 +589,7 @@ def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None): atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) - onp.testing.assert_allclose(x, y, atol=atol, rtol=rtol) + _assert_numpy_allclose(x, y, atol=atol, rtol=rtol) if check_dtypes: self.assertDtypesMatch(x, y) @@ -639,5 +664,5 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, args = args_maker() numpy_ans = numpy_reference_op(*args) lax_ans = lax_op(*args) - self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes, + self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, atol=tol, rtol=tol) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 8fd84981ac58..d4e29fb7f8ae 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -57,7 +57,7 @@ all_shapes = scalar_shapes + array_shapes float_dtypes = list(jtu.supported_dtypes().intersection( - {onp.float16, onp.float32, onp.float64})) + {lnp.bfloat16, onp.float16, onp.float32, onp.float64})) complex_dtypes = [onp.complex64, onp.complex128] int_dtypes = [onp.int32, onp.int64] unsigned_dtypes = [onp.uint32, onp.uint64] @@ -102,8 +102,9 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]), op_record("float_power", 2, inexact_dtypes, all_shapes, jtu.rand_default, ["rev"], - tolerance={onp.float32: 1e-3, onp.float64: 1e-12, - onp.complex64: 2e-4, onp.complex128: 1e-12}), + tolerance={lnp.bfloat16: 1e-2, onp.float32: 1e-3, + onp.float64: 1e-12, onp.complex64: 2e-4, + onp.complex128: 1e-12}), op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []), op_record("greater", 2, number_dtypes, all_shapes, jtu.rand_some_equal, []), op_record("greater_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, []), @@ -158,7 +159,7 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes, jtu.rand_nonzero, []), op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - tolerance={onp.float16: 1e-2}), + tolerance={lnp.bfloat16: 2e-2, onp.float16: 1e-2}), # TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32 # precision. op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [], @@ -204,7 +205,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [], tolerance={onp.float16: 1e-2}), op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("sinc", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + op_record("sinc", 1, [t for t in number_dtypes if t != lnp.bfloat16], + all_shapes, jtu.rand_default, ["rev"], tolerance={onp.complex64: 1e-5}), op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"]), @@ -228,10 +230,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, JAX_REDUCER_RECORDS = [ op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []), - op_record("prod", 1, number_dtypes, all_shapes, jtu.rand_small_positive, []), - op_record("sum", 1, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("var", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []), - op_record("std", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, []), + op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan, []), op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_some_nan, []), op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []), @@ -242,6 +242,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), + op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), + op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), ] JAX_ARGMINMAX_RECORDS = [ @@ -296,7 +298,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, numpy_version = tuple(map(int, onp.version.version.split('.'))) if numpy_version >= (1, 15): JAX_COMPOUND_OP_RECORDS += [ - op_record("isclose", 2, all_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("isclose", 2, [t for t in all_dtypes if t != lnp.bfloat16], + all_shapes, jtu.rand_small_positive, []), op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default, []), op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default, []), ] @@ -356,7 +359,7 @@ def _GetArgsMaker(self, rng, shapes, dtypes): dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "check_dtypes": rec.check_dtypes, "tol": rec.tolerance} + "check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance} for shapes in filter( _shapes_are_broadcast_compatible, CombosWithReplacement(rec.shapes, rec.nargs)) @@ -365,7 +368,7 @@ def _GetArgsMaker(self, rng, shapes, dtypes): for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS))) def testOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes, check_dtypes, - tol): + tolerance): rng = rng_factory() args_maker = self._GetArgsMaker(rng, shapes, dtypes) python_scalar = jtu.PYTHON_SCALAR_SHAPE in shapes @@ -373,6 +376,9 @@ def testOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes, check_dtypes, jtu.NUMPY_SCALAR_SHAPE in shapes or () in shapes) empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) + tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) + tol = functools.reduce(jtu.join_tolerance, + [tolerance, tol, jtu.default_tolerance()]) self._CheckAgainstNumpy( onp_op, lnp_op, args_maker, check_dtypes=check_dtypes and not scalar_arg and not empty_shape, @@ -456,29 +462,34 @@ def testBitwiseOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes): check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, - "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "axis": axis, "keepdims": keepdims} - for rec in JAX_REDUCER_RECORDS - for shape in rec.shapes for dtype in rec.dtypes - for out_dtype in [None] + rec.dtypes - for axis in set(range(-len(shape), len(shape))) | set([None]) - for keepdims in [False, True])) + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, + "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, + "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), + "axis": axis, "keepdims": keepdims} + for shape in rec.shapes for dtype in rec.dtypes + for out_dtype in [None] + rec.dtypes + for axis in set(range(-len(shape), len(shape))) | set([None]) + for keepdims in [False, True]) + for rec in JAX_REDUCER_RECORDS)) def testReducer(self, onp_op, lnp_op, rng_factory, shape, dtype, out_dtype, axis, keepdims): rng = rng_factory() - onp_fun = lambda x: onp_op(x, axis, dtype=out_dtype, keepdims=keepdims) + def onp_fun(x): + x_cast = x if dtype != lnp.bfloat16 else x.astype(onp.float32) + t = out_dtype if out_dtype != lnp.bfloat16 else onp.float32 + return onp_op(x_cast, axis, dtype=t, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] tol_spec = {onp.float16: 1e-2, onp.float32: 1e-3, onp.complex64: 1e-3, onp.float64: 1e-5, onp.complex128: 1e-5} tol = jtu.tolerance(dtype, tol_spec) tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, + check_dtypes=lnp.bfloat16 not in (dtype, out_dtype), tol=tol) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol) @@ -568,8 +579,12 @@ def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factor args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] axisa, axisb, axisc, axis = axes lnp_fun = lambda a, b: lnp.cross(a, b, axisa, axisb, axisc, axis) - onp_fun = lambda a, b: onp.cross(a, b, axisa, axisb, axisc, axis) - tol_spec = {onp.float16: 1e-2} + def onp_fun(a, b): + a = a.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else a + b = b.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else b + out = onp.cross(a, b, axisa, axisb, axisc, axis) + return out.astype(lnp.promote_types(lhs_dtype, rhs_dtype)) + tol_spec = {dtypes.bfloat16: 3e-1, onp.float16: 1e-2} tol = max(jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec)) self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, @@ -605,7 +620,11 @@ def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): onp.complex128: 1e-14} if jtu.device_under_test() == "tpu": tol[onp.float32] = tol[onp.complex64] = 2e-1 - self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True, + def onp_dot(x, y): + x = x.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else x + y = y.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else y + return onp.dot(x, y).astype(lnp.promote_types(lhs_dtype, rhs_dtype)) + self._CheckAgainstNumpy(onp_dot, lnp.dot, args_maker, check_dtypes=True, tol=tol) self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True, atol=tol, rtol=tol) @@ -633,11 +652,15 @@ def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): rng = rng_factory() + def onp_fun(x, y): + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return onp.matmul(x, y).astype(dtype) args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {onp.float16: 1e-2, onp.float32: 2e-2, onp.float64: 1e-12} + tol = {onp.float16: 1e-2, onp.float32: 2e-2, onp.float64: 1e-12, + onp.complex128: 1e-12} if jtu.device_under_test() == "tpu": tol[onp.float32] = tol[onp.complex64] = 4e-2 - self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker, + self._CheckAgainstNumpy(onp_fun, lnp.matmul, args_maker, check_dtypes=True, tol=tol) self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True, atol=tol, rtol=tol) @@ -663,7 +686,11 @@ def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_fa rng = rng_factory() args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] lnp_fun = lambda a, b: lnp.tensordot(a, b, axes) - onp_fun = lambda a, b: onp.tensordot(a, b, axes) + def onp_fun(a, b): + a = a if lhs_dtype != lnp.bfloat16 else a.astype(onp.float32) + b = b if rhs_dtype != lnp.bfloat16 else b.astype(onp.float32) + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return onp.tensordot(a, b, axes).astype(dtype) tol = {onp.float16: 1e-1, onp.float32: 1e-3, onp.float64: 1e-12, onp.complex64: 1e-3, onp.complex128: 1e-12} if jtu.device_under_test() == "tpu": @@ -688,9 +715,13 @@ def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_fa def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): rng = rng_factory() args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - onp_fun = lambda lhs, rhs: onp.inner(lhs, rhs) + def onp_fun(lhs, rhs): + lhs = lhs if lhs_dtype != lnp.bfloat16 else lhs.astype(onp.float32) + rhs = rhs if rhs_dtype != lnp.bfloat16 else rhs.astype(onp.float32) + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return onp.inner(lhs, rhs).astype(dtype) lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs) - tol_spec = {onp.float16: 1e-2, onp.float64: 1e-13} + tol_spec = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 1e-13} if jtu.device_under_test() == "tpu": tol_spec[onp.float32] = tol_spec[onp.complex64] = 2e-1 tol = max(jtu.tolerance(lhs_dtype, tol_spec), @@ -734,7 +765,7 @@ def testRoundStaticDecimals(self, shape, dtype, decimals, rng_factory): onp_fun = lambda x: onp.round(x, decimals=decimals) lnp_fun = lambda x: lnp.round(x, decimals=decimals) args_maker = lambda: [rng(shape, dtype)] - tol = {onp.float16: 1e-2} + tol = {lnp.bfloat16: 5e-2, onp.float16: 1e-2} check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=check_dtypes, tol=tol) @@ -805,23 +836,27 @@ def testTile(self, shape, dtype, reps, rng_factory): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( axis, ",".join(str(d) for d in base_shape), - ",".join(onp.dtype(dtype).name for dtype in dtypes)), - "axis": axis, "base_shape": base_shape, "dtypes": dtypes, + ",".join(onp.dtype(dtype).name for dtype in arg_dtypes)), + "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes, "rng_factory": jtu.rand_default} for num_arrs in [3] - for dtypes in CombosWithReplacement(default_dtypes, num_arrs) + for arg_dtypes in CombosWithReplacement(default_dtypes, num_arrs) for base_shape in [(4,), (3, 4), (2, 3, 4)] for axis in range(-len(base_shape)+1, len(base_shape)))) - def testConcatenate(self, axis, base_shape, dtypes, rng_factory): + def testConcatenate(self, axis, base_shape, arg_dtypes, rng_factory): rng = rng_factory() wrapped_axis = axis % len(base_shape) shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)] - onp_fun = lambda *args: onp.concatenate(args, axis=axis) + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] + def onp_fun(*args): + args = [x if x.dtype != lnp.bfloat16 else x.astype(onp.float32) + for x in args] + dtype = functools.reduce(lnp.promote_types, arg_dtypes) + return onp.concatenate(args, axis=axis).astype(dtype) lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @@ -829,22 +864,27 @@ def args_maker(): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( axis, ",".join(str(d) for d in base_shape), - ",".join(onp.dtype(dtype).name for dtype in dtypes)), - "axis": axis, "base_shape": base_shape, "dtypes": dtypes, + ",".join(onp.dtype(dtype).name for dtype in arg_dtypes)), + "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes, "rng_factory": jtu.rand_default} - for dtypes in CombosWithReplacement(default_dtypes, 2) + for arg_dtypes in CombosWithReplacement(default_dtypes, 2) for base_shape in [(4,), (3, 4), (2, 3, 4)] for axis in range(-len(base_shape)+1, len(base_shape)))) - def testAppend(self, axis, base_shape, dtypes, rng_factory): + def testAppend(self, axis, base_shape, arg_dtypes, rng_factory): rng = rng_factory() wrapped_axis = axis % len(base_shape) shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)] - onp_fun = lambda arr, values: onp.append(arr, values, axis=axis) + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] + def onp_fun(arr, values): + arr = arr.astype(onp.float32) if arr.dtype == lnp.bfloat16 else arr + values = (values.astype(onp.float32) if values.dtype == lnp.bfloat16 + else values) + out = onp.append(arr, values, axis=axis) + return out.astype(lnp.promote_types(*arg_dtypes)) lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis) def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @@ -1014,7 +1054,11 @@ def testIdentity(self, n, dtype): for offset in list(range(-4, 4)))) def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng_factory): rng = rng_factory() - onp_fun = lambda arg: onp.trace(arg, offset, axis1, axis2, out_dtype) + def onp_fun(arg): + if out_dtype == lnp.bfloat16: + return onp.trace(arg, offset, axis1, axis2, onp.float32).astype(lnp.bfloat16) + else: + return onp.trace(arg, offset, axis1, axis2, out_dtype) lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) @@ -1268,8 +1312,8 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, rng_factory): lnp_fun = lambda x, weights: lnp.average(x, axis, weights, returned) args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)] - tol = {onp.float16: 1e-1, onp.float32: 1e-3, onp.float64: 1e-10, - onp.complex64: 1e-3, onp.complex128: 1e-10} + tol = {lnp.bfloat16: 1e-1, onp.float16: 1e-1, onp.float32: 1e-3, + onp.float64: 1e-10, onp.complex64: 1e-3, onp.complex128: 1e-10} check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE try: self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, @@ -1682,7 +1726,9 @@ def args_maker(): for increasing in [False, True])) def testVander(self, shape, dtype, n, increasing, rng_factory): rng = rng_factory() - onp_fun = lambda arg: onp.vander(arg, N=n, increasing=increasing) + def onp_fun(arg): + arg = arg.astype(onp.float32) if dtype == lnp.bfloat16 else arg + return onp.vander(arg, N=n, increasing=increasing) lnp_fun = lambda arg: lnp.vander(arg, N=n, increasing=increasing) args_maker = lambda: [rng([shape], dtype)] # np.vander seems to return float64 for all floating types. We could obey @@ -1701,9 +1747,18 @@ def testVander(self, shape, dtype, n, increasing, rng_factory): def testNanToNum(self, rng_factory, shape, dtype): rng = rng_factory() dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type + def onp_fun(x): + if dtype == lnp.bfloat16: + x = onp.where(onp.isnan(x), dtype(0), x) + x = onp.where(onp.isposinf(x), lnp.finfo(dtype).max, x) + x = onp.where(onp.isneginf(x), lnp.finfo(dtype).min, x) + return x + else: + return onp.nan_to_num(x).astype(dtype) + args_maker = lambda: [rng(shape, dtype)] check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(onp.nan_to_num, lnp.nan_to_num, args_maker, + self._CheckAgainstNumpy(onp_fun, lnp.nan_to_num, args_maker, check_dtypes=check_dtypes) self._CompileAndCheck(lnp.nan_to_num, args_maker, check_dtypes=check_dtypes) @@ -1758,12 +1813,15 @@ def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype, args_maker = lambda: [a_rng(a_shape, a_dtype)] else: args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] - onp_fun = partial(getattr(onp, op), axis=axis, keepdims=keepdims) + def onp_fun(*args): + args = [x if lnp.result_type(x) != lnp.bfloat16 else + onp.asarray(x, onp.float32) for x in args] + return getattr(onp, op)(*args, axis=axis, keepdims=keepdims) lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims) # TODO(phawkins): we currently set dtype=False because we aren't as # aggressive about promoting to float64. It's not clear we want to mimic # Numpy here. - tol_spec = {onp.float16: 2e-2, onp.float32: 1e-5, onp.float64: 5e-6} + tol_spec = {onp.float32: 1e-5, onp.float64: 5e-6} tol = max(jtu.tolerance(a_dtype, tol_spec), jtu.tolerance(q_dtype, tol_spec)) self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False, @@ -1789,7 +1847,12 @@ def args_maker(): default = rng(shapes[-1], dtypes[-1]) return condlist, choicelist, default # TODO(phawkins): float32/float64 type mismatches - self._CheckAgainstNumpy(onp.select, lnp.select, args_maker, + def onp_fun(condlist, choicelist, default): + choicelist = [x if lnp.result_type(x) != lnp.bfloat16 + else x.astype(onp.float32) for x in choicelist] + dtype = lnp.result_type(default, *choicelist) + return onp.select(condlist, choicelist, default).astype(dtype) + self._CheckAgainstNumpy(onp_fun, lnp.select, args_maker, check_dtypes=False) self._CompileAndCheck(lnp.select, args_maker, check_dtypes=True) @@ -1978,7 +2041,7 @@ def testIssue956(self): "ddof": ddof, "keepdims": keepdims, "rng_factory": rng_factory} for shape in [(5,), (10, 5)] for dtype in all_dtypes - for out_dtype in number_dtypes + for out_dtype in inexact_dtypes for axis in [None, 0, -1] for ddof in [0, 1, 2] for keepdims in [False, True] @@ -1986,9 +2049,13 @@ def testIssue956(self): def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng_factory): rng = rng_factory() args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - onp_fun = partial(onp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) + def onp_fun(x): + out = onp.var(x.astype(lnp.promote_types(onp.float32, dtype)), + axis=axis, ddof=ddof, keepdims=keepdims) + return out.astype(out_dtype) lnp_fun = partial(lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) - tol = jtu.tolerance(out_dtype, {onp.float16: 1e-1}) + tol = jtu.tolerance(out_dtype, {onp.float16: 1e-1, onp.float32: 1e-3, + onp.float64: 1e-3, onp.complex128: 1e-6}) self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol, @@ -2129,7 +2196,8 @@ def testLinspace(self, start_shape, stop_shape, num, endpoint, jtu.cases_from_list( {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" "_base={}_dtype={}").format( - start_shape, stop_shape, num, endpoint, base, dtype), + start_shape, stop_shape, num, endpoint, base, + dtype.__name__ if dtype else "None"), "start_shape": start_shape, "stop_shape": stop_shape, "num": num, "endpoint": endpoint, "base": base, @@ -2139,7 +2207,7 @@ def testLinspace(self, start_shape, stop_shape, num, endpoint, for num in [0, 1, 2, 5, 20] for endpoint in [True, False] for base in [10.0, 2, onp.e] - for dtype in number_dtypes + [None,] + for dtype in inexact_dtypes + [None,] for rng_factory in [jtu.rand_default])) def testLogspace(self, start_shape, stop_shape, num, endpoint, base, dtype, rng_factory): @@ -2212,10 +2280,16 @@ def args_maker(): start, stop = args_maker() ndim = len(onp.shape(start + stop)) for axis in range(-ndim, ndim): - lnp_op = lambda start, stop: lnp.geomspace( - start, stop, num, endpoint=endpoint, dtype=dtype, axis=axis) - onp_op = lambda start, stop: onp.geomspace( - start, stop, num, endpoint=endpoint, dtype=dtype, axis=axis) + def lnp_op(start, stop): + return lnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, + axis=axis) + def onp_op(start, stop): + start = start.astype(onp.float32) if dtype == lnp.bfloat16 else start + stop = stop.astype(onp.float32) if dtype == lnp.bfloat16 else stop + return onp.geomspace( + start, stop, num, endpoint=endpoint, + dtype=dtype if dtype != lnp.bfloat16 else onp.float32, + axis=axis).astype(dtype) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=False, tol=tol) # Check dtype equivalence within expected 32bit downcasting. diff --git a/tests/lax_test.py b/tests/lax_test.py index d3f1c34b0366..7cfff54b0dd4 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -55,7 +55,7 @@ def num_float_bits(dtype): # arguments of appropriate shapes and dtypes using the following table. float_dtypes = list(jtu.supported_dtypes().intersection( - {onp.float16, onp.float32, onp.float64})) + {dtypes.bfloat16, onp.float16, onp.float32, onp.float64})) complex_dtypes = [onp.complex64, onp.complex128] inexact_dtypes = float_dtypes + complex_dtypes int_dtypes = [onp.int32, onp.int64] @@ -65,27 +65,6 @@ def num_float_bits(dtype): compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] -default_tolerance = { - onp.bool_: 0, - onp.int16: 0, - onp.int32: 0, - onp.int64: 0, - onp.float16: 1e-3, - onp.float32: 1e-6, - onp.float64: 1e-15, - onp.complex64: 1e-6, - onp.complex128: 1e-15, -} - -def tolerance(dtype, tol=None): - if not FLAGS.jax_enable_x64: - if dtype == onp.float64: - dtype = onp.float32 - elif dtype == onp.complex128: - dtype = onp.complex64 - tol = tol or {} - return tol.get(dtype, default_tolerance[dtype]) - OpRecord = collections.namedtuple( "OpRecord", ["op", "nargs", "dtypes", "rng_factory", "tol"]) @@ -671,8 +650,10 @@ def testDot(self, lhs_shape, rhs_shape, dtype, precision, rng_factory): def testDotAgainstNumpy(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory() args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - tol = {onp.float16: 1e-2, - onp.float64: max(default_tolerance[onp.float64], 1e-14)} + tol = { + onp.float16: 1e-2, + onp.float64: max(jtu.default_tolerance()[onp.dtype(onp.float64)], 1e-14) + } lax_op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckAgainstNumpy(lax_op, lax_reference.dot, args_maker, tol=tol) @@ -1740,26 +1721,6 @@ def check_grads_bilinear(f, args, order, check_grads(lambda rhs: f(lhs, rhs), (rhs,), order, modes=modes, atol=atol, rtol=rtol, eps=1.) - -default_gradient_tolerance = { - onp.float16: 1e-2, - onp.float32: 1e-6, - onp.float64: 1e-10, - onp.complex64: 1e-6, - onp.complex128: 1e-10, -} - -def gradient_tolerance(dtype, tol=None): - if dtype == onp.complex64: - dtype = onp.float32 - elif dtype == onp.complex128: - dtype = onp.float64 - if not FLAGS.jax_enable_x64 and dtype == onp.float64: - dtype = onp.float32 - tol = tol or {} - return tol.get(dtype, default_gradient_tolerance[dtype]) - - class LaxAutodiffTest(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( @@ -1798,7 +1759,8 @@ def testOpGradSpecialValue(self, op, special_value, tol): for rng_factory in [jtu.rand_default])) def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): rng = rng_factory() - tol = max(gradient_tolerance(to_dtype), gradient_tolerance(from_dtype)) + tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), + jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) @@ -1815,15 +1777,16 @@ def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): [(), (2, 3), ()], [(2, 3), (2, 3), (2, 3)], ] - for dtype in float_dtypes + # TODO(phawkins): this test fails for bfloat16. + for dtype in [t for t in float_dtypes if t != dtypes.bfloat16] for rng_factory in [jtu.rand_default])) def testClampGrad(self, min_shape, operand_shape, max_shape, dtype, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype, {onp.float16: 1e-1, onp.float32: 1e-2}) + tol = {dtypes.bfloat16: 1e-1, onp.float16: 1e-1, onp.float32: 1e-2} shapes = [min_shape, operand_shape, max_shape] min, operand, max = (rng(shape, dtype) for shape in shapes) min, max = onp.minimum(min, max), onp.maximum(min, max) # broadcast - eps = 1e-1 if dtype == onp.float16 else 1e-2 + eps = 1e-1 if dtypes.finfo(dtype).bits == 16 else 1e-2 check_grads(lax.clamp, (min, operand, max), 2, ["fwd", "rev"], tol, tol, eps=eps) @@ -1840,12 +1803,11 @@ def testClampGrad(self, min_shape, operand_shape, max_shape, dtype, rng_factory) for rng_factory in [jtu.rand_default])) def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] operands = tuple(rng(shape, dtype) for shape in shapes) concatenate = lambda *args: lax.concatenate(args, dim) - check_grads(concatenate, operands, 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -1944,7 +1906,7 @@ def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype, {onp.float16: 5e-1, onp.float32: 1e-4}) + tol = {dtypes.bfloat16: 3e-1, onp.float16: 5e-1, onp.float32: 1e-4} # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms @@ -1974,7 +1936,7 @@ def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, for dtype in float_dtypes)) def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype, {onp.float16: 1e-1, onp.float32: 1e-4}) + tol = {onp.float16: 1e-1, onp.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) @@ -2004,13 +1966,11 @@ def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) - check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"], - atol=tol, rtol=tol) + check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) @@ -2028,10 +1988,9 @@ def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, for rng_factory in [jtu.rand_default])) def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) - check_grads(broadcast, args, 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( @@ -2049,11 +2008,9 @@ def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): for rng_factory in [jtu.rand_default])) def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) operand = rng(inshape, dtype) broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) - check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], tol, tol, - eps=1.) + check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_perm={}".format( @@ -2077,10 +2034,9 @@ def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_facto for rng_factory in [jtu.rand_default])) def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) - check_grads(reshape, (operand,), 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}" @@ -2091,17 +2047,14 @@ def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory) for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]])) def testPadGrad(self, shape, dtype, pads, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) - operand = rng(shape, dtype) pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) - check_grads(pad, (operand,), 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) operand = rng(shape, dtype) padding_value = onp.array(0., dtype) pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) - check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], tol, tol, - eps=1.) + check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.) def testReverseGrad(self): rev = lambda operand: lax.rev(operand, dimensions) @@ -2125,13 +2078,11 @@ def testReverseGrad(self): for rng_factory in [jtu.rand_default])) def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) pred = rng(pred_shape, onp.bool_) on_true = rng(arg_shape, dtype) on_false = rng(arg_shape, dtype) select = lambda on_true, on_false: lax.select(pred, on_true, on_false) - check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], tol, tol, - eps=1.) + check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -2155,10 +2106,9 @@ def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory): for rng_factory in [jtu.rand_default])) def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) operand = rng(shape, dtype) slice = lambda x: lax.slice(x, starts, limits, strides) - check_grads(slice, (operand,), 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( @@ -2176,10 +2126,9 @@ def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory): def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) operand = rng(shape, dtype) dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) - check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( @@ -2197,19 +2146,18 @@ def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices, def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) operand = rng(shape, dtype) update = rng(update_shape, dtype) start_indices = onp.array(start_indices) dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) - check_grads(dus, (operand, update), 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) - check_grads(dus, (operand,), 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) - check_grads(dus, (update,), 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}".format( @@ -2225,10 +2173,9 @@ def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, for rng_factory in [jtu.rand_default])) def testTransposeGrad(self, shape, dtype, perm, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) - check_grads(transpose, (operand,), 2, ["fwd", "rev"], tol, tol, eps=1.) + check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}" @@ -2241,7 +2188,8 @@ def testTransposeGrad(self, shape, dtype, perm, rng_factory): (-onp.inf, lax.max, [t for t in inexact_dtypes if t != onp.float16]), (onp.inf, lax.min, [t for t in inexact_dtypes if t != onp.float16]), # The mul test overflows the range of a float16. - (1, lax.mul, [t for t in inexact_dtypes if t != onp.float16]), + (1, lax.mul, [t for t in inexact_dtypes + if t not in (onp.float16, dtypes.bfloat16)]), ] for dtype in dtypes for shape, dims in [ @@ -2257,12 +2205,13 @@ def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory() if jtu.device_under_test() == "tpu" and op is lax.mul: raise SkipTest("unimplemented case") - tol = gradient_tolerance( - dtype, {onp.float16: 1e-1, onp.float32: 4e-2, onp.float64: 1e-3}) + tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 4e-2, + onp.float64: 1e-3, onp.complex64: 1e-2} operand = rng(shape, dtype) init_val = onp.asarray(init_val, dtype=dtype) reduce = lambda operand: lax.reduce(operand, init_val, op, dims) eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else + 1e-1 if dtype == dtypes.bfloat16 else 1e-2 if dtypes.finfo(dtype).bits == 32 else None) check_grads(reduce, (operand,), 1, ["fwd", "rev"], tol, tol, eps) @@ -2281,7 +2230,7 @@ def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): for rng_factory in [jtu.rand_default])) def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype, {onp.float16: 1e-1, onp.float32: 1e-3}) + tol = {onp.float16: 1e-1, onp.float32: 1e-3} init_val = onp.asarray(init_val, dtype=dtype) # We need this conditional and the corresponding loop logic to be in the @@ -2333,10 +2282,9 @@ def fun(operand): for rng_factory in [jtu.rand_default])) def testSortGrad(self, shape, dtype, axis, rng_factory): rng = rng_factory() - tol = gradient_tolerance(dtype, {onp.float32: 1e-3}) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, axis) - check_grads(sort, (operand,), 2, ["fwd", "rev"], tol, tol, eps=1e-2) + check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( @@ -2384,7 +2332,7 @@ def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory): rng = rng_factory() src = rng(shape, dtype) index_take = lambda src: lax.index_take(src, idxs, axes) - check_grads(index_take, (src,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1) + check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( @@ -2563,8 +2511,7 @@ def _CheckBatching(self, op, bdim_size, bdims, shapes, dtype, rng, def testOp(self, op_name, rng_factory, shapes, dtype, bdims): rng = rng_factory() op = getattr(lax, op_name) - tol = tolerance(dtype) - self._CheckBatching(op, 10, bdims, shapes, dtype, rng, tol, tol) + self._CheckBatching(op, 10, bdims, shapes, dtype, rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -2694,7 +2641,7 @@ def testDot(self, lhs_shape, rhs_shape, dtype, bdims, rng_factory): rng = rng_factory() op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), dtype, rng, - rtol=tolerance(dtype, {onp.float16: 5e-2})) + rtol={onp.float16: 5e-2}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name":