Skip to content

Commit

Permalink
Disable some tests with jax.Array that are failing in OSS due to usin…
Browse files Browse the repository at this point in the history
…g minimum_jaxlib_version. I will bump the version again this week.

PiperOrigin-RevId: 488708528
  • Loading branch information
yashk2810 authored and jax authors committed Nov 15, 2022
1 parent d742e6a commit eca1241
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,7 @@ def f(k):
python_should_be_executing = False
self.assertEqual(x, f(x))

@unittest.skipIf(xla_extension_version < 99, "C++ jax.Array is not available")
def test_hitting_cpp_path(self):
if not self.use_cpp_jit:
raise unittest.SkipTest("this test only applies to _cpp_jit")
Expand Down Expand Up @@ -1197,6 +1198,7 @@ def test_caches_dont_depend_on_unnamed_axis_env(self):
self.assertEqual(count[0], 0) # no compiles
self.assertArraysAllClose(ans, expected, check_dtypes=True)

@unittest.skipIf(xla_extension_version < 99, "C++ jax.Array is not available")
def test_cache_key_defaults(self):
# https://github.com/google/jax/discussions/11875
if not self.use_cpp_jit:
Expand Down
4 changes: 4 additions & 0 deletions tests/jax_jit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from functools import partial
import inspect
import unittest

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -23,6 +24,7 @@
from jax._src import lib as jaxlib
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax.config import config
import numpy as np

Expand Down Expand Up @@ -69,6 +71,7 @@ def test_device_put_on_numpy_arrays(self, device_put_function):
dtype=dtype))

@parameterized.parameters([jax.device_put, _cpp_device_put])
@unittest.skipIf(xla_extension_version < 99, "C++ jax.Array is not available")
def test_device_put_on_buffers(self, device_put_function):
device = jax.devices()[0]
jitted_f = jax.jit(lambda x: x + 1)
Expand All @@ -83,6 +86,7 @@ def test_device_put_on_buffers(self, device_put_function):
np.testing.assert_array_equal(output_buffer, np.array(value + 1))

@parameterized.parameters([jax.device_put, _cpp_device_put])
@unittest.skipIf(xla_extension_version < 99, "C++ jax.Array is not available")
def test_device_put_on_sharded_device_array(self, device_put_function):
device = jax.devices()[0]

Expand Down
4 changes: 4 additions & 0 deletions tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import partial
import itertools
import operator
import unittest
from unittest import SkipTest

from absl.testing import absltest
Expand All @@ -32,6 +33,7 @@

from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -490,6 +492,7 @@ def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes,
name=[rec.name for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2],
othertype=[dict, list, tuple, set],
)
@unittest.skipIf(xla_extension_version < 99, "C++ jax.Array is not available")
def testOperatorOverloadErrors(self, name, othertype):
# Test that binary operators with builtin collections raise a TypeError
# and report the types in the correct order.
Expand All @@ -509,6 +512,7 @@ def testOperatorOverloadErrors(self, name, othertype):
name=[rec.name for rec in JAX_RIGHT_OPERATOR_OVERLOADS if rec.nargs == 2],
othertype=[dict, list, tuple, set],
)
@unittest.skipIf(xla_extension_version < 99, "C++ jax.Array is not available")
def testRightOperatorOverloadErrors(self, name, othertype):
# Test that binary operators with builtin collections raise a TypeError
# and report the types in the correct order.
Expand Down

0 comments on commit eca1241

Please sign in to comment.