From c0577f70f99bc2912a2ac88fa8bee664d650f454 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 12 Jan 2023 22:42:06 +0000 Subject: [PATCH] Migrate pytestmark usage to new `@jtu.pytest_mark_if_available` decorator. See discussion in https://github.com/google/jax/pull/13977. Marking entire modules is magical and verbose, plus less precise than marking individual classes or tests. I wasn't super careful on which classes to mark, and erred on the side of marking too many classes (in line with the previous behavior). It's possible some test classes don't actually benefit from multiple accelerators. --- tests/multiprocess_gpu_test.py | 10 +--------- tests/pjit_test.py | 11 ++++++----- tests/pmap_test.py | 19 +++++++++++++------ tests/remote_transfer_test.py | 6 +----- tests/xmap_test.py | 19 ++++++++++++++----- 5 files changed, 35 insertions(+), 30 deletions(-) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index fc62608c5b79..2a753da60224 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -41,11 +41,6 @@ except ImportError: portpicker = None -try: - import pytest -except ImportError: - pytest = None - config.parse_flags_with_absl() @unittest.skipIf(not portpicker, "Test requires portpicker") @@ -230,12 +225,9 @@ def test_gpu_ompi_distributed_initialize(self): @unittest.skipIf( os.environ.get("SLURM_JOB_NUM_NODES", None) != "2", "Slurm environment with at least two nodes needed!") -@unittest.skipIf(not pytest, "Test requires pytest markers") +@jtu.pytest_mark_if_available('SlurmMultiNodeGpuTest') class SlurmMultiNodeGpuTest(jtu.JaxTestCase): - if pytest is not None: - pytestmark = pytest.mark.SlurmMultiNodeGpuTest - def sorted_devices(self): devices = sorted(jax.devices(), key=lambda d: (d.id, d.host_id)) if len(devices) != 16: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2f252c3d4458..2051876a38c7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import os import re from functools import partial, lru_cache @@ -57,10 +56,6 @@ prev_xla_flags = None -with contextlib.suppress(ImportError): - import pytest - pytestmark = pytest.mark.multiaccelerator - def setUpModule(): global prev_xla_flags @@ -136,6 +131,7 @@ def check_1d_2d_mesh(f, set_mesh): # TODO(skye): make the buffer donation utils part of JaxTestCase +@jtu.pytest_mark_if_available('multiaccelerator') class PJitTest(jtu.BufferDonationTestCase): @jtu.with_mesh([('x', 1)]) @@ -1119,6 +1115,7 @@ def f(x, y): self.assertArraysEqual(result0, result1) self.assertArraysEqual(result1, result2) +@jtu.pytest_mark_if_available('multiaccelerator') class GDAPjitTest(jtu.JaxTestCase): def setUp(self): @@ -1511,6 +1508,7 @@ def make_keys(seeds): out.unsafe_raw_array() # doesn't crash +@jtu.pytest_mark_if_available('multiaccelerator') class AutoShardingPjitTest(jtu.JaxTestCase): def setUp(self): @@ -1713,6 +1711,7 @@ def test_pjit_array_error(self): f(*inputs) +@jtu.pytest_mark_if_available('multiaccelerator') class ArrayPjitTest(jtu.JaxTestCase): def setUp(self): @@ -3151,6 +3150,7 @@ def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") +@jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): @check_1d_2d_mesh(set_mesh=True) @@ -3439,6 +3439,7 @@ def test_pjit_with_deleted_input_at_subsequent_call(self, committed): _ = f(x) +@jtu.pytest_mark_if_available('multiaccelerator') class UtilTest(jtu.JaxTestCase): def testOpShardingRoundTrip(self): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 1ffcc7912baf..9dc178665909 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -14,7 +14,6 @@ from concurrent.futures import ThreadPoolExecutor -import contextlib from functools import partial import itertools as it import gc @@ -56,11 +55,6 @@ prev_xla_flags = None - -with contextlib.suppress(ImportError): - import pytest - pytestmark = pytest.mark.multiaccelerator - compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] def all_bdims(*shapes, pmap): @@ -137,6 +131,7 @@ def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None, input_shape, pmap_sharding, lambda idx: input_data[idx]), input_data +@jtu.pytest_mark_if_available('multiaccelerator') class PythonPmapTest(jtu.JaxTestCase): @property @@ -2064,6 +2059,7 @@ def test_remat_of_pmap_policy(self, remat): self.assertEqual(jaxpr_text.count(' cos '), 2) +@jtu.pytest_mark_if_available('multiaccelerator') class CppPmapTest(PythonPmapTest): @property @@ -2107,6 +2103,7 @@ def test_cache_uses_jax_key(self): self.assertEqual(pmaped_f._cache_size, 1) +@jtu.pytest_mark_if_available('multiaccelerator') class VmapOfPmapTest(jtu.JaxTestCase): # TODO(apaszke) @@ -2149,6 +2146,7 @@ def args_slice(vi, pi): self.assertAllClose(ans, expected) +@jtu.pytest_mark_if_available('multiaccelerator') class VmapPmapCollectivesTest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -2346,6 +2344,7 @@ def testVsVmap(self, prim, tiled): self.assertAllClose(vmap(f, axis_name='i')(x), pmap(f, axis_name='i')(x)) +@jtu.pytest_mark_if_available('multiaccelerator') class PmapWithDevicesTest(jtu.JaxTestCase): def testAllDevices(self): @@ -2602,6 +2601,7 @@ def h(y): jax.grad(mk_case(vmap))(x, y)) +@jtu.pytest_mark_if_available('multiaccelerator') class ShardedDeviceArrayTest(jtu.JaxTestCase): def testThreadsafeIndexing(self): @@ -2883,6 +2883,7 @@ def _spec_str(spec): f"{spec.mesh_mapping},)") +@jtu.pytest_mark_if_available('multiaccelerator') class ShardArgsTest(jtu.JaxTestCase): def numpy_array(x): @@ -2949,6 +2950,7 @@ def testShardArgs(self, shape, spec, make_arg): self.assertAllClose(np.asarray(buf), x[idx], check_dtypes=False) +@jtu.pytest_mark_if_available('multiaccelerator') class ArrayPmapTest(jtu.JaxTestCase): def test_pmap_input_array_output_array(self): @@ -3136,6 +3138,7 @@ def tearDown(self): config.update('jax_disable_jit', self.jit_disabled) super().tearDown() +@jtu.pytest_mark_if_available('multiaccelerator') class EagerPythonPmapTest(EagerPmapMixin, PythonPmapTest): def test_custom_jvp(self): @@ -3171,15 +3174,19 @@ def foo_bwd(_, g): self.assertAllClose(self.pmap(f)(x), jax.vmap(f)(x)) +@jtu.pytest_mark_if_available('multiaccelerator') class EagerCppPmapTest(EagerPmapMixin, CppPmapTest): pass +@jtu.pytest_mark_if_available('multiaccelerator') class EagerPmapWithDevicesTest(EagerPmapMixin, PmapWithDevicesTest): pass +@jtu.pytest_mark_if_available('multiaccelerator') class EagerVmapOfPmapTest(EagerPmapMixin, VmapOfPmapTest): pass +@jtu.pytest_mark_if_available('multiaccelerator') class EagerArrayPmapTest(EagerPmapMixin, ArrayPmapTest): pass diff --git a/tests/remote_transfer_test.py b/tests/remote_transfer_test.py index fb667c09532f..33e34e40cd11 100644 --- a/tests/remote_transfer_test.py +++ b/tests/remote_transfer_test.py @@ -14,7 +14,6 @@ """Tests for cross host device transfer.""" from absl.testing import absltest -import contextlib import unittest import numpy as np @@ -25,11 +24,8 @@ config.parse_flags_with_absl() -with contextlib.suppress(ImportError): - import pytest - pytestmark = pytest.mark.multiaccelerator - +@jtu.pytest_mark_if_available('multiaccelerator') class RemoteTransferTest(jtu.JaxTestCase): # TODO(jheek): this test crashes on multi-GPU. diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 0e773e5a47a1..c14bc312e328 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import functools import itertools as it import os @@ -55,10 +54,6 @@ from jax.config import config config.parse_flags_with_absl() -with contextlib.suppress(ImportError): - import pytest - pytestmark = pytest.mark.multiaccelerator - # TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py # Run all tests with 8 CPU devices. @@ -240,6 +235,7 @@ def divisors2(n: int) -> Iterator[Tuple[int, int]]: yield axis_resources, mesh_data +@jtu.pytest_mark_if_available('multiaccelerator') class XMapTestCase(jtu.BufferDonationTestCase): pass @@ -267,6 +263,7 @@ def tearDown(self): jtu.restore_spmd_lowering_flag() +@jtu.pytest_mark_if_available('multiaccelerator') class XMapTest(XMapTestCase): def testBasic(self): @@ -801,6 +798,7 @@ def testNewCheckpointNonlinearWithPolicy(self): jax.grad(lambda x: f(x).sum())(jnp.arange(3.)) # TODO crashes! +@jtu.pytest_mark_if_available('multiaccelerator') class XMapTestSPMD(SPMDTestMixin, XMapTest): """Re-executes all basic tests with the SPMD partitioner enabled""" @@ -856,6 +854,7 @@ def testConstantsInLowering(self): jnp.concatenate([yp, yp])) +@jtu.pytest_mark_if_available('multiaccelerator') class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase): @jtu.with_mesh([('x', 2)]) def testBasic(self): @@ -976,6 +975,7 @@ def testConstantsInLowering(self): jnp.concatenate([yp, yp])) +@jtu.pytest_mark_if_available('multiaccelerator') class NamedNumPyTest(XMapTestCase): @jtu.sample_product( @@ -1000,6 +1000,7 @@ def testReductions(self, reduction, axes, mapped_axis): self.assertAllClose(ref_red(x), xmap_red(x)) +@jtu.pytest_mark_if_available('multiaccelerator') class NamedRandomTest(XMapTestCase): SAMPLERS = [ @@ -1038,6 +1039,7 @@ def sample(axis_resources): self.assertAllClose(sample({}), sample(dict(axis_resources))) +@jtu.pytest_mark_if_available('multiaccelerator') class NamedNNTest(XMapTestCase): def testOneHot(self): @@ -1090,6 +1092,7 @@ def testVarianceScaling(self, map_in, map_out, fan, distr): atol=1e-4, rtol=2e-2) +@jtu.pytest_mark_if_available('multiaccelerator') class XMapGDATest(XMapTestCase): def setUp(self): @@ -1250,6 +1253,7 @@ def test_gda_from_pjit_with_xmap_sharding_mismatch(self): +@jtu.pytest_mark_if_available('multiaccelerator') class XMapArrayTest(XMapTestCase): def test_basic(self): @@ -1327,6 +1331,7 @@ def test_xmap_array_sharding_mismatch(self): f(input_array) +@jtu.pytest_mark_if_available('multiaccelerator') class NewPrimitiveTest(XMapTestCase): def testGatherPositional(self): @@ -1351,6 +1356,7 @@ def testGather(self, mesh, axis_resources): self.assertAllClose(f(x, y), f_ref(x, y)) +@jtu.pytest_mark_if_available('multiaccelerator') class NewPrimitiveTestSPMD(SPMDTestMixin, NewPrimitiveTest): pass @@ -1454,6 +1460,7 @@ def schedules_from_pdot_spec( yield from schedules(logical_sizes) +@jtu.pytest_mark_if_available('multiaccelerator') class PDotTests(XMapTestCase): @jtu.with_mesh([('r1', 2)]) @@ -1764,6 +1771,7 @@ def check(spec): check('jk{i,b}->k{b}') +@jtu.pytest_mark_if_available('multiaccelerator') class XMapErrorTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 2)]) @@ -2011,6 +2019,7 @@ def testAxesMismatch(self): xmap(lambda x: x, (p,), (p, ['x']))([x, x, x]) # Error, we raise a generic tree mismatch message +@jtu.pytest_mark_if_available('multiaccelerator') class NamedAutodiffTests(jtu.JaxTestCase): def testVjpReduceAxes(self):