Skip to content

Commit

Permalink
Some cleanup and reformatting in xla.py.
Browse files Browse the repository at this point in the history
- Make creation of a few dictionaries more readable.
- Use f-strings where possible.
- Remove unused imports and function parameters.
- Don't format string before passing to `log` function.
  • Loading branch information
chr1sj0nes committed Apr 16, 2020
1 parent 0e29bd4 commit 903b50e
Showing 1 changed file with 69 additions and 78 deletions.
147 changes: 69 additions & 78 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@
abstract_token)
from ..core import Literal, pp_eqn_compact
from ..pprint_util import pp
from ..util import (partial, partialmethod, cache, safe_map, prod, unzip2,
memoize, extend_name_stack, wrap_name)
from ..util import (partial, partialmethod, cache, prod, unzip2, memoize,
extend_name_stack, wrap_name)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from . import partial_eval as pe
from . import ad
from . import masking
from typing import Callable

FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_debug_nans',
Expand All @@ -52,6 +51,7 @@
def _map(f, *xs): return tuple(map(f, *xs))
def identity(x): return x

_scalar_types = dtypes.python_scalar_dtypes.keys()

# unit representation
def _make_unit(c): return c.Constant(onp.zeros((), dtype=onp.dtype('bool')))
Expand All @@ -70,47 +70,45 @@ def aval_to_xla_shape(aval):
try:
return xla_shape_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError("No xla_shape_handler for type: {}".format(type(aval))
) from err
xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = {}
xla_shape_handlers[core.AbstractUnit] = _make_abstract_unit
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err

xla_shape_handlers[ShapedArray] = _make_array_shape
xla_shape_handlers[ConcreteArray] = _make_array_shape
xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = {
core.AbstractUnit: _make_abstract_unit,
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}

def aval_to_result_handler(device, aval):
try:
return xla_result_handlers[type(aval)](device, aval)
except KeyError as err:
raise TypeError("No xla_result_handler for type: {}".format(type(aval))
) from err
xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {}
xla_result_handlers[core.AbstractUnit] = lambda _, __: lambda _: core.unit
raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err

def array_result_handler(device, aval):
return partial(DeviceArray, raise_to_shaped(aval), device, lazy.array(aval.shape))
xla_result_handlers[ShapedArray] = array_result_handler
xla_result_handlers[ConcreteArray] = array_result_handler

xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
core.AbstractUnit: lambda _, __: lambda _: core.unit,
ShapedArray: array_result_handler,
ConcreteArray: array_result_handler,
}

def device_put(x, device=None):
x = canonicalize_dtype(x)
try:
return device_put_handlers[type(x)](x, device)
except KeyError as err:
raise TypeError("No device_put handler for type: {}".format(type(x))
) from err
raise TypeError(f"No device_put handler for type: {type(x)}") from err

device_put_handlers: Dict[Any, Callable] = {}
device_put_handlers[core.Unit] = _device_put_unit
def _device_put_array(x, device):
return xc.Buffer.from_pyval(x, device, backend=xb.get_device_backend(device))
for _t in array_types:
device_put_handlers[_t] = _device_put_array

def _device_put_scalar(x, device):
return xc.Buffer.from_pyval(dtypes.coerce_to_array(x), device,
backend=xb.get_device_backend(device))
for _t in dtypes.python_scalar_dtypes.keys():
device_put_handlers[_t] = _device_put_array
return _device_put_array(dtypes.coerce_to_array(x), device)

device_put_handlers: Dict[Any, Callable] = {core.Unit: _device_put_unit}
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)

# TODO(mattjj): try to remove this canonicalize_dtype stuff
def canonicalize_dtype(x):
Expand All @@ -120,19 +118,20 @@ def canonicalize_dtype(x):
for typ in typ.mro():
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
raise TypeError("No canonicalize_dtype handler for type: {}".format(type(x)))
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")

canonicalize_dtype_handlers: Dict[Any, Callable] = {}
canonicalize_dtype_handlers[core.Unit] = identity
def _canonicalize_ndarray_dtype(x):
return onp.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
for _t in array_types:
canonicalize_dtype_handlers[_t] = _canonicalize_ndarray_dtype

