Skip to content

Commit

Permalink
Add a new jax.spmd_mode config for preventing unintentional hangs a…
Browse files Browse the repository at this point in the history
…nd incorrect results when users pass `jax.Array`s that span across multiple processes (i.e. not fully addressable) to `jit` or jnp operations (that are jitted by default).

Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array.

Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays).
* Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`.

PiperOrigin-RevId: 485075693
  • Loading branch information
yashk2810 authored and jax authors committed Oct 31, 2022
1 parent f3ddd56 commit ca1f58e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
1 change: 1 addition & 0 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
transfer_guard_host_to_device as transfer_guard_host_to_device,
transfer_guard_device_to_device as transfer_guard_device_to_device,
transfer_guard_device_to_host as transfer_guard_device_to_host,
spmd_mode as spmd_mode,
)
from .core import eval_context as ensure_compile_time_eval
from jax._src.environment_info import print_environment_info as print_environment_info
Expand Down
17 changes: 17 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,23 @@ def _update_jax_array_thread_local(val):
'used.'))


spmd_mode = config.define_enum_state(
name='jax_spmd_mode',
enum_values=['allow_all', 'allow_jit', 'allow_pjit'],
# TODO(yashkatariya): Default to `allow_jit` when the training wheels come
# off.
default='allow_pjit',
help=("Decides whether Math on `jax.Array`'s that are not fully addressable "
"(i.e. spans across multiple processes) is allowed. The options are: "
"* allow_pjit: Default, only `pjit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n"
"* allow_jit: `pjit` and `jax.jit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n"
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
" `jax.jit` and all other operations are allowed to "
" execute on non-fully addresable `jax.Array`s."))


distributed_debug = config.define_bool_state(
name='jax_distributed_debug',
default=False,
Expand Down
18 changes: 13 additions & 5 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import logging
import operator as op
import sys
import warnings
import threading
import types
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Expand Down Expand Up @@ -2810,17 +2809,26 @@ def lower_sharding_computation(
if d.process_index == d.client.process_index()]
if len(device_assignment) != len(local_device_assignment):
check_multihost_collective_allowlist(jaxpr)
# TODO(yashkatariya): Raise an error here and add a context manager.
if config.jax_array and api_name == 'jit':
warnings.warn(
# TODO(yashkatariya): Once jit and pjit's frontend is merged, use the
# argument on jit `_allow_multiprocess` (which will be added later) instead
# of the `api_name` check here.
# Furthermore, `allow_jit` is not allowed yet because `allow_jit` only
# allows explicit `jax.jit` to work but not implicitly jitted `jnp`.
# operations. This restriction will be relaxed in the future when the
# default value of `spmd_mode` config changes to `allow_jit`.
if (config.jax_array and api_name == 'jit' and
config.jax_spmd_mode != 'allow_all'):
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. It’s very important that all processes run "
"the same cross-process computations in the same order otherwise it "
"can lead to hangs.\n"
"If you’re not already familiar with JAX’s multi-process "
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html.")
"https://jax.readthedocs.io/en/latest/multi_process.html\n"
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")

has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
Expand Down

0 comments on commit ca1f58e

Please sign in to comment.