Skip to content

Commit

Permalink
Merge pull request jax-ml#1211 from levskaya/multibackend
Browse files Browse the repository at this point in the history
multibackend jit
  • Loading branch information
skye authored Aug 26, 2019
2 parents 8940322 + 0cc21c8 commit 2d26ac3
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 121 deletions.
31 changes: 19 additions & 12 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self):

_thread_local_state = _ThreadLocalState()

def jit(fun, static_argnums=(), device_assignment=None):
def jit(fun, static_argnums=(), device_assignment=None, backend=None):
"""Sets up `fun` for just-in-time compilation with XLA.
Args:
Expand All @@ -101,6 +101,8 @@ def jit(fun, static_argnums=(), device_assignment=None):
change. Optional, an int specifying the device ordinal for which to compile the
function. The default is inherited from XLA's DeviceAssignment logic and is
usually to use device 0.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
Returns:
A wrapped version of `fun`, set up for just-in-time compilation.
Expand Down Expand Up @@ -141,7 +143,7 @@ def f_jitted(*args, **kwargs):
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
_check_args(args_flat)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment)
out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)
return tree_unflatten(out_tree(), out)

jitted_name = "jit({}, static_argnums={})"
Expand Down Expand Up @@ -191,7 +193,7 @@ def disable_jit():
_thread_local_state.jit_is_disabled = prev_val


def xla_computation(fun, static_argnums=(), axis_env=None):
def xla_computation(fun, static_argnums=(), axis_env=None, backend=None):
"""Creates a function that produces its XLA computation given example args.
Args:
Expand All @@ -203,6 +205,8 @@ def xla_computation(fun, static_argnums=(), axis_env=None):
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of ``jax.pmap``. See the examples below.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns a
Expand Down Expand Up @@ -285,7 +289,7 @@ def computation_maker(*args, **kwargs):
pvals = map(pv_like, jax_args)
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
return xla.build_jaxpr(jaxpr, axis_env_, consts,
return xla.build_jaxpr(jaxpr, backend, axis_env_, consts,
*map(xla.abstractify, jax_args))
return computation_maker

Expand Down Expand Up @@ -624,7 +628,7 @@ class _NoneProxy(object): pass
_none_proxy = _NoneProxy()


def pmap(fun, axis_name=None):
def pmap(fun, axis_name=None, backend=None):
"""Parallel map with support for collectives.
The purpose of ``pmap`` is to express single-program multiple-data (SPMD)
Expand All @@ -647,6 +651,8 @@ def pmap(fun, axis_name=None):
fun: Function to be mapped over argument axes.
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
Returns:
A parallelized version of ``fun`` with arguments that correspond to those of
Expand Down Expand Up @@ -718,7 +724,7 @@ def f_pmapped(*args, **kwargs):
axis_size = _pmap_axis_size(args)
_check_args(args)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = pxla.xla_pmap(flat_fun, *args, axis_name=axis_name, axis_size=axis_size)
out = pxla.xla_pmap(flat_fun, *args, axis_name=axis_name, axis_size=axis_size, backend=backend)
return tree_unflatten(out_tree(), out)

namestr = "pmap({}, axis_name={})".format
Expand All @@ -740,7 +746,7 @@ def __repr__(self):
return '<axis {}>'.format(hex(id(self)))


def soft_pmap(fun, axis_name=None):
def soft_pmap(fun, axis_name=None, backend=None):
_check_callable(fun)
axis_name = _TempAxisName() if axis_name is None else axis_name

Expand All @@ -752,9 +758,9 @@ def f_pmapped(*args, **kwargs):
_check_args(args_flat)
flat_fun, out_tree = flatten_fun(f, in_tree)

chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count())
chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count(backend))
if chunk_size == 0 and leftover:
return pmap(fun, axis_name)(*args) # can map directly onto hardware
return pmap(fun, axis_name, backend)(*args) # can map directly onto hardware
elif leftover:
msg = ("soft_pmap mapped axis size must be divisble by the number of "
"XLA devices (or be less than or equal to that number), but got "
Expand All @@ -765,7 +771,8 @@ def f_pmapped(*args, **kwargs):
reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]
soft_mapped_fun = pxla.split_axis(flat_fun, axis_name, chunk_size)
reshaped_outs = pxla.xla_pmap(soft_mapped_fun, *reshaped_args,
axis_name=axis_name, axis_size=num_chunks)
axis_name=axis_name, axis_size=num_chunks,
backend=backend)
outs = [_reshape_merge(out) for out in reshaped_outs]
return tree_unflatten(out_tree(), outs)

Expand Down Expand Up @@ -1084,8 +1091,8 @@ def jaxpr_maker(*args, **kwargs):
return jaxpr_maker


def device_put(x, device_num=0):
return tree_map(lambda y: xla.device_put_p.bind(y, device_num=device_num), x)
def device_put(x, device_num=0, backend=None):
return tree_map(lambda y: xla.device_put_p.bind(y, device_num=device_num, backend=backend), x)


# TODO(mattjj): consider revising
Expand Down
47 changes: 24 additions & 23 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

def identity(x): return x

def shard_args(device_ordinals, assignments, axis_size, args):
def shard_args(backend, device_ordinals, assignments, axis_size, args):
"""Shard an argument data array arg along its leading axis.
Args:
Expand Down Expand Up @@ -74,18 +74,18 @@ def shard_args(device_ordinals, assignments, axis_size, args):
else buf.copy_to_device(device_ordinals[r]))
else:
for r, buf in enumerate(arg.device_buffers):
buffers[r][a] = xla.device_put(x[assignments[r]], ordinals[r])
buffers[r][a] = xla.device_put(x[assignments[r]], ordinals[r], backend=backend)
else:
bufs = shard_arg_handlers[type(arg)](arg, device_ordinals, assignments)
bufs = shard_arg_handlers[type(arg)](arg, device_ordinals, assignments, backend=backend)
for r, buf in enumerate(bufs):
buffers[r][a] = buf
return buffers
shard_arg_handlers = {}
shard_arg_handlers[core.Unit] = \
lambda x, ordinals, _: [xla.device_put(core.unit, d) for d in ordinals]
def _shard_array(x, ordinals, assignments):
lambda x, ordinals, _, backend=None: [xla.device_put(core.unit, d, backend=backend) for d in ordinals]
def _shard_array(x, ordinals, assignments, backend=None):
nrep = len(ordinals)
return (xla.device_put(x[assignments[r]], ordinals[r]) for r in range(nrep))
return (xla.device_put(x[assignments[r]], ordinals[r], backend=backend) for r in range(nrep))
for _t in it.chain(array_types, [xla.DeviceArray]):
shard_arg_handlers[_t] = _shard_array