def _canonicalize_python_scalar_dtype(typ, x):
return onp.asarray(
x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ]))
for _t in dtypes.python_scalar_dtypes.keys():
canonicalize_dtype_handlers[_t] = partial(_canonicalize_python_scalar_dtype, _t)
x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ]))

canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity}
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in array_types)
canonicalize_dtype_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)

def abstractify(x) -> core.AbstractValue:
typ = type(x)
Expand All @@ -141,19 +140,17 @@ def abstractify(x) -> core.AbstractValue:
for typ in typ.mro():
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
raise TypeError("No abstraction handler for type: {}".format(type(x)))

pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {}
pytype_aval_mappings[core.Unit] = lambda _: core.abstract_unit
for _t in array_types:
pytype_aval_mappings[_t] = make_shaped_array
raise TypeError(f"No abstraction handler for type: {type(x)}")

def _make_abstract_python_scalar(typ, _):
return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True)

for _t in dtypes.python_scalar_dtypes.keys():
pytype_aval_mappings[_t] = partial(_make_abstract_python_scalar, _t)

pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {
core.Unit: lambda _: core.abstract_unit,
}
pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)

### op-by-op execution

Expand Down Expand Up @@ -186,10 +183,10 @@ def xla_primitive_callable(prim, *arg_specs, **params):
else:
nreps = 1
if nreps > xb.device_count(backend):
msg = ("compiling a primitive computation `{}` that requires {} replicas, "
"but only {} XLA devices are available on backend {}.")
raise ValueError(msg.format(prim, nreps, xb.device_count(backend),
backend.platform))
raise ValueError(
f"compiling a primitive computation `{prim}` that requires {nreps} "
f"replicas, but only {xb.device_count(backend)} XLA devices are "
f"available on backend {backend.platform}.")
built_c = primitive_computation(prim, AxisEnv(nreps), backend, tuple_args,
*avals, **params)
options = xb.get_compile_options(
Expand All @@ -199,11 +196,9 @@ def xla_primitive_callable(prim, *arg_specs, **params):
options.tuple_arguments = tuple_args
compiled = built_c.Compile(compile_options=options, backend=backend)
if nreps == 1:
return partial(_execute_compiled_primitive, prim, compiled, backend,
handle_result)
return partial(_execute_compiled_primitive, prim, compiled, handle_result)
else:
return partial(_execute_replicated_primitive, prim, compiled, backend,
handle_result)
return partial(_execute_replicated_primitive, prim, compiled, handle_result)

def _device_from_arg_devices(devices):
"""Given devices of inputs, determine where to perform a computation.
Expand All @@ -224,7 +219,7 @@ def _device_from_arg_devices(devices):

@cache()
def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params):
c = xb.make_computation_builder("primitive_computation_{}".format(prim.name))
c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
c.SetOpMetadata(xc.OpMetadata(
op_type=prim.name,
op_name=str(pp_eqn_compact(prim.name, params))))
Expand All @@ -242,7 +237,7 @@ def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params)
ans = rule(c, axis_env, extend_name_stack(prim.name), avals, backend,
*xla_args, **params)
else:
raise NotImplementedError("XLA translation rule for {} not found".format(prim))
raise NotImplementedError(f"XLA translation rule for {prim} not found")
assert isinstance(ans, xc._xla.XlaOp)
c.ClearOpMetadata()
try:
Expand All @@ -256,16 +251,15 @@ def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params)
def primitive_subcomputation(prim, *avals, **params):
return primitive_computation(prim, AxisEnv(1), None, False, *avals, **params)

def _execute_compiled_primitive(prim, compiled, backend, result_handler, *args):
def _execute_compiled_primitive(prim, compiled, result_handler, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
out_bufs = compiled.Execute(input_bufs)
if FLAGS.jax_debug_nans:
check_nans(prim, out_bufs)
return result_handler(out_bufs if prim.multiple_results else out_bufs[0])

def _execute_replicated_primitive(prim, compiled, backend, result_handler,
*args):
def _execute_replicated_primitive(prim, compiled, result_handler, *args):
input_bufs = [
[device_put(x, device) for x in args if x is not token]
for device in compiled.local_devices()]
Expand All @@ -282,8 +276,7 @@ def _check_nans(name, xla_shape, buf):
assert not xla_shape.is_tuple()
if dtypes.issubdtype(xla_shape.element_type(), onp.inexact):
if onp.any(onp.isnan(buf.to_py())):
msg = "invalid value (nan) encountered in {}"
raise FloatingPointError(msg.format(name))
raise FloatingPointError(f"invalid value (nan) encountered in {name}")

### compiling jaxprs

Expand Down Expand Up @@ -353,8 +346,8 @@ def write(v, node):
ans = rule(c, axis_env, in_nodes,
name_stack, backend=backend, **new_params)
else:
msg = "XLA translation rule for primitive '{}' not found"
raise NotImplementedError(msg.format(eqn.primitive.name))
raise NotImplementedError(
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")

assert isinstance(ans, xc._xla.XlaOp)
c.GetShape(ans) # force xla to do shape error checking
Expand All @@ -372,10 +365,9 @@ def check_backend_params(params, outer_backend):
# it's an error if the inner call has a conflicting explicit backend spec.
inner_backend = params.get('backend', None)
if inner_backend and inner_backend != outer_backend:
msg = (
"Outer-jit backend specification {} must match explicit inner-jit "
"backend specification {}.")
raise ValueError(msg.format(outer_backend, inner_backend))
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")
return {k: params[k] for k in params if k != 'backend'}


Expand Down Expand Up @@ -495,13 +487,13 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
return partial(_execute_trivial, jaxpr, device, consts, result_handlers)

log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
"Compiling {} for args {}.".format(fun.__name__, abstract_args))
logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args)

if nreps > xb.device_count(backend):
msg = ("compiling computation that requires {} replicas, but only {} XLA "
"devices are available")
raise ValueError(msg.format(nreps, xb.device_count(backend)))
raise ValueError(
f"compiling computation that requires {nreps} replicas, but only "
f"{xb.device_count(backend)} XLA devices are available")

if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
Expand All @@ -525,15 +517,15 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend))

if nreps == 1:
return partial(_execute_compiled, compiled, backend, result_handlers)
return partial(_execute_compiled, compiled, result_handlers)
else:
return partial(_execute_replicated, compiled, backend, result_handlers)
return partial(_execute_replicated, compiled, result_handlers)

def _xla_callable_device(nreps, backend, device, arg_devices):
if nreps > 1:
if device is not None or backend is not None:
raise ValueError("can't specify device or backend for jit-of-pmap, "
"got device={} and backend={}".format(device, backend))
raise ValueError(f"can't specify device or backend for jit-of-pmap, "
f"got device={device} and backend={backend}")
return None
else:
if device is None and backend is None:
Expand Down Expand Up @@ -567,14 +559,14 @@ def _pval_to_result_handler(device, pval):
else:
return aval_to_result_handler(device, pv)

def _execute_compiled(compiled, backend, handlers, *args):
def _execute_compiled(compiled, handlers, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
out_bufs = compiled.Execute(input_bufs)
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]

def _execute_replicated(compiled, backend, handlers, *args):
def _execute_replicated(compiled, handlers, *args):
input_bufs = [
[device_put(x, device) for x in args if x is not token]
for device in compiled.local_devices()]
Expand Down Expand Up @@ -615,7 +607,7 @@ def _xla_call_translation_rule(c, axis_env,
in_nodes, name_stack, backend, name,
call_jaxpr, device=None):
del device # Ignored.
subc = xb.make_computation_builder("jit_{}".format(name))
subc = xb.make_computation_builder(f"jit_{name}")
args = [subc.ParameterWithShape(c.GetShape(n)) for n in in_nodes]
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
Expand Down Expand Up @@ -947,7 +939,6 @@ def _force(x: DeviceArray) -> DeviceArray:
return force_fun(x)

@cache()

def _lazy_force_computation(sticky, aval, device, lexpr) -> Callable[[DeviceArray], DeviceArray]:
c = xb.make_computation_builder("lazy_force")
if lazy.is_constant(lexpr):
Expand Down Expand Up @@ -988,8 +979,8 @@ def _device_put_impl(x, device=None):
try:
a = abstractify(x)
except TypeError as err:
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(x, type(x))) from err
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
handler = aval_to_result_handler(device, a)
return handler(device_put(x, device))

Expand Down

0 comments on commit 903b50e

Please sign in to comment.