Skip to content

Commit

Permalink
[dynamic-shapes] add basic vmap-of-indexing support
Browse files Browse the repository at this point in the history
The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.

I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
  • Loading branch information
mattjj committed Sep 9, 2022
1 parent 49672cd commit 5882650
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 57 deletions.
3 changes: 3 additions & 0 deletions jax/_src/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __repr__(self): return f'IreeBuffer({np.asarray(self)})'
def _value(self):
return np.asarray(self)

def copy_to_host_async(self):
return self

class IreeExecutable:

def __init__(self, client, devices, module_object, function_name):
Expand Down
28 changes: 24 additions & 4 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,14 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
Returns:
An array containing the concatenation.
"""
from jax.experimental import array

if len(operands) == 0:
raise ValueError("concatenate requires a non-empty sequences of arrays")
if len(operands) == 1:
op, = operands
if isinstance(op, (core.Tracer, device_array.DeviceArray, array.Array)):
return op
return concatenate_p.bind(*operands, dimension=dimension)


Expand Down Expand Up @@ -2249,8 +2255,12 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp))

def _compare_lower_mhlo(direction: str, ctx, x, y):
x_aval, y_aval = ctx.avals_in
aval_out, = ctx.avals_out
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
if config.jax_dynamic_shapes:
substitute = partial(_substitute_axis_sizes_in_aval, ctx.axis_size_env)
avals_in = map(substitute, avals_in)
aval_out = substitute(aval_out)
x_aval, y_aval = avals_in
x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), ctx.avals_in,
(x, y))
if dtypes.issubdtype(x_aval.dtype, np.inexact):
Expand Down Expand Up @@ -2757,8 +2767,9 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
[None] * len(dyn_shape))

def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *dyn_shape, shape,
broadcast_dimensions):
if dyn_shape: raise NotImplementedError # TODO(mattjj)
operand, = batched_args
bdim, = batch_dims
new_operand = batching.moveaxis(operand, bdim, 0)
Expand Down Expand Up @@ -3157,7 +3168,16 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
def _squeeze_lower(ctx, operand, *, dimensions):
del dimensions # Implied by the output aval.
aval_out, = ctx.avals_out
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results
if config.jax_dynamic_shapes:
substitute = partial(_substitute_axis_sizes_in_aval, ctx.axis_size_env)
aval_out = substitute(aval_out)
if any(isinstance(d, ir.Value) for d in aval_out.shape):
return mhlo.DynamicReshapeOp(
mlir.aval_to_ir_type(aval_out), operand,
mlir.shape_tensor(aval_out.shape),
).results
else:
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results

mlir.register_lowering(squeeze_p, _squeeze_lower)

Expand Down
44 changes: 3 additions & 41 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def gather(operand: Array, start_indices: Array,
fill_value=fill_value)



class ScatterDimensionNumbers(NamedTuple):
"""
Describes the dimension number arguments to an `XLA's Scatter operator
Expand Down Expand Up @@ -1161,7 +1160,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
expanded_indices_shape.pop(index_vector_dim)
indices_shape = iter(expanded_indices_shape)

slice_sizes = iter(np.delete(slice_sizes, collapsed_slice_dims))
slice_sizes = (s for i, s in enumerate(slice_sizes)
if i not in collapsed_slice_dims)
return tuple(next(slice_sizes) if i in offset_dims
else next(indices_shape) for i in range(output_shape_rank))

Expand Down Expand Up @@ -1250,7 +1250,7 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,

elif operand_bdim is None and indices_bdim is not None:
indices = batching.moveaxis(indices, indices_bdim, 0)
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
offset_dims = tuple(1 + d for d in dimension_numbers.offset_dims)
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
Expand Down Expand Up @@ -1308,12 +1308,10 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
_gather_shape_rule, _gather_dtype_rule, 'gather',
weak_type_rule=_argnum_weak_type(0))
ad.defjvp(gather_p, _gather_jvp_rule, None)

ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule



def _gather_lower(ctx, operand, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
Expand Down Expand Up @@ -2054,39 +2052,3 @@ def _dynamic_slice_indices(operand, start_indices: Any):
d = lax.convert_element_type(core.dimension_as_value(d), _dtype(i))
result.append(lax.select(i < 0, i + d, i))
return result


# TODO(mattjj): getslice is a prototype for dynamic shapes, revise or remove it
def _getslice(x, lo, hi):
return getslice_p.bind(x, lo, hi)

getslice_p = core.Primitive('getslice')

@getslice_p.def_impl
def getslice_impl(x, lo, hi):
return x[lo:hi]

def _getslice_staging_rule(trace, x, lo, hi):
size = lax.make_bint(lax.clamp(0, hi - lo, x.shape[0]), x.shape[0])
aval = core.DShapedArray((size,), x.dtype, x.weak_type)
source_info = source_info_util.current()
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
invars = map(trace.getvar, [x, lo, hi])
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
getslice_p, {}, source_info)
trace.frame.eqns.append(eqn)
return out_tracer
pe.custom_staging_rules[getslice_p] = _getslice_staging_rule

def _getslice_padding_rule(in_avals, out_avals, x, lo, hi):
xx = lax.concatenate([x, x], 0)
return [dynamic_slice_in_dim(xx, lo, x.shape[0])]
pe.padding_rules[getslice_p] = _getslice_padding_rule

def _getslice_lower(ctx, x, lo, hi):
aval_out, = ctx.avals_out
return mhlo.RealDynamicSliceOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1])
).results
mlir.register_lowering(getslice_p, _getslice_lower)
17 changes: 9 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator)
from jax._src.lax import lax as lax_internal
from jax._src.lax.slicing import _getslice
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.reductions import ( # noqa: F401
_ensure_optional_axes, _reduction_dims,
Expand Down Expand Up @@ -3620,14 +3619,16 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
(start, stop, step) != (0, n, 1)):
return lax.slice_in_dim(arr, start, stop, step)


# TODO(mattjj,dougalm): expand dynamic shape indexing support
if (jax.config.jax_dynamic_shapes and type(idx) is slice and idx.step is None
and (isinstance(idx.start, core.Tracer) or isinstance(idx.stop, core.Tracer))
and arr.shape):
start = 0 if idx.start is None else idx.start
stop = arr.shape[0] if idx.stop is None else idx.stop
return _getslice(arr, start, stop)
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
try: aval = core.get_aval(idx)
except: pass
else:
if (isinstance(aval, core.DShapedArray) and aval.shape == () and
dtypes.issubdtype(aval.dtype, np.integer) and
not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
isinstance(arr.shape[0], int)):
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)

treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
Expand Down
5 changes: 3 additions & 2 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,6 @@ def as_value(self, d: DimSize):
def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]:
if isinstance(dim, Tracer) and not config.jax_dynamic_shapes:
return None
# TODO: look up DynamicJaxprTracer
return _SPECIAL_DIMENSION_HANDLERS.get(type(dim))

def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple[DimSize, ...]]:
Expand Down Expand Up @@ -1801,7 +1800,9 @@ def dimension_as_value(d: DimSize):
return handler.as_value(*ds)

def _canonicalize_dimension(dim: DimSize) -> DimSize:
if is_special_dim_size(dim):
if isinstance(dim, Tracer) and config.jax_dynamic_shapes:
return dim
elif is_special_dim_size(dim):
return dim
else:
return operator.index(dim)
Expand Down
32 changes: 30 additions & 2 deletions tests/dynamic_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def f(n):
self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False)
self.assertEqual(count, 1)

@jtu.skip_on_devices('iree') # TODO(mattjj): update getslice, no bints
@unittest.skip("revising slicing logic")
def test_slicing_basic(self):
f = jax.jit(lambda x, n: jnp.sum(x[:n]))
# TODO(mattjj): revise getslice, add typecheck rule for it, enable checks
Expand All @@ -765,7 +765,7 @@ def test_slicing_basic(self):

# TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize
# operation 'mhlo.while' that was explicitly marked illegal"
@jtu.skip_on_devices('iree')
@unittest.skip("revising slicing logic")
def test_scan_basic(self):
def cumsum(x):
def body(i, _):
Expand Down Expand Up @@ -1299,6 +1299,34 @@ def foo(x, y):
f, = jaxpr.outvars
self.assertEqual(f.aval.shape, (a,))

def test_vmap_of_indexing_basic(self):
x = jnp.arange(3.)

def f(idxs):
return jax.vmap(lambda i: x[i])(idxs)

idxs = jnp.arange(3)
jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(idxs).jaxpr
# { lambda a:f32[3]; b:i32[] c:i32[b]. let
# d:bool[b] = lt c 0
# e:i32[b] = add c 3
# f:i32[b] = select_n d c e
# g:i32[b,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None, 1)] f b
# h:f32[b,1] = gather[
# dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,))
# fill_value=None
# indices_are_sorted=False
# mode=GatherScatterMode.PROMISE_IN_BOUNDS
# slice_sizes=(1,)
# unique_indices=False
# ] a g
# i:f32[b] = squeeze[dimensions=(1,)] h
# in (i,) }
b, _ = jaxpr.invars
e, = (e for e in jaxpr.eqns if str(e.primitive) == 'gather')
h, = e.outvars
self.assertEqual(h.aval.shape, (b, 1))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 5882650

Please sign in to comment.