Expand Down Expand Up @@ -159,17 +159,17 @@ def replica_groups(nrep, mesh_spec, mesh_axes):

### the main pmap machinery lowers SPMD jaxprs to multi-replica XLA computations

def compile_replicated(jaxpr, axis_name, axis_size, consts, *abstract_args):
def compile_replicated(jaxpr, backend, axis_name, axis_size, consts, *abstract_args):
num_replicas = axis_size * xla.jaxpr_replicas(jaxpr)
if num_replicas > xb.device_count():
if num_replicas > xb.device_count(backend):
msg = ("compiling computation that requires {} replicas, but only {} XLA "
"devices are available")
raise ValueError(msg.format(num_replicas, xb.device_count()))
raise ValueError(msg.format(num_replicas, xb.device_count(backend)))
axis_env = xla.AxisEnv(num_replicas, [axis_name], [axis_size])
arg_shapes = list(map(aval_to_xla_shape, abstract_args))
built_c = xla.jaxpr_computation(jaxpr, axis_env, consts, (), *arg_shapes)
built_c = xla.jaxpr_computation(jaxpr, backend, axis_env, consts, (), *arg_shapes)
compiled = built_c.Compile(arg_shapes, xb.get_compile_options(num_replicas),
backend=xb.get_backend())
backend=xb.get_backend(backend))
return compiled, num_replicas


