Skip to content

Commit

Permalink
[JAX] Prune unused inputs in jit.
Browse files Browse the repository at this point in the history
- Python part based on: jax-ml#6567
- Added cpp_jit path to handle pruned args

PiperOrigin-RevId: 371743277
  • Loading branch information
zhangqiaorjc authored and jax authors committed May 3, 2021
1 parent e6bdcbb commit 850bd66
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 14 deletions.
26 changes: 25 additions & 1 deletion benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,37 @@ def jit_simple_many_args(n, state):
while state:
f(args).block_until_ready()

def jit_simple_pruned_args_dispatch(n, state):
args = [jax.device_put(i) for i in range(n)]
f = jax.jit(lambda *xs: xs[0] + 1)
x = f(*args)
x.block_until_ready()

while state:
x = f(*args)
x.block_until_ready()


def jit_simple_pruned_args(n, state):
args = [jax.device_put(i) for i in range(n)]
f = jax.jit(lambda *xs: xs[0] + 1)
x = f(*args)
x.block_until_ready()

while state:
f(*args).block_until_ready()

benchmarks = []
for n in [10, 100, 1000, 2000]:
benchmarks += [
google_benchmark.register(partial(jit_simple_many_args_dispatch, n),
name=f"jit_simple_many_args_dispatch_{n}"),
google_benchmark.register(partial(jit_simple_many_args, n),
name=f"jit_simple_many_args_{n}")
name=f"jit_simple_many_args_{n}"),
google_benchmark.register(partial(jit_simple_pruned_args_dispatch, n),
name=f"jit_simple_pruned_args_dispatch_{n}"),
google_benchmark.register(partial(jit_simple_pruned_args, n),
name=f"jit_simple_pruned_args_{n}")
]


Expand Down
18 changes: 16 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ class _BackendAndDeviceInfo(NamedTuple):
default_device: xc.Device
committed_to_device: bool

class _FastpathData(NamedTuple):
xla_executable: xla.XlaExecutable
out_pytree_def: Any
sticky_device: xc.Device
avals: Iterable[Any]
lazy_exprs: Iterable[Any]
kept_var_bitvec: Iterable[bool]

if lib._xla_extension_version >= 16:
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
Expand Down Expand Up @@ -442,15 +449,22 @@ def cache_miss(*args, **kwargs):
all(xla.type_is_device_array(x) for x in out_flat))
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
xla_executable, _, result_handlers = execute.args
xla_executable, _, result_handlers, kept_var_idx = execute.args
sticky_device = None
avals = []
lazy_exprs = [None] * len(result_handlers)
for result_handler in result_handlers:
aval, sticky_device = result_handler.args
avals.append(aval)
assert len(avals) == len(out_flat)
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
if xla._ALLOW_ARG_PRUNING:
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
fastpath_data = _FastpathData(xla_executable, out_pytree_def,
sticky_device, avals, lazy_exprs,
kept_var_bitvec)
else:
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals,
lazy_exprs)
else:
fastpath_data = None

Expand Down
77 changes: 66 additions & 11 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
extend_name_stack, wrap_name, safe_zip, safe_map)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from ..lib import _xla_extension_version
from . import partial_eval as pe
from . import ad
from . import masking
Expand Down Expand Up @@ -647,10 +648,17 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))

abstract_args, arg_devices = unzip2(arg_specs)
abstract_args, _ = unzip2(arg_specs)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
if any(isinstance(c, core.Tracer) for c in consts):
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
abstract_args, arg_devices = unzip2(pruned_arg_specs)
donated_invars = [
x for i, x in enumerate(donated_invars) if i in kept_var_idx
]
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
jaxpr = apply_outfeed_rewriter(jaxpr)

Expand All @@ -663,7 +671,8 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if not jaxpr.eqns:
return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers)
return partial(_execute_trivial, jaxpr, device, consts, out_avals,
result_handlers, kept_var_idx)

