Skip to content

Commit

Permalink
Add support for higher derivatives of reduce-window-min/max at reduce…
Browse files Browse the repository at this point in the history
…d precision. On CPU/GPU this means support for float64 derivatives, and on TPU this means support for float32 derivatives.

Warn if we are forced to be imprecise.
  • Loading branch information
hawkinsp committed Jun 29, 2019
1 parent acda3f3 commit db36909
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 19 deletions.
66 changes: 51 additions & 15 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3612,11 +3612,29 @@ def _select_and_gather_add_shape_rule(
64: onp.uint64,
}

def _select_and_gather_add_pair_reducer(dtype, select_prim):
bits = onp.finfo(dtype).bits
_float_bitwidths = {
xla_client.PrimitiveType.BF16: 16,
xla_client.PrimitiveType.F16: 16,
xla_client.PrimitiveType.F32: 32,
xla_client.PrimitiveType.F64: 64,
}

_select_and_gather_add_reduction_types = {
xla_client.PrimitiveType.BF16: xla_client.PrimitiveType.BF16,
xla_client.PrimitiveType.F16: xla_client.PrimitiveType.F16,
xla_client.PrimitiveType.F32: xla_client.PrimitiveType.F32,
xla_client.PrimitiveType.F64: xla_client.PrimitiveType.F32,
}

_select_and_gather_add_tpu_reduction_types = {
xla_client.PrimitiveType.BF16: xla_client.PrimitiveType.BF16,
xla_client.PrimitiveType.F32: xla_client.PrimitiveType.BF16,
}

def _select_and_gather_add_pair_reducer(etype, select_prim):
bits = _float_bitwidths[etype]
pair_uint_dtype = _UINT_DTYPES[bits * 2]
uint_etype = xla_bridge.dtype_to_etype_exact(_UINT_DTYPES[bits])
etype = xla_bridge.dtype_to_etype_exact(dtype)

c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
x = c.ParameterWithShape(
Expand All @@ -3639,41 +3657,56 @@ def fst(t):

def _select_and_gather_add_translation(
c, tangents, operand, select_prim, window_dimensions, window_strides,
padding):
padding, reduction_types=None):
reduction_types = reduction_types or _select_and_gather_add_reduction_types
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
# we implement a pair-wise ReduceWindow by packing two k-bit values into
# 2k-bit unsigned integer using bit tricks. This will only work for <= 32-bit
# inputs (since we don't have 128-bit integer types).
dtype = c.GetShape(operand).numpy_dtype()
bits = onp.finfo(dtype).bits
if bits > 32:
raise NotImplementedError(
"select_and_gather_add is not implemented for type larger than 32 bits")
etype = xla_bridge.dtype_to_etype(dtype)
uint_etype = xla_bridge.dtype_to_etype(_UINT_DTYPES[bits])
shape = c.GetShape(operand)
etype = shape.xla_element_type()
reduction_etype = reduction_types.get(etype, None)
if reduction_etype is None:
msg = "Unsupported type for select_and_gather_add: {}"
raise ValueError(msg.format(etype))

if reduction_etype != etype:
warnings.warn("Using reduced precision for gradient of reduce-window "
"min/max operator. This is likely from a second or "
"higher derivative of a max-pooling operation and is to work"
"around a missing XLA feature.")

bits = _float_bitwidths[reduction_etype]
uint_etype = xla_bridge.dtype_to_etype_exact(_UINT_DTYPES[bits])
pair_uint_dtype = _UINT_DTYPES[bits * 2]
pair_uint_etype = xla_bridge.dtype_to_etype_exact(pair_uint_dtype)

operand = c.ConvertElementType(operand, reduction_etype)
operand = c.BitcastConvertType(operand, uint_etype)
tangents = c.BitcastConvertType(tangents, uint_etype)
operand = c.ConvertElementType(operand, pair_uint_etype)
tangents = c.ConvertElementType(tangents, reduction_etype)
tangents = c.BitcastConvertType(tangents, uint_etype)
tangents = c.ConvertElementType(tangents, pair_uint_etype)
operand = c.ShiftLeft(
operand, c.Constant(pair_uint_dtype(bits), canonicalize_types=False))

assert select_prim is ge_p or select_prim is le_p
init = -onp.inf if select_prim is ge_p else onp.inf
init = c.BitcastConvertType(c.Constant(dtype.type(init)), uint_etype)
init = c.Constant(shape.numpy_dtype().type(init))
init = c.ConvertElementType(init, reduction_etype)
init = c.BitcastConvertType(init, uint_etype)
init = c.ConvertElementType(init, pair_uint_etype)
init = c.ShiftLeft(
init, c.Constant(pair_uint_dtype(bits), canonicalize_types=False))

xla_computation = _select_and_gather_add_pair_reducer(dtype, select_prim)
xla_computation = _select_and_gather_add_pair_reducer(reduction_etype,
select_prim)
out = c.ReduceWindow(c.Or(operand, tangents), init,
xla_computation, window_dimensions, window_strides,
padding)
out = c.ConvertElementType(out, uint_etype)
return c.BitcastConvertType(out, etype)
out = c.BitcastConvertType(out, reduction_etype)
return c.ConvertElementType(out, etype)

def _select_and_gather_add_jvp(
primals, tangents, select_prim, window_dimensions, window_strides,
Expand Down Expand Up @@ -3706,6 +3739,9 @@ def _select_and_gather_add_transpose(
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
ad.primitive_transposes[select_and_gather_add_p] = \
_select_and_gather_add_transpose
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
_select_and_gather_add_translation,
reduction_types=_select_and_gather_add_tpu_reduction_types)


sort_shape = lambda operand, dimension: operand.shape
Expand Down
8 changes: 4 additions & 4 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2029,9 +2029,9 @@ def testReduceGrad(self, op, init_val, shape, dtype, dims, rng):
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
"rng": rng}
for init_val, op, dtypes, rng in [
(0, lax.add, [onp.float32], jtu.rand_small()),
(-onp.inf, lax.max, [onp.float32], jtu.rand_default()),
(onp.inf, lax.min, [onp.float32], jtu.rand_default()),
(0, lax.add, float_dtypes, jtu.rand_small()),
(-onp.inf, lax.max, float_dtypes, jtu.rand_default()),
(onp.inf, lax.min, float_dtypes, jtu.rand_default()),
]
for dtype in dtypes
for padding in ["VALID", "SAME"]
Expand All @@ -2045,7 +2045,7 @@ def testReduceWindowGrad(self, op, init_val, dtype, padding, rng):
# TODO(b/31565929): enable when fixed.
if FLAGS.jax_test_dut == "tpu" and op is not lax.add:
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
test_gradients = False # TODO(b/73062247): need variadic reduce-window.
test_gradients = True # TODO(b/73062247): need variadic reduce-window.
else:
all_configs = itertools.chain(
itertools.product(
Expand Down

0 comments on commit db36909

Please sign in to comment.