Skip to content

Commit

Permalink
Check for jax.Sharding's number of devices instead of `py_array.num…
Browse files Browse the repository at this point in the history
…_shards` which looks at IFRT sharding's num_devices to check against `global_devices` and deciding whether to fall back to python shard_arg.

This is because IFRT sharding's `num_shards` method is busted. It doesn't return the global shards (in some cases) which leads to JAX program unnecessarily falling back to python.

PiperOrigin-RevId: 673067095
  • Loading branch information
yashk2810 authored and jax authors committed Sep 10, 2024
1 parent 02ab741 commit 90892f5
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import jax
from jax import lax
from jax._src import api
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dispatch
Expand Down Expand Up @@ -383,6 +384,25 @@ def mlir_lower_and_count(*args, **kwargs):
mlir.lower_jaxpr_to_module = mlir_lower


@contextmanager
def count_jax_array_shard_arg_calls():
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.

array_shard_arg = array._array_shard_arg
count = [0]

def array_shard_arg_and_count(*args, **kwargs):
count[0] += 1
return array_shard_arg(*args, **kwargs)

pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg_and_count
try:
yield count
finally:
pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg


@contextmanager
def count_jit_compilation_cache_miss():
# No need to clear any caches since we generally jit and pmap fresh callables
Expand Down Expand Up @@ -1965,7 +1985,7 @@ def arcsin(self, x):
# On branch cut, mpmath.mp.asin returns different value compared
# to mpmath.fp.asin and numpy.arcsin (see
# mpmath/mpmath#786). The following if-block ensures
# compatibiliy with numpy.arcsin.
# compatibility with numpy.arcsin.
if x.real > 1 and x.imag == 0:
return ctx.asin(x).conjugate()

Expand Down Expand Up @@ -1997,7 +2017,7 @@ def arccos(self, x):
return ctx.make_mpc((real._mpf_, (-sign_imag * inf)._mpf_))
# On branch cut, mpmath.mp.acos returns different value
# compared to mpmath.fp.acos and numpy.arccos. The
# following if-block ensures compatibiliy with
# following if-block ensures compatibility with
# numpy.arccos.
if x.imag == 0 and x.real > 1:
return -ctx.acos(x)
Expand Down Expand Up @@ -2026,7 +2046,7 @@ def arcsinh(self, x):
# On branch cut, mpmath.mp.asinh returns different value
# compared to mpmath.fp.asinh and numpy.arcsinh (see
# mpmath/mpmath#786). The following if-block ensures
# compatibiliy with numpy.arcsinh.
# compatibility with numpy.arcsinh.
if x.real == 0 and x.imag < -1:
return (-ctx.asinh(x)).conjugate()
return ctx.asinh(x)
Expand Down

0 comments on commit 90892f5

Please sign in to comment.