if not _on_exit:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
Expand Down Expand Up @@ -714,9 +723,12 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
options.parameter_is_tupled_arguments = tuple_args
compiled = backend_compile(backend, built, options)
if nreps == 1:
return partial(_execute_compiled, compiled, out_avals, result_handlers)
return partial(_execute_compiled, compiled, out_avals, result_handlers,
kept_var_idx)
else:
return partial(_execute_replicated, compiled, out_avals, result_handlers)
return partial(_execute_replicated, compiled, out_avals, result_handlers,
kept_var_idx)


def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
"""Configures input/output "must" aliasing based on `donated_args`."""
Expand Down Expand Up @@ -746,6 +758,33 @@ def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):

return tuple(out_donated_args)


# Pruning unused JIT arguments require jaxlib 0.1.66 or newer.
# TODO(zhangqiaorjc): remove when jaxlib 0.1.66 is the minimum.
_ALLOW_ARG_PRUNING = _xla_extension_version >= 18


def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
if not _ALLOW_ARG_PRUNING:
kept_const_idx = range(len(jaxpr.constvars))
kept_var_idx = range(len(jaxpr.invars))
return jaxpr, set(kept_const_idx), set(kept_var_idx)

used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
# applications that do not produce used outputs. Must handle side-effecting
# primitives and nested jaxpr.
used.update(
v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
kept_const_idx, new_constvars = unzip2(
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
kept_var_idx, new_invars = unzip2(
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns)
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)


def _xla_callable_device(nreps, backend, device, arg_devices):
if nreps > 1:
if device is not None or backend is not None:
Expand Down Expand Up @@ -823,27 +862,43 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_prot
else:
return with_sharding(builder, partitions, make_param)

def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args):

def _execute_compiled(compiled: XlaExecutable, avals, handlers, kept_var_idx,
*args):
device, = compiled.local_devices()
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
input_bufs = list(
it.chain.from_iterable(
device_put(x, device)
for i, x in enumerate(args)
if x is not token and i in kept_var_idx))
out_bufs = compiled.execute(input_bufs)
check_special(xla_call_p.name, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):

def _execute_replicated(compiled: XlaExecutable, avals, handlers, kept_var_idx,
*args):
input_bufs = [
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
for device in compiled.local_devices()]
list(
it.chain.from_iterable(
device_put(x, device)
for i, x in enumerate(args)
if x is not token and i in kept_var_idx))
for device in compiled.local_devices()
]
out_bufs = [
buf[0] for buf in compiled.execute_sharded_on_local_devices(
list(zip(*input_bufs)))
]
check_special(xla_call_p.name, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args):

def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
kept_var_idx, *args):
env = {core.unitvar: core.unit}
map(env.setdefault, jaxpr.invars, args)
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
for v in jaxpr.outvars]
Expand Down
17 changes: 17 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,23 @@ def f(x, y):
x_is_tracer, y_is_tracer = False, True
assert f_mixed(x='foo', y=3) == 1

# TODO(zhangqiaorjc): Test pruning constants after DCE pass prunes primitive
# applications.
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_num_args={}".format(num_args),
"num_args": num_args}
for num_args in [2, 3, 4]))
def test_jit_with_pruned_args(self, num_args):
def f(*args):
used = np.array(2)
return args[1] + used
f_pruned = self.jit(f)
args = range(num_args)
with jtu.count_device_put() as count:
np.testing.assert_allclose(f_pruned(*args), 3)
self.assertEqual(count[0], 1)


class PythonJitTest(CPPJitTest):

Expand Down
40 changes: 40 additions & 0 deletions tests/xla_interpreter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from absl.testing import absltest

from jax import test_util as jtu
from jax._src import api
from jax.interpreters import xla


class XlaInterpreterTest(jtu.JaxTestCase):

@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
def test_prune_jit_args(self):
def f(*args):
return args[0]

closed_jaxpr = api.make_jaxpr(f)(*range(10))
pruned_jaxpr, kept_const_idx, kept_var_idx = xla._prune_unused_inputs(
closed_jaxpr.jaxpr)
assert len(pruned_jaxpr.invars) == 1
assert kept_const_idx == set()
assert kept_var_idx == {0}


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 850bd66

Please sign in to comment.