Skip to content

Commit

Permalink
Remove the duplicate _with_context_manager tests now that Mesh is…
Browse files Browse the repository at this point in the history
… the default

way to create a mesh.

PiperOrigin-RevId: 436333536
  • Loading branch information
yashk2810 authored and jax authors committed Mar 21, 2022
1 parent c9a9e56 commit 8b1cb8a
Showing 1 changed file with 0 additions and 35 deletions.
35 changes: 0 additions & 35 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,6 @@ def f(x, y):

class GDAPjitTest(jtu.JaxTestCase):

@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_single_output(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
Expand All @@ -823,40 +822,6 @@ def cb(index):
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)

with jax._src.config.parallel_functions_output_gda(True):
@partial(pjit, in_axis_resources=FROM_GDA, out_axis_resources=P('x', 'y'))
def f(x):
return x @ x.T
expected_matrix_mul = input_data @ input_data.T

out = f(gda_obj)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])

out2 = f(out)
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)

with self.assertRaisesRegex(
ValueError, ('For a non-GDA input, the corresponding resource in '
'in_axis_resources cannot be `pjit.FROM_GDA`.')):
f(input_data)

def test_pjit_gda_single_output_with_mesh_context_manager(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]

gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)

with jax._src.config.parallel_functions_output_gda(True):
with global_mesh:
@partial(pjit, in_axis_resources=FROM_GDA, out_axis_resources=P('x', 'y'))
Expand Down

0 comments on commit 8b1cb8a

Please sign in to comment.