Skip to content

Commit

Permalink
Cleanup: convert uses of 'import numpy as onp' in tests (jax-ml#3756)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Jul 14, 2020
1 parent 58aba9b commit 512ed18
Show file tree
Hide file tree
Showing 10 changed files with 469 additions and 469 deletions.
32 changes: 16 additions & 16 deletions tests/infeed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

from absl.testing import absltest
import jax
from jax import lax, numpy as np
from jax import lax, numpy as jnp
from jax.config import config
from jax.experimental import host_callback as hcb
from jax.lib import xla_client
import jax.test_util as jtu
import numpy as onp
import numpy as np

config.parse_flags_with_absl()
FLAGS = config.FLAGS
Expand All @@ -34,14 +34,14 @@ def testInfeed(self):
def f(x):
token = lax.create_token(x)
(y,), token = lax.infeed(
token, shape=(jax.ShapedArray((3, 4), np.float32),))
token, shape=(jax.ShapedArray((3, 4), jnp.float32),))
(z,), _ = lax.infeed(
token, shape=(jax.ShapedArray((3, 1, 1), np.float32),))
token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),))
return x + y + z

x = onp.float32(1.5)
y = onp.reshape(onp.arange(12, dtype=onp.float32), (3, 4)) # onp.random.randn(3, 4).astype(onp.float32)
z = onp.random.randn(3, 1, 1).astype(onp.float32)
x = np.float32(1.5)
y = np.reshape(np.arange(12, dtype=np.float32), (3, 4)) # np.random.randn(3, 4).astype(np.float32)
z = np.random.randn(3, 1, 1).astype(np.float32)
device = jax.local_devices()[0]
device.transfer_to_infeed((y,))
device.transfer_to_infeed((z,))
Expand All @@ -53,27 +53,27 @@ def testInfeedThenOutfeed(self):
def f(x):
token = lax.create_token(x)
y, token = lax.infeed(
token, shape=jax.ShapedArray((3, 4), np.float32))
token = lax.outfeed(token, y + onp.float32(1))
token, shape=jax.ShapedArray((3, 4), jnp.float32))
token = lax.outfeed(token, y + np.float32(1))
return lax.tie_in(token, x - 1)

x = onp.float32(7.5)
y = onp.random.randn(3, 4).astype(onp.float32)
x = np.float32(7.5)
y = np.random.randn(3, 4).astype(np.float32)
execution = threading.Thread(target=lambda: f(x))
execution.start()
device = jax.local_devices()[0]
device.transfer_to_infeed((y,))
out, = device.transfer_from_outfeed(
xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent())
execution.join()
self.assertAllClose(out, y + onp.float32(1))
self.assertAllClose(out, y + np.float32(1))

def testInfeedThenOutfeedInALoop(self):
hcb.stop_outfeed_receiver()
def doubler(_, token):
y, token = lax.infeed(
token, shape=jax.ShapedArray((3, 4), np.float32))
return lax.outfeed(token, y * onp.float32(2))
token, shape=jax.ShapedArray((3, 4), jnp.float32))
return lax.outfeed(token, y * np.float32(2))

@jax.jit
def f(n):
Expand All @@ -86,11 +86,11 @@ def f(n):
execution = threading.Thread(target=lambda: f(n))
execution.start()
for _ in range(n):
x = onp.random.randn(3, 4).astype(onp.float32)
x = np.random.randn(3, 4).astype(np.float32)
device.transfer_to_infeed((x,))
y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,))
.with_major_to_minor_layout_if_absent())
self.assertAllClose(y, x * onp.float32(2))
self.assertAllClose(y, x * np.float32(2))
execution.join()


Expand Down
148 changes: 74 additions & 74 deletions tests/lax_autodiff_test.py

Large diffs are not rendered by default.

268 changes: 134 additions & 134 deletions tests/lax_numpy_indexing_test.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import absltest
from absl.testing import parameterized

import numpy as onp
import numpy as np
import scipy.special as osp_special

from jax import api
Expand Down Expand Up @@ -188,12 +188,12 @@ def lax_fun(a):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
tol={onp.float32: 1e-3, onp.float64: 1e-14})
tol={np.float32: 1e-3, np.float64: 1e-14})
self._CompileAndCheck(lax_fun, args_maker)

def testIssue980(self):
x = onp.full((4,), -1e20, dtype=onp.float32)
self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
x = np.full((4,), -1e20, dtype=np.float32)
self.assertAllClose(np.zeros((4,), dtype=np.float32),
lsp_special.expit(x))

def testXlogyShouldReturnZero(self):
Expand Down
Loading

0 comments on commit 512ed18

Please sign in to comment.