Skip to content

Commit

Permalink
Add deprecation warning to JaxTestCase and JaxTestLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 17, 2022
1 parent e545daa commit da3aaa1
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 15 deletions.
29 changes: 18 additions & 11 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.1 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.0...main).
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
default. To recover the previous behavior, use the `jax.test_util.with_config`
decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
...
```
* Added ``jax.scipy.linalg.schur``, ``jax.scipy.linalg.sqrtm``,
``jax.scipy.signal.csd``, ``jax.scipy.signal.stft``,
``jax.scipy.signal.welch``.

* Changes:
* `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated.
The suggested replacement is to use `parametrized.TestCase` directly. For tests that
rely on custom asserts such as `JaxTestCase.assertAllClose()`, the suggested replacement
is to use standard numpy testing utilities such as {func}`numpy.testing.assert_allclose()`,
which work directly with JAX arrays ({jax-issue}`#9620`).
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by default
({jax-issue}`#9562`). To recover the previous behavior, use the new
`jax.test_util.with_config` decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
...
```
* Added {func}`jax.scipy.linalg.schur`, {func}`jax.scipy.linalg.sqrtm`,
{func}`jax.scipy.signal.csd`, {func}`jax.scipy.signal.stft`,
{func}`jax.scipy.signal.welch`.

## jaxlib 0.3.1 (Unreleased)
* Changes
Expand Down
27 changes: 25 additions & 2 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# flake8: noqa: F401
# TODO(phawkins): remove all exports except check_grads/check_jvp/check_vjp.
from jax._src.test_util import (
JaxTestCase,
JaxTestLoader,
JaxTestCase as _PrivateJaxTestCase,
JaxTestLoader as _PrivateJaxTestLoader,
cases_from_list,
check_close,
check_eq,
Expand All @@ -31,3 +31,26 @@
xla_bridge,
_default_tolerance
)

class JaxTestCase(_PrivateJaxTestCase):
def __init__(self, *args, **kwargs):
import warnings
import textwrap
warnings.warn(textwrap.dedent("""\
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
The suggested replacement is to use parametrized.TestCase directly.
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
the suggested replacement is to use standard numpy testing utilities such
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
category=DeprecationWarning)
super().__init__(*args, **kwargs)

class JaxTestLoader(_PrivateJaxTestLoader):
def __init__(self, *args, **kwargs):
import warnings
warnings.warn(
"jax.test_util.JaxTestLoader is deprecated as of jax version 0.3.1. Use absltest.TestLoader directly.",
category=DeprecationWarning)
super().__init__(*args, **kwargs)

del _PrivateJaxTestCase, _PrivateJaxTestLoader
2 changes: 1 addition & 1 deletion tests/mesh_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from jax import test_util
from jax.experimental import mesh_utils
from jax.experimental.maps import Mesh
from jax._src import test_util


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion tests/svd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import functools

import jax
from jax import test_util as jtu
from jax.config import config
import jax.numpy as jnp
import numpy as np
import scipy.linalg as osp_linalg
from jax._src.lax import svd
from jax._src import test_util as jtu

from absl.testing import absltest
from absl.testing import parameterized
Expand Down

0 comments on commit da3aaa1

Please sign in to comment.