Skip to content

Commit

Permalink
Remove experimental_cpp_jit since that flag is unused and also remo…
Browse files Browse the repository at this point in the history
…ve `experimental_cpp_pjit`.

For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
  • Loading branch information
yashk2810 authored and jax authors committed Apr 7, 2023
1 parent b15ebb1 commit 694e43a
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 66 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.9

* Changes
* The flags experimental_cpp_jit, and experimental_cpp_pjit have been removed.
They are now always on.

* Deprecations
* `jax.experimental.gda_serialization` is deprecated and has been renamed to
`jax.experimental.array_serialization`.
Expand Down
72 changes: 24 additions & 48 deletions benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import google_benchmark
import jax
from jax import lax
from jax._src import config as jax_config
from jax.experimental import sparse
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
Expand Down Expand Up @@ -693,9 +692,6 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):

x = [x for _ in range(num_args)]

prev_state = jax_config.FLAGS.experimental_cpp_pjit
jax_config.FLAGS.experimental_cpp_pjit = cpp_jit

in_axis_resources = jax.sharding.NamedSharding(mesh, spec)
out_axis_resources = jax.sharding.NamedSharding(mesh, spec)

Expand All @@ -713,54 +709,40 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
while state:
x = f(x)

jax_config.FLAGS.experimental_cpp_pjit = prev_state


@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_simple_1_device(state):
pjit_simple_benchmark(
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))

@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_simple_4_device(state):
pjit_simple_benchmark(
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))

@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_simple_4000_device(state):
pjit_simple_benchmark(
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))


@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_aot_1_device(state):
pjit_simple_benchmark(
state,
Expand All @@ -771,13 +753,10 @@ def pjit_aot_1_device(state):


@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_aot_4_device(state):
pjit_simple_benchmark(
state,
Expand All @@ -788,13 +767,10 @@ def pjit_aot_4_device(state):


@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_aot_4000_device(state):
pjit_simple_benchmark(
state,
Expand Down
10 changes: 0 additions & 10 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,11 @@

FLAGS = flags.FLAGS

flags.DEFINE_bool(
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
"A flag enabling the C++ jax.jit fast path."
"Set this to `False` only if it crashes otherwise and report "
"the error to the jax-team.")
flags.DEFINE_bool(
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
flags.DEFINE_bool(
"experimental_cpp_pjit", bool_env("JAX_CPP_PJIT", True),
"A flag enabling the C++ pjit fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")

map = safe_map

Expand Down
6 changes: 0 additions & 6 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.config import flags
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -99,8 +98,6 @@ class WeakRefList(list):
ShardingSpec = sharding_specs.ShardingSpec




### util

def identity(x): return x
Expand Down Expand Up @@ -2811,9 +2808,6 @@ def create_cpp_call(self, no_kwargs, in_tree, out_tree):
not self.unsafe_call.has_host_callbacks):
return None

if not flags.FLAGS.experimental_cpp_pjit:
return None

def aot_cache_miss(*args, **kwargs):
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, FLAGS)
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info)
from jax._src.errors import JAXTypeError
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -354,7 +354,7 @@ def pre_infer_params(fun, in_shardings, out_shardings,
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums, abstracted_axes,
pjit_has_explicit_sharding):
if FLAGS.experimental_cpp_pjit and abstracted_axes is None:
if abstracted_axes is None:
wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums, pjit_has_explicit_sharding)
else:
Expand Down

0 comments on commit 694e43a

Please sign in to comment.