Skip to content

Commit

Permalink
Migrate pytestmark usage to new @jtu.pytest_mark_if_available decor…
Browse files Browse the repository at this point in the history
…ator.

See discussion in jax-ml#13977. Marking
entire modules is magical and verbose, plus less precise than marking
individual classes or tests.

I wasn't super careful on which classes to mark, and erred on the side
of marking too many classes (in line with the previous behavior). It's
possible some test classes don't actually benefit from multiple
accelerators.
  • Loading branch information
skye committed Jan 12, 2023
1 parent 15ec37c commit c0577f7
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 30 deletions.
10 changes: 1 addition & 9 deletions tests/multiprocess_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@
except ImportError:
portpicker = None

try:
import pytest
except ImportError:
pytest = None

config.parse_flags_with_absl()

@unittest.skipIf(not portpicker, "Test requires portpicker")
Expand Down Expand Up @@ -230,12 +225,9 @@ def test_gpu_ompi_distributed_initialize(self):
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")
@unittest.skipIf(not pytest, "Test requires pytest markers")
@jtu.pytest_mark_if_available('SlurmMultiNodeGpuTest')
class SlurmMultiNodeGpuTest(jtu.JaxTestCase):

if pytest is not None:
pytestmark = pytest.mark.SlurmMultiNodeGpuTest

def sorted_devices(self):
devices = sorted(jax.devices(), key=lambda d: (d.id, d.host_id))
if len(devices) != 16:
Expand Down
11 changes: 6 additions & 5 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import os
import re
from functools import partial, lru_cache
Expand Down Expand Up @@ -57,10 +56,6 @@

prev_xla_flags = None

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator


def setUpModule():
global prev_xla_flags
Expand Down Expand Up @@ -136,6 +131,7 @@ def check_1d_2d_mesh(f, set_mesh):


# TODO(skye): make the buffer donation utils part of JaxTestCase
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitTest(jtu.BufferDonationTestCase):

@jtu.with_mesh([('x', 1)])
Expand Down Expand Up @@ -1119,6 +1115,7 @@ def f(x, y):
self.assertArraysEqual(result0, result1)
self.assertArraysEqual(result1, result2)

@jtu.pytest_mark_if_available('multiaccelerator')
class GDAPjitTest(jtu.JaxTestCase):

def setUp(self):
Expand Down Expand Up @@ -1511,6 +1508,7 @@ def make_keys(seeds):
out.unsafe_raw_array() # doesn't crash


@jtu.pytest_mark_if_available('multiaccelerator')
class AutoShardingPjitTest(jtu.JaxTestCase):

def setUp(self):
Expand Down Expand Up @@ -1713,6 +1711,7 @@ def test_pjit_array_error(self):
f(*inputs)


@jtu.pytest_mark_if_available('multiaccelerator')
class ArrayPjitTest(jtu.JaxTestCase):

def setUp(self):
Expand Down Expand Up @@ -3151,6 +3150,7 @@ def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):

@check_1d_2d_mesh(set_mesh=True)
Expand Down Expand Up @@ -3439,6 +3439,7 @@ def test_pjit_with_deleted_input_at_subsequent_call(self, committed):
_ = f(x)


@jtu.pytest_mark_if_available('multiaccelerator')
class UtilTest(jtu.JaxTestCase):

def testOpShardingRoundTrip(self):
Expand Down
19 changes: 13 additions & 6 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from concurrent.futures import ThreadPoolExecutor
import contextlib
from functools import partial
import itertools as it
import gc
Expand Down Expand Up @@ -56,11 +55,6 @@

prev_xla_flags = None


with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator

compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]

def all_bdims(*shapes, pmap):
Expand Down Expand Up @@ -137,6 +131,7 @@ def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
input_shape, pmap_sharding, lambda idx: input_data[idx]), input_data


@jtu.pytest_mark_if_available('multiaccelerator')
class PythonPmapTest(jtu.JaxTestCase):

@property
Expand Down Expand Up @@ -2064,6 +2059,7 @@ def test_remat_of_pmap_policy(self, remat):
self.assertEqual(jaxpr_text.count(' cos '), 2)


