Skip to content

Commit

Permalink
clean up multibackend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
levskaya committed Aug 24, 2019
1 parent a57a1a3 commit 91a2311
Showing 1 changed file with 25 additions and 88 deletions.
113 changes: 25 additions & 88 deletions tests/multibackend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as onp
import numpy.random as npr
import six
from unittest import SkipTest

from jax import api
from jax import test_util as jtu
Expand All @@ -42,10 +43,11 @@ class MultiBackendTest(jtu.JaxTestCase):
{"testcase_name": "_backend={}".format(backend),
"backend": backend,
}
for backend in ['cpu', 'gpu']
for backend in ['cpu', 'gpu', 'tpu', None]
))
@jtu.skip_on_devices('cpu', 'tpu')
def testGpuMultiBackend(self, backend):
def testMultiBackend(self, backend):
if backend not in ('cpu', jtu.device_under_test(), None):
raise SkipTest()
@partial(api.jit, backend=backend)
def fun(x, y):
return np.matmul(x, y)
Expand All @@ -54,33 +56,17 @@ def fun(x, y):
z_host = onp.matmul(x, y)
z = fun(x, y)
self.assertAllClose(z, z_host, check_dtypes=True)
self.assertEqual(z.device_buffer.platform(), backend)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_backend={}".format(backend),
"backend": backend,
}
for backend in ['cpu', 'tpu']
))
@jtu.skip_on_devices('cpu', 'gpu')
def testTpuMultiBackend(self, backend):
@partial(api.jit, backend=backend)
def fun(x, y):
return np.matmul(x, y)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
z_host = onp.matmul(x, y)
z = fun(x, y)
self.assertAllClose(z, z_host, check_dtypes=True)
self.assertEqual(z.device_buffer.platform(), backend)
correct_platform = backend if backend else jtu.device_under_test()
self.assertEqual(z.device_buffer.platform(), correct_platform)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ordering={}".format(ordering),
"ordering": ordering,}
for ordering in [('cpu', None), ('tpu', None)]))
@jtu.skip_on_devices('cpu', 'gpu')
def testTpuMultiBackendNestedJit(self, ordering):
for ordering in [('cpu', None), ('gpu', None), ('tpu', None), (None, None)]))
def testMultiBackendNestedJit(self, ordering):
outer, inner = ordering
if outer not in ('cpu', jtu.device_under_test(), None):
raise SkipTest()
@partial(api.jit, backend=outer)
def fun(x, y):
@partial(api.jit, backend=inner)
Expand All @@ -92,56 +78,23 @@ def infun(x, y):
z_host = onp.matmul(x, y) + onp.ones_like(x)
z = fun(x, y)
self.assertAllClose(z, z_host, check_dtypes=True)
self.assertEqual(z.device_buffer.platform(), outer)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ordering={}".format(ordering),
"ordering": ordering,}
for ordering in [('cpu', None), ('gpu', None)]))
@jtu.skip_on_devices('cpu', 'tpu')
def testGpuMultiBackendNestedJit(self, ordering):
outer, inner = ordering
@partial(api.jit, backend=outer)
def fun(x, y):
@partial(api.jit, backend=inner)
def infun(x, y):
return np.matmul(x, y)
return infun(x, y) + np.ones_like(x)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
z_host = onp.matmul(x, y) + onp.ones_like(x)
z = fun(x, y)
self.assertAllClose(z, z_host, check_dtypes=True)
self.assertEqual(z.device_buffer.platform(), outer)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ordering={}".format(ordering),
"ordering": ordering,}
for ordering in [
('cpu', 'tpu'), ('tpu', 'cpu'), (None, 'cpu'), (None, 'tpu'),
]))
@jtu.skip_on_devices('cpu', 'gpu')
def testTpuMultiBackendNestedJitConflict(self, ordering):
outer, inner = ordering
@partial(api.jit, backend=outer)
def fun(x, y):
@partial(api.jit, backend=inner)
def infun(x, y):
return np.matmul(x, y)
return infun(x, y) + np.ones_like(x)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
self.assertRaises(ValueError, lambda: fun(x, y))
correct_platform = outer if outer else jtu.device_under_test()
self.assertEqual(z.device_buffer.platform(), correct_platform)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ordering={}".format(ordering),
"ordering": ordering,}
for ordering in [
('cpu', 'gpu'), ('gpu', 'cpu'), (None, 'cpu'), (None, 'gpu'),
('cpu', 'gpu'), ('gpu', 'cpu'),
('cpu', 'tpu'), ('tpu', 'cpu'),
(None, 'cpu'), (None, 'gpu'), (None, 'tpu'),
]))
@jtu.skip_on_devices('cpu', 'tpu')
def testGpuMultiBackendNestedJitConflict(self, ordering):
def testMultiBackendNestedJitConflict(self, ordering):
outer, inner = ordering
if outer not in ('cpu', jtu.device_under_test(), None):
raise SkipTest()
if inner not in ('cpu', jtu.device_under_test(), None):
raise SkipTest()
@partial(api.jit, backend=outer)
def fun(x, y):
@partial(api.jit, backend=inner)
Expand All @@ -155,10 +108,11 @@ def infun(x, y):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_backend={}".format(backend),
"backend": backend,}
for backend in ['cpu', 'gpu']
for backend in ['cpu', 'gpu', 'tpu']
))
@jtu.skip_on_devices('cpu', 'tpu')
def testGpuMultiBackendOpByOpReturn(self, backend):
if backend not in ('cpu', jtu.device_under_test()):
raise SkipTest()
@partial(api.jit, backend=backend)
def fun(x, y):
return np.matmul(x, y)
Expand All @@ -167,24 +121,7 @@ def fun(x, y):
z = fun(x, y)
w = np.sin(z)
self.assertEqual(z.device_buffer.platform(), backend)
self.assertEqual(w.device_buffer.platform(), 'gpu')

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_backend={}".format(backend),
"backend": backend,}
for backend in ['cpu', 'tpu']
))
@jtu.skip_on_devices('cpu', 'gpu')
def testTpuMultiBackendOpByOpReturn(self, backend):
@partial(api.jit, backend=backend)
def fun(x, y):
return np.matmul(x, y)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
z = fun(x, y)
w = np.sin(z)
self.assertEqual(z.device_buffer.platform(), backend)
self.assertEqual(w.device_buffer.platform(), 'tpu')
self.assertEqual(w.device_buffer.platform(), jtu.device_under_test())


if __name__ == "__main__":
Expand Down

0 comments on commit 91a2311

Please sign in to comment.