Skip to content

Commit

Permalink
Make eager pmap take advantage of pmap cache
Browse files Browse the repository at this point in the history
The current strategy of creating a `partial(primitive.bind, **params)` has the downside
of completely confusing the pmap cache and resulting in a new compilation for every single
primitive. Replacing it with a `HashableFunction` should fix it.

Also, pmap_test is now 2x faster!

PiperOrigin-RevId: 490749153
  • Loading branch information
apaszke authored and jax authors committed Nov 24, 2022
1 parent 8788a94 commit a711166
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@

import enum
from contextlib import contextmanager, ContextDecorator
from collections import defaultdict, OrderedDict
from collections import defaultdict, OrderedDict, namedtuple
import dataclasses
from functools import partial, lru_cache
import itertools as it
import logging
import operator as op
import sys
import threading
import types
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
TYPE_CHECKING)
Expand Down Expand Up @@ -82,7 +81,7 @@
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log,
unzip2)
unzip2, HashableFunction)

if TYPE_CHECKING:
from jax._src.sharding import NamedSharding, XLACompatibleSharding
Expand Down Expand Up @@ -1047,15 +1046,22 @@ def _emap_impl(fun: lu.WrappedFun, *args,
new_outvals.append(out)
return new_outvals

def _map_schedule(idx: Tuple[Optional[int], ...]) -> List[Optional[int]]:
def _map_schedule(idx: Tuple[Optional[int], ...]) -> Tuple[Optional[int], ...]:
# In order to do a multi-map (a simultaneous map over several axes), we will
# nest several maps. Each time we do a map, we "remove" an input axis so we
# need to update the remaining map axes. For example, if we are to map over
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
return [None if i is None else
i - sum(j is not None and j < i for j in idx[:l])
for l, i in enumerate(idx)]
return tuple(None if i is None else
i - sum(j is not None and j < i for j in idx[:l])
for l, i in enumerate(idx))


# We're often creating `f`s on the fly and we try to carefully make them have
# the right __hash__ and __eq__. However, despite our attempts pmap's caching
# still ends up not working, because it has a separate cache per
# _function object_. Adding this annotation here lets us reuse the same pmap
# callable for all equivalent primitive pmaps.
@lru_cache()
def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
all_axes: List[Tuple[Optional[int], ...]]
) -> Tuple[Callable, Dict[core.AxisName, int]]:
Expand All @@ -1074,6 +1080,8 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
return f, out_shard_axes

_FakePrimitive = namedtuple("_FakePrimitive", ["multiple_results", "bind"])

class MapTrace(core.Trace):

def __init__(self, *args, emap_info):
Expand All @@ -1089,11 +1097,12 @@ def sublift(self, tracer):
def process_primitive(self, primitive, tracers, params):
info = self.main.payload["emap_info"]
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
names = [f.name for f in core.thread_local_state.trace_state.axis_env
if f.main_trace is self.main]
all_axes = [_map_schedule(map(s.get, names)) for s in shard_axes]
f_mapped, out_shard_axes = _multi_pmap(partial(primitive.bind, **params),
info, names, all_axes)
names = tuple(f.name for f in core.thread_local_state.trace_state.axis_env
if f.main_trace is self.main)
all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes)
f = HashableFunction(lambda *args: primitive.bind(*args, **params),
(primitive, tuple(params.items())))
f_mapped, out_shard_axes = _multi_pmap(f, info, names, all_axes)
with core.eval_context(), jax.disable_jit(False):
outvals = f_mapped(*vals)
if primitive.multiple_results:
Expand All @@ -1102,16 +1111,20 @@ def process_primitive(self, primitive, tracers, params):

def process_call(self, call_primitive, fun, tracers, params):
if call_primitive is not xla.xla_call_p: raise NotImplementedError
fake_primitive = types.SimpleNamespace(
multiple_results=True, bind=partial(call_primitive.bind, fun))
bind = HashableFunction(
lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs),
(call_primitive, fun))
fake_primitive = _FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, params)

def process_map(self, call_primitive, fun, tracers, params):
if params['devices'] is not None:
raise ValueError("Nested pmap with explicit devices argument.")
if not config.jax_disable_jit:
fake_primitive = types.SimpleNamespace(
multiple_results=True, bind=partial(call_primitive.bind, fun))
bind = HashableFunction(
lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs),
(call_primitive, fun))
fake_primitive = _FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, params)
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
Expand All @@ -1131,20 +1144,26 @@ def process_map(self, call_primitive, fun, tracers, params):
return map(partial(MapTracer, self), out, outaxes)

def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
fake_primitive = types.SimpleNamespace(
multiple_results=True, bind=partial(primitive.bind, fun, jvp))
bind = HashableFunction(
lambda *args, **kwargs: primitive.bind(fun, jvp, *args, **kwargs),
(primitive, fun, jvp))
fake_primitive = _FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})

def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees):
fake_primitive = types.SimpleNamespace(
multiple_results=True, bind=partial(primitive.bind, fun, fwd, bwd,
out_trees=out_trees))
bind = HashableFunction(
lambda *args, **kwargs: primitive.bind(fun, fwd, bwd, *args,
out_trees=out_trees, **kwargs),
(primitive, fun, fwd, bwd))
fake_primitive = _FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})

def process_axis_index(self, frame):
fake_primitive = types.SimpleNamespace(
multiple_results=False, bind=lambda _: jax.lax.axis_index(frame.name))
bind = HashableFunction(
lambda _: jax.lax.axis_index(frame.name),
(jax.lax.axis_index, frame.name))
fake_primitive = _FakePrimitive(multiple_results=False, bind=bind)
with core.eval_context():
range = jax.lax.iota(np.int32, frame.size)
dummy_tracer = MapTracer(self, range, {frame.name: 0})
Expand Down

0 comments on commit a711166

Please sign in to comment.