@jtu.pytest_mark_if_available('multiaccelerator')
class CppPmapTest(PythonPmapTest):

@property
Expand Down Expand Up @@ -2107,6 +2103,7 @@ def test_cache_uses_jax_key(self):
self.assertEqual(pmaped_f._cache_size, 1)


@jtu.pytest_mark_if_available('multiaccelerator')
class VmapOfPmapTest(jtu.JaxTestCase):

# TODO(apaszke)
Expand Down Expand Up @@ -2149,6 +2146,7 @@ def args_slice(vi, pi):
self.assertAllClose(ans, expected)


@jtu.pytest_mark_if_available('multiaccelerator')
class VmapPmapCollectivesTest(jtu.JaxTestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -2346,6 +2344,7 @@ def testVsVmap(self, prim, tiled):
self.assertAllClose(vmap(f, axis_name='i')(x), pmap(f, axis_name='i')(x))


@jtu.pytest_mark_if_available('multiaccelerator')
class PmapWithDevicesTest(jtu.JaxTestCase):

def testAllDevices(self):
Expand Down Expand Up @@ -2602,6 +2601,7 @@ def h(y):
jax.grad(mk_case(vmap))(x, y))


@jtu.pytest_mark_if_available('multiaccelerator')
class ShardedDeviceArrayTest(jtu.JaxTestCase):

def testThreadsafeIndexing(self):
Expand Down Expand Up @@ -2883,6 +2883,7 @@ def _spec_str(spec):
f"{spec.mesh_mapping},)")


@jtu.pytest_mark_if_available('multiaccelerator')
class ShardArgsTest(jtu.JaxTestCase):

def numpy_array(x):
Expand Down Expand Up @@ -2949,6 +2950,7 @@ def testShardArgs(self, shape, spec, make_arg):
self.assertAllClose(np.asarray(buf), x[idx], check_dtypes=False)


@jtu.pytest_mark_if_available('multiaccelerator')
class ArrayPmapTest(jtu.JaxTestCase):

def test_pmap_input_array_output_array(self):
Expand Down Expand Up @@ -3136,6 +3138,7 @@ def tearDown(self):
config.update('jax_disable_jit', self.jit_disabled)
super().tearDown()

@jtu.pytest_mark_if_available('multiaccelerator')
class EagerPythonPmapTest(EagerPmapMixin, PythonPmapTest):

def test_custom_jvp(self):
Expand Down Expand Up @@ -3171,15 +3174,19 @@ def foo_bwd(_, g):
self.assertAllClose(self.pmap(f)(x), jax.vmap(f)(x))


@jtu.pytest_mark_if_available('multiaccelerator')
class EagerCppPmapTest(EagerPmapMixin, CppPmapTest):
pass

@jtu.pytest_mark_if_available('multiaccelerator')
class EagerPmapWithDevicesTest(EagerPmapMixin, PmapWithDevicesTest):
pass

@jtu.pytest_mark_if_available('multiaccelerator')
class EagerVmapOfPmapTest(EagerPmapMixin, VmapOfPmapTest):
pass

@jtu.pytest_mark_if_available('multiaccelerator')
class EagerArrayPmapTest(EagerPmapMixin, ArrayPmapTest):
pass