Expand Down Expand Up @@ -222,10 +222,10 @@ def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size):
yield
dynamic_axis_env.pop()

def unmapped_device_count():
def unmapped_device_count(backend=None):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
mapped = prod(frame.hard_size for frame in dynamic_axis_env)
unmapped, ragged = divmod(xb.device_count(), mapped)
unmapped, ragged = divmod(xb.device_count(backend), mapped)
assert not ragged and unmapped > 0
return unmapped

Expand Down Expand Up @@ -394,13 +394,14 @@ def __getitem__(self, idx):
def xla_pmap_impl(fun, *args, **params):
axis_name = params.pop('axis_name')
axis_size = params.pop('axis_size')
backend = params.pop('backend', None)
assert not params
abstract_args = map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, axis_name, axis_size, *abstract_args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size, *abstract_args)
return compiled_fun(*args)

@lu.cache
def parallel_callable(fun, axis_name, axis_size, *avals):
def parallel_callable(fun, backend, axis_name, axis_size, *avals):
avals = tuple(map(partial(shard_aval, axis_size), avals))
pvals = [PartialVal((aval, core.unit)) for aval in avals]
pval = PartialVal([core.abstract_unit, core.unit]) # dummy value
Expand All @@ -426,12 +427,12 @@ def dynamic_fun(dummy, *args):
results = [handler(None) for handler in handlers]
return lambda *_: results
else:
compiled, nrep = compile_replicated(jaxpr, axis_name, axis_size, consts, *avals)
compiled, nrep = compile_replicated(jaxpr, backend, axis_name, axis_size, consts, *avals)
device_ordinals = compiled.DeviceOrdinals()
assignments = assign_shards_to_replicas(nrep, axis_size)
handle_args = partial(shard_args, device_ordinals, assignments, axis_size)
handle_args = partial(shard_args, backend, device_ordinals, assignments, axis_size)
handle_outs = _pvals_to_results_handler(axis_size, nrep, out_pvals)
return partial(execute_replicated, compiled, nrep, handle_args, handle_outs)
return partial(execute_replicated, compiled, backend, nrep, handle_args, handle_outs)

def _pvals_to_results_handler(size, nrep, out_pvals):
nouts = len(out_pvals)
Expand All @@ -452,11 +453,11 @@ def _pval_to_result_handler(size, nrep, pval):
else:
return aval_to_result_handler(size, nrep, pv)

def execute_replicated(compiled, nrep, in_handler, out_handler, *args):
if nrep > xb.device_count():
def execute_replicated(compiled, backend, nrep, in_handler, out_handler, *args):
if nrep > xb.device_count(backend):
msg = ("executing pmap computation that requires {} replicas, but only {} "
"XLA devices are available")
raise ValueError(msg.format(nrep, xb.device_count()))
raise ValueError(msg.format(nrep, xb.device_count(backend)))
input_bufs = in_handler(args)
out_bufs = compiled.ExecutePerReplica(list(input_bufs))
return out_handler(out_bufs)
Expand All @@ -469,10 +470,10 @@ def execute_replicated(compiled, nrep, in_handler, out_handler, *args):
xla_pmap_p.def_impl(xla_pmap_impl)

def _xla_pmap_translation_rule(c, jaxpr, axis_env, env_nodes, in_nodes,
axis_name, axis_size):
axis_name, axis_size, backend=None):
new_env = xla.extend_axis_env(axis_env, axis_name, axis_size)
in_nodes_sharded = list(map(partial(_xla_shard, c, new_env.sizes), in_nodes))
subc = xla.jaxpr_computation(jaxpr, new_env, (),
subc = xla.jaxpr_computation(jaxpr, backend, new_env, (),
tuple(map(c.GetShape, env_nodes)),
*map(c.GetShape, in_nodes_sharded))
sharded_result = c.Call(subc, env_nodes + in_nodes_sharded)
Expand Down
Loading

0 comments on commit 2d26ac3

Please sign in to comment.