From ca1f58e37b80a0c6a318cc8a15f3dbbed554d12c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 31 Oct 2022 09:46:46 -0700 Subject: [PATCH] Add a new `jax.spmd_mode` config for preventing unintentional hangs and incorrect results when users pass `jax.Array`s that span across multiple processes (i.e. not fully addressable) to `jit` or jnp operations (that are jitted by default). Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array. Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays). * Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`. PiperOrigin-RevId: 485075693 --- jax/__init__.py | 1 + jax/_src/config.py | 17 +++++++++++++++++ jax/interpreters/pxla.py | 18 +++++++++++++----- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 9ae5f61930a2..83ac1bacc61c 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -57,6 +57,7 @@ transfer_guard_host_to_device as transfer_guard_host_to_device, transfer_guard_device_to_device as transfer_guard_device_to_device, transfer_guard_device_to_host as transfer_guard_device_to_host, + spmd_mode as spmd_mode, ) from .core import eval_context as ensure_compile_time_eval from jax._src.environment_info import print_environment_info as print_environment_info diff --git a/jax/_src/config.py b/jax/_src/config.py index 965cdde88547..1411bc710b20 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -710,6 +710,23 @@ def _update_jax_array_thread_local(val): 'used.')) +spmd_mode = config.define_enum_state( + name='jax_spmd_mode', + enum_values=['allow_all', 'allow_jit', 'allow_pjit'], + # TODO(yashkatariya): Default to `allow_jit` when the training wheels come + # off. + default='allow_pjit', + help=("Decides whether Math on `jax.Array`'s that are not fully addressable " + "(i.e. spans across multiple processes) is allowed. The options are: " + "* allow_pjit: Default, only `pjit` computations are allowed to " + " execute on non-fully addressable `jax.Array`s\n" + "* allow_jit: `pjit` and `jax.jit` computations are allowed to " + " execute on non-fully addressable `jax.Array`s\n" + "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " + " `jax.jit` and all other operations are allowed to " + " execute on non-fully addresable `jax.Array`s.")) + + distributed_debug = config.define_bool_state( name='jax_distributed_debug', default=False, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index d4f35c229736..85f06160ae02 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -39,7 +39,6 @@ import logging import operator as op import sys -import warnings import threading import types from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, @@ -2810,9 +2809,16 @@ def lower_sharding_computation( if d.process_index == d.client.process_index()] if len(device_assignment) != len(local_device_assignment): check_multihost_collective_allowlist(jaxpr) - # TODO(yashkatariya): Raise an error here and add a context manager. - if config.jax_array and api_name == 'jit': - warnings.warn( + # TODO(yashkatariya): Once jit and pjit's frontend is merged, use the + # argument on jit `_allow_multiprocess` (which will be added later) instead + # of the `api_name` check here. + # Furthermore, `allow_jit` is not allowed yet because `allow_jit` only + # allows explicit `jax.jit` to work but not implicitly jitted `jnp`. + # operations. This restriction will be relaxed in the future when the + # default value of `spmd_mode` config changes to `allow_jit`. + if (config.jax_array and api_name == 'jit' and + config.jax_spmd_mode != 'allow_all'): + raise RuntimeError( "Running operations on `Array`s that are not fully addressable by this " "process (i.e. `Array`s with data sharded across multiple devices and " "processes.) is dangerous. It’s very important that all processes run " @@ -2820,7 +2826,9 @@ def lower_sharding_computation( "can lead to hangs.\n" "If you’re not already familiar with JAX’s multi-process " "programming model, please read " - "https://jax.readthedocs.io/en/latest/multi_process.html.") + "https://jax.readthedocs.io/en/latest/multi_process.html\n" + "To fix this error, run your `jitted` computation inside " + "`with jax.spmd_mode('allow_all'):` context manager.") has_outfeed = core.jaxpr_uses_outfeed(jaxpr) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)