Expand Down
6 changes: 1 addition & 5 deletions tests/remote_transfer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for cross host device transfer."""

from absl.testing import absltest
import contextlib
import unittest
import numpy as np

Expand All @@ -25,11 +24,8 @@

config.parse_flags_with_absl()

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator


@jtu.pytest_mark_if_available('multiaccelerator')
class RemoteTransferTest(jtu.JaxTestCase):

# TODO(jheek): this test crashes on multi-GPU.
Expand Down
19 changes: 14 additions & 5 deletions tests/xmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import itertools as it
import os
Expand Down Expand Up @@ -55,10 +54,6 @@
from jax.config import config
config.parse_flags_with_absl()

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator


# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
# Run all tests with 8 CPU devices.
Expand Down Expand Up @@ -240,6 +235,7 @@ def divisors2(n: int) -> Iterator[Tuple[int, int]]:
yield axis_resources, mesh_data


@jtu.pytest_mark_if_available('multiaccelerator')
class XMapTestCase(jtu.BufferDonationTestCase):
pass

Expand Down Expand Up @@ -267,6 +263,7 @@ def tearDown(self):
jtu.restore_spmd_lowering_flag()


@jtu.pytest_mark_if_available('multiaccelerator')
class XMapTest(XMapTestCase):

def testBasic(self):
Expand Down Expand Up @@ -801,6 +798,7 @@ def testNewCheckpointNonlinearWithPolicy(self):
jax.grad(lambda x: f(x).sum())(jnp.arange(3.)) # TODO crashes!


@jtu.pytest_mark_if_available('multiaccelerator')
class XMapTestSPMD(SPMDTestMixin, XMapTest):
"""Re-executes all basic tests with the SPMD partitioner enabled"""

Expand Down Expand Up @@ -856,6 +854,7 @@ def testConstantsInLowering(self):
jnp.concatenate([yp, yp]))


@jtu.pytest_mark_if_available('multiaccelerator')
class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
@jtu.with_mesh([('x', 2)])
def testBasic(self):
Expand Down Expand Up @@ -976,6 +975,7 @@ def testConstantsInLowering(self):
jnp.concatenate([yp, yp]))


@jtu.pytest_mark_if_available('multiaccelerator')
class NamedNumPyTest(XMapTestCase):

@jtu.sample_product(
Expand All @@ -1000,6 +1000,7 @@ def testReductions(self, reduction, axes, mapped_axis):
self.assertAllClose(ref_red(x), xmap_red(x))


@jtu.pytest_mark_if_available('multiaccelerator')
class NamedRandomTest(XMapTestCase):

SAMPLERS = [
Expand Down Expand Up @@ -1038,6 +1039,7 @@ def sample(axis_resources):
self.assertAllClose(sample({}), sample(dict(axis_resources)))


@jtu.pytest_mark_if_available('multiaccelerator')
class NamedNNTest(XMapTestCase):

def testOneHot(self):
Expand Down Expand Up @@ -1090,6 +1092,7 @@ def testVarianceScaling(self, map_in, map_out, fan, distr):
atol=1e-4, rtol=2e-2)


@jtu.pytest_mark_if_available('multiaccelerator')
class XMapGDATest(XMapTestCase):

def setUp(self):
Expand Down Expand Up @@ -1250,6 +1253,7 @@ def test_gda_from_pjit_with_xmap_sharding_mismatch(self):



@jtu.pytest_mark_if_available('multiaccelerator')
class XMapArrayTest(XMapTestCase):

def test_basic(self):
Expand Down Expand Up @@ -1327,6 +1331,7 @@ def test_xmap_array_sharding_mismatch(self):
f(input_array)


@jtu.pytest_mark_if_available('multiaccelerator')
class NewPrimitiveTest(XMapTestCase):

def testGatherPositional(self):
Expand All @@ -1351,6 +1356,7 @@ def testGather(self, mesh, axis_resources):
self.assertAllClose(f(x, y), f_ref(x, y))


@jtu.pytest_mark_if_available('multiaccelerator')
class NewPrimitiveTestSPMD(SPMDTestMixin, NewPrimitiveTest):
pass

Expand Down Expand Up @@ -1454,6 +1460,7 @@ def schedules_from_pdot_spec(
yield from schedules(logical_sizes)


@jtu.pytest_mark_if_available('multiaccelerator')
class PDotTests(XMapTestCase):

@jtu.with_mesh([('r1', 2)])
Expand Down Expand Up @@ -1764,6 +1771,7 @@ def check(spec):
check('jk{i,b}->k{b}')


@jtu.pytest_mark_if_available('multiaccelerator')
class XMapErrorTest(jtu.JaxTestCase):

@jtu.with_mesh([('x', 2)])
Expand Down Expand Up @@ -2011,6 +2019,7 @@ def testAxesMismatch(self):
xmap(lambda x: x, (p,), (p, ['x']))([x, x, x]) # Error, we raise a generic tree mismatch message


@jtu.pytest_mark_if_available('multiaccelerator')
class NamedAutodiffTests(jtu.JaxTestCase):

def testVjpReduceAxes(self):
Expand Down

0 comments on commit c0577f7

Please sign in to comment.