Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 25, 2020
1 parent 3053f4b commit 0965bc4
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
n = len(dtype_args)
if is_complex and prim is lax.add_p:
# TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
# special case because it's not currently handled by XLA:GPU or XLA:CPU
# special case because it's not currently handled by XLA:GPU
dtype_args = ([xops.Real(x) for x in dtype_args] +
[xops.Imag(x) for x in dtype_args])
scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
Expand All @@ -422,23 +422,22 @@ def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
# tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior.
def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env,
axis_index_groups, platform):
def _translate(val):
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
dtype = c.get_shape(val).numpy_dtype()
scalar = ShapedArray((), dtype)
def all_reduce(x):
replica_groups_protos = xc.make_replica_groups(
_replica_groups(axis_env, axis_name, axis_index_groups))
scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
computation = xla.primitive_subcomputation(prim, scalar, scalar)
replica_groups_protos = xc.make_replica_groups(replica_groups)
all_reduce = lambda x: xops.AllReduce(x, computation, replica_groups_protos,
None, None)
return xops.AllReduce(x, computation, replica_groups_protos, None, None)

if dtypes.issubdtype(dtype, np.complexfloating) and prim is lax.add_p:
# TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
# special case because it's not currently handled by XLA:GPU or XLA:CPU
return xops.Complex(all_reduce(xops.Real(val)),
all_reduce(xops.Imag(val)))
else:
return all_reduce(val)
return xops.Tuple(c, list(map(_translate, args)))
if prim is not lax.add_p:
outs = [all_reduce(x) for x in args]
else:
# TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
# special case because it's not currently handled by XLA:GPU
outs = [xops.Complex(all_reduce(xops.Real(x)), all_reduce(xops.Imag(x)))
if dtypes.issubdtype(c.get_shape(x).numpy_dtype(), np.complexfloating)
else all_reduce(x) for x in args]
return xops.Tuple(c, outs)

def _psum_transpose_rule(cts, axis_name, axis_index_groups):
nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
Expand Down

0 comments on commit 0965bc4

Please sign in to comment.