Skip to content

Commit

Permalink
Add lax.reduce_precision()
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 5, 2021
1 parent 3c1ee06 commit 33fde77
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Operators
real
reciprocal
reduce
reduce_precision
reduce_window
reshape
rem
Expand Down
28 changes: 28 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5449,6 +5449,34 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
_reduce_window_batch_rule, _reduce_window_min)


def _reduce_precision_shape_rule(operand, *, exponent_bits, mantissa_bits):
exponent_bits = operator.index(exponent_bits)
mantissa_bits = operator.index(mantissa_bits)
if exponent_bits < 1:
raise ValueError(f"reduce_precision: exponent_bits must be positive; got {exponent_bits}")
if mantissa_bits < 0:
raise ValueError(f"reduce_precision: mantissa_bits must be non-negative; got {mantissa_bits}")
return operand.shape


reduce_precision_p = standard_primitive(
_reduce_precision_shape_rule,
partial(unop_dtype_rule, _identity, _float, 'reduce_precision'),
name='reduce_precision')


def reduce_precision(operand, exponent_bits, mantissa_bits):
"""Wraps XLA's `ReducePrecision
<https://www.tensorflow.org/xla/operation_semantics#reduceprecision>`_
operator.
"""
exponent_bits = core.concrete_or_error(
operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision")
mantissa_bits = core.concrete_or_error(
operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision")
return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits)


def _select_and_scatter_shape_rule(
operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
scatter_consts, window_dimensions, window_strides, padding):
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):

"igamma_grad_a",
"random_gamma_grad",
"reduce_precision",

# Not high priority?
"after_all", "all_to_all", "create_token",
Expand Down
2 changes: 2 additions & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@
reduce_min_p,
reduce_or_p,
reduce_p,
reduce_precision,
reduce_precision_p,
reduce_prod_p,
reduce_sum_p,
reduce_window,
Expand Down
19 changes: 19 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,25 @@ def np_fun(x):
self._CompileAndCheck(fun, args_maker)
self._CheckAgainstNumpy(np_fun, fun, args_maker)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_out_dtype={}".format(
jtu.format_shape_dtype_string(shape, dtype),
jtu.format_shape_dtype_string(shape, out_dtype)),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype}
for shape in [(), (3,), (3, 4)]
for dtype in float_dtypes
for out_dtype in float_dtypes))
def testReducePrecision(self, shape, dtype, out_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
info = dtypes.finfo(out_dtype)
fun = lambda x: lax.reduce_precision(x, info.nexp, info.nmant)
np_fun = lambda x: np.asarray(x).astype(out_dtype).astype(dtype)
self._CheckAgainstNumpy(np_fun, fun, args_maker)
self._CompileAndCheck(fun, args_maker)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
Expand Down

0 comments on commit 33fde77

Please sign in to comment.