Skip to content

Commit

Permalink
address reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 25, 2020
1 parent 50cb604 commit ebd51e1
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ def _batched_reduction_collective2(
prim, if_mapped, if_unmapped, frame, vals_in, dims_in, axis_name,
axis_index_groups):
assert not prim.multiple_results # cf. _batched_reduction_collective
if axis_index_groups is not None:
raise NotImplementedError("axis_index_groups not supported in vmap collectives. "
"Please open a feature request!")
(v,), (d,) = vals_in, dims_in
val_out = (if_mapped(v, d) if d is not batching.not_mapped
else if_unmapped(v, frame.size))
Expand Down Expand Up @@ -538,6 +541,7 @@ def _ppermute_transpose_rule(t, perm, axis_name):

def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm):
assert len(perm) == frame.size, "Permutation doesn't match the axis size!"
assert axis_name == frame.name, "ppermute batcher called with wrong axis name"
(v,), (d,) = vals_in, dims_in
assert d is not batching.not_mapped
perm_indices = [None] * frame.size
Expand Down

0 comments on commit ebd51e1

Please sign in to comment.