Skip to content

Commit

Permalink
[pmap] Add support for nested pmaps on multihost platforms via axis_s…
Browse files Browse the repository at this point in the history
…ize (jax-ml#2002)

One issue with nested pmaps on multihost platforms is inferring the global
pmap axis size without communication. This commit sidesteps the issue by adding
an `axis_size` argument to manually provide this information.

This change only enables a single cross-host pmap; all inner pmaps must be
single-host.

Addressing: jax-ml#1753
  • Loading branch information
trevorcai authored and skye committed Jan 15, 2020
1 parent a5644ed commit 12975bb
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 31 deletions.
28 changes: 21 additions & 7 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def _flatten_axes(treedef, axis_tree):
return axes


def pmap(fun, axis_name=None, devices=None, backend=None):
def pmap(fun, axis_name=None, devices=None, backend=None, axis_size=None):
"""Parallel map with support for collectives.
The purpose of ``pmap`` is to express single-program multiple-data (SPMD)
Expand Down Expand Up @@ -868,16 +868,28 @@ def pmap(fun, axis_name=None, devices=None, backend=None):
_check_callable(fun)
axis_name = _TempAxisName(fun) if axis_name is None else axis_name

# axis_size is an optional integer representing the global axis size.
# The aggregate size (across all hosts) size of the mapped axis must match
# the given value. This argument is mutually exclusive with ``devices``.
if axis_size is not None and devices is not None:
msg = "pmap got devices and axis_size. They're mutually exclusive."
raise ValueError(msg)

@wraps(fun)
def f_pmapped(*args, **kwargs):
f = lu.wrap_init(fun)
args, in_tree = tree_flatten((args, kwargs))
axis_size = _pmap_axis_size(args)
local_axis_size = _pmap_axis_size(args)
_check_args(args)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = pxla.xla_pmap(flat_fun, *args, axis_name=axis_name, axis_size=axis_size,
devices=tuple(devices) if devices is not None else devices,
backend=backend)
out = pxla.xla_pmap(
flat_fun,
*args,
axis_name=axis_name,
axis_size=local_axis_size,
global_axis_size=axis_size,
devices=tuple(devices) if devices is not None else devices,
backend=backend)
return tree_unflatten(out_tree(), out)

namestr = "pmap({}, axis_name={})".format
Expand Down Expand Up @@ -932,7 +944,8 @@ def f_pmapped(*args, **kwargs):
soft_mapped_fun = pxla.split_axis(flat_fun, axis_name, chunk_size)
reshaped_outs = pxla.xla_pmap(soft_mapped_fun, *reshaped_args,
axis_name=axis_name, axis_size=num_chunks,
devices=None, backend=backend)
global_axis_size=None, devices=None,
backend=backend)
outs = [_reshape_merge(out) for out in reshaped_outs]
return tree_unflatten(out_tree(), outs)

Expand Down Expand Up @@ -990,7 +1003,8 @@ def pfun(*args):
f, out_axes = parallel.papply_transform(f, axis_name, axis_size)
f = pxla.split_axis(f, axis_name, chunk_size)
outs = pxla.xla_pmap(f, *reshaped_args, axis_name=axis_name,
axis_size=num_chunks, devices=None)
axis_size=num_chunks, global_axis_size=None,
devices=None, backend=None)
outs = map(_reshape_merge, outs)
outs = [batching.matchaxis(axis_size, 0, dst, x)
for dst, x in zip(out_axes(), outs)]
Expand Down
60 changes: 36 additions & 24 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,34 +408,47 @@ def __getitem__(self, idx):
def xla_pmap_impl(fun, *args, **params):
axis_name = params.pop('axis_name')
axis_size = params.pop('axis_size')
global_axis_size = params.pop('global_axis_size')
devices = params.pop('devices')
backend = params.pop('backend', None)
backend = params.pop('backend')
assert not params

abstract_args = map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size, devices,
*abstract_args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
global_axis_size, devices, *abstract_args)
return compiled_fun(*args)

@lu.cache
def parallel_callable(fun, backend, axis_name, axis_size, devices, *avals):
def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
devices, *avals):
if devices is not None and len(devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

# Determine global_axis_size for use in AxisEnv.
if devices:
assert global_axis_size is None # Checked in api.py
global_axis_size = len(devices)
elif xb.host_count() > 1:
# TODO(skye): relax this constraint or provide functionality for
# automatically passing appropriate `devices`.
if axis_size != xb.local_device_count():
raise ValueError(
"On multi-host platforms, the input to pmapped functions must have "
"leading axis size equal to the number of local devices if no "
"`devices` argument is specified. Got axis_size=%d, "
"num_local_devices=%d" % (axis_size, xb.local_device_count()))
global_axis_size = xb.device_count()
if global_axis_size is None:
# TODO(skye): relax this constraint or provide functionality for
# automatically passing appropriate `devices`.
# TODO(trevorcai): This check forces us to provide global_axis_size for
# all pmaps on pmap-on-pod. Can we do it after tracing?
if axis_size != xb.local_device_count():
raise ValueError(
"On multi-host platforms, the input to pmapped functions must have "
"leading axis size equal to the number of local devices if no "
"`devices` argument is specified. Got axis_size=%d, "
"num_local_devices=%d" % (axis_size, xb.local_device_count()))
global_axis_size = xb.device_count()
else:
global_axis_size = axis_size
if global_axis_size is not None:
if global_axis_size != axis_size:
raise ValueError(
"Specified axis_size {} doesn't match received axis_size {}.".format(
global_axis_size, axis_size))
else:
global_axis_size = axis_size

log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
Expand Down Expand Up @@ -512,7 +525,7 @@ def dynamic_fun(dummy, *args):
# violating pmap's semantics where data is sharded across replicas in
# row-major order. Instead, manually create a device assignment that ensures
# each host is responsible for a continguous set of replicas.
if xb.host_count() > 1:
if num_global_replicas > num_local_replicas:
# TODO(skye): use a locality-aware assignment that satisfies the above
# constraint.
devices = [d for host_id in xb.host_ids()
Expand Down Expand Up @@ -544,7 +557,7 @@ def dynamic_fun(dummy, *args):
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas,
out_pvals, compiled.local_devices(),
backend)
return partial(execute_replicated, compiled, backend, num_local_replicas, handle_args, handle_outs)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)

multi_host_supported_collectives = set()

Expand Down Expand Up @@ -625,11 +638,7 @@ def _pval_to_result_handler(axis_size, nrep, pval, devices, backend):
else:
return aval_to_result_handler(axis_size, nrep, pv)

def execute_replicated(compiled, backend, nrep, in_handler, out_handler, *args):
if nrep > xb.device_count(backend):
msg = ("executing pmap computation that requires {} replicas, but only {} "
"XLA devices are available")
raise ValueError(msg.format(nrep, xb.device_count(backend)))
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
input_bufs = in_handler(args)
out_bufs = compiled.ExecutePerReplica(list(input_bufs))
return out_handler(out_bufs)
Expand All @@ -642,12 +651,15 @@ def execute_replicated(compiled, backend, nrep, in_handler, out_handler, *args):
xla_pmap_p.def_impl(xla_pmap_impl)

def _pmap_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes,
in_nodes, axis_name, axis_size, devices, backend=None):
in_nodes, axis_name, axis_size, global_axis_size,
devices, backend=None):
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
if axis_env.devices is not None or (axis_env.names and devices is not None):
raise ValueError("Nested pmaps with explicit devices argument.")
new_env = xla.extend_axis_env(axis_env, axis_name, axis_size)
if global_axis_size is None:
global_axis_size = axis_size
new_env = xla.extend_axis_env(axis_env, axis_name, global_axis_size)
in_nodes_sharded = list(map(partial(_xla_shard, c, new_env), in_nodes))
sharded_outs = xla.jaxpr_subcomp(c, jaxpr, backend, new_env, const_nodes,
freevar_nodes, *in_nodes_sharded)
Expand Down Expand Up @@ -822,7 +834,7 @@ def process_call(self, call_primitive, f, tracers, params):
def process_map(self, map_primitive, f, tracers, params):
vals, names = unzip2((t.val, t.axis_name) for t in tracers)
if all(name is not_mapped for name in names):
return map_primitive.bind(f, *vals, **params)
return map_primitive.bind(f, *vals, **params)
else:
# because the map primitive maps over leading axes, we need to transpose
# the software-mapped axis on any mapped arguments to be the second axis;
Expand Down

0 comments on commit 12975bb

Please sign in to comment.