Skip to content

Commit

Permalink
Cleanup: de-lint tests directory & add flake8 to travis (jax-ml#3304)
Browse files Browse the repository at this point in the history
* Cleanup: fix lint errors in tests/*.py

* Add flake8 step to travis

* add setup.cfg
  • Loading branch information
jakevdp authored Jun 3, 2020
1 parent 177e7cf commit 9ee4ef1
Show file tree
Hide file tree
Showing 18 changed files with 74 additions and 109 deletions.
12 changes: 7 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ os: linux
jobs:
include:
- python: "3.6"
env: JAX_ONLY_CHECK_TYPES=true
env: JAX_ONLY_LINT_AND_TYPECHECK=true
- python: "3.6"
env: JAX_ENABLE_X64=0 JAX_NUM_GENERATED_CASES=25
- python: "3.6"
Expand All @@ -30,8 +30,8 @@ before_install:
install:
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip absl-py opt_einsum numpy scipy pytest-xdist pytest-benchmark mypy=0.770
- pip install msgpack
- if [ "$JAX_ONLY_CHECK_TYPES" = true ]; then
pip install pytype ;
- if [ "$JAX_ONLY_LINT_AND_TYPECHECK" = true ]; then
pip install pytype flake8;
fi
# The jaxlib version should match the minimum jaxlib version in
# jax/lib/__init__.py. This tests JAX PRs against the oldest permitted
Expand All @@ -53,9 +53,11 @@ script:
sphinx-build -b html -D nbsphinx_execute=always docs docs/build/html &&
pytest docs &&
pytest --doctest-modules jax/api.py ;
elif [ "$JAX_ONLY_CHECK_TYPES" = true ]; then
elif [ "$JAX_ONLY_LINT_AND_TYPECHECK" = true ]; then
echo "===== Checking with mypy ====" &&
time mypy --config-file=mypy.ini jax ;
time mypy --config-file=mypy.ini jax &&
echo "===== Checking lint with flake8 ====" &&
time flake8 . ;
elif [ "$JAX_TO_TF" = true ]; then
pytest jax/experimental/jax_to_tf/tests ;
else
Expand Down
11 changes: 11 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[flake8]
max-line-length = 88
ignore =
C901 # object names too complex
E111, E114 # four-space indents
E121 # line continuations
W503, W504 # line breaks around binary operators
max-complexity = 18
select = B,C,F,W,T4,B9
filename =
./tests/*.py
9 changes: 3 additions & 6 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def f(pt):
return 0. if pt.is_zero() else jnp.sqrt(pt.x ** 2 + pt.y ** 2)

f(pt) # doesn't crash
g = api.grad(f)(pt)
_ = api.grad(f)(pt)
self.assertIsInstance(pt, ZeroPoint)

@parameterized.parameters(1, 2, 3)
Expand Down Expand Up @@ -966,7 +966,7 @@ def pmapped_multi_step(state):
return pmapped_multi_step(state)

u = jnp.ones((device_count, 100))
u_final = multi_step_pmap(u) # doesn't crash
_ = multi_step_pmap(u) # doesn't crash

def test_concurrent_device_get_and_put(self):
def f(x):
Expand Down Expand Up @@ -1455,8 +1455,6 @@ def binom_checkpoint(funs):

def test_remat_symbolic_zeros(self):
# code from https://github.com/google/jax/issues/1907
test_remat = True
test_scan = True

key = jax.random.PRNGKey(0)
key, split = jax.random.split(key)
Expand Down Expand Up @@ -2206,8 +2204,7 @@ def f(x):
def g(y):
return x + y
def g_jvp(primals, tangents):
(y,), (t,) = primals, tangents
return g(x), 2 * y
return g(x), 2 * primals[0]
g.defjvp(g_jvp)
return g(1.)

Expand Down
48 changes: 14 additions & 34 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ def loss(params, data):
params = [(R(m, n), R(m))
for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]

input_vec = R(3)
target_vec = R(4)
datum = (input_vec, target_vec)

input_batch = R(5, 3)
target_batch = R(5, 4)
batch = (input_batch, target_batch)
Expand Down Expand Up @@ -651,8 +647,7 @@ def testLaxLinalgTriangularSolve(self):
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
slice_sizes),
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
"slice_sizes": slice_sizes, "rng_factory": rng_factory,
"rng_idx_factory": rng_idx_factory}
"slice_sizes": slice_sizes, "rng_factory": rng_factory}
for dtype in [np.float32, np.int32]
for axis, shape, idxs, dnums, slice_sizes in [
(0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
Expand All @@ -670,12 +665,10 @@ def testLaxLinalgTriangularSolve(self):
start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
slice_sizes, rng_factory):
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
operand = rng(shape, dtype)
ans = vmap(fun, (axis, None))(operand, idxs)
Expand All @@ -688,8 +681,7 @@ def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
slice_sizes),
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
"slice_sizes": slice_sizes, "rng_factory": rng_factory,
"rng_idx_factory": rng_idx_factory}
"slice_sizes": slice_sizes, "rng_factory": rng_factory}
for dtype in [np.float32, np.float64]
for axis, shape, idxs, dnums, slice_sizes in [
(0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
Expand All @@ -706,12 +698,10 @@ def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
offset_dims=(1,), collapsed_slice_dims=(0,),
start_index_map=(0, 1)),
(1, 3)), ]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
slice_sizes, rng_factory):
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
operand = rng(shape, dtype)
Expand All @@ -725,8 +715,7 @@ def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
slice_sizes),
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
"slice_sizes": slice_sizes, "rng_factory": rng_factory,
"rng_idx_factory": rng_idx_factory}
"slice_sizes": slice_sizes, "rng_factory": rng_factory}
for dtype in [np.float32, np.int32]
for axis, shape, idxs, dnums, slice_sizes in [
(0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
Expand All @@ -741,12 +730,10 @@ def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
slice_sizes, rng_factory):
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
operand = rng(shape, dtype)
ans = vmap(fun, (None, axis))(operand, idxs)
Expand All @@ -759,8 +746,7 @@ def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
slice_sizes),
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
"slice_sizes": slice_sizes, "rng_factory": rng_factory,
"rng_idx_factory": rng_idx_factory}
"slice_sizes": slice_sizes, "rng_factory": rng_factory}
for dtype in [np.float32, np.float64]
for axis, shape, idxs, dnums, slice_sizes in [
(0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
Expand All @@ -775,12 +761,10 @@ def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
slice_sizes, rng_factory):
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
operand = rng(shape, dtype)
Expand All @@ -795,7 +779,7 @@ def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
dnums, slice_sizes),
"op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
"rng_factory": rng_factory}
for dtype in [np.float32, np.int32]
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
(0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
Expand All @@ -816,12 +800,10 @@ def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
slice_sizes, rng_factory):
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
operand = rng(shape, dtype)
assert operand.shape[op_axis] == idxs.shape[idxs_axis]
Expand All @@ -837,7 +819,7 @@ def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
dnums, slice_sizes),
"op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
"rng_factory": rng_factory}
for dtype in [np.float32]
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
(0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
Expand All @@ -858,12 +840,10 @@ def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
slice_sizes, rng_factory):
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
operand = rng(shape, dtype)
Expand Down Expand Up @@ -945,10 +925,10 @@ def dist_sq(R):

@jit
def f(R):
dr = dist_sq(R)
_ = dist_sq(R)
return jnp.sum(R ** 2)

H = hessian(f)(R) # don't crash on UnshapedArray
_ = hessian(f)(R) # don't crash on UnshapedArray

def testIssue489(self):
def f(key):
Expand Down
10 changes: 5 additions & 5 deletions tests/debug_nans_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ def tearDown(self):

def testSingleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
B = jnp.tanh(A)
_ = jnp.tanh(A)

def testMultipleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
D, V = jnp.linalg.eig(A)
_, _ = jnp.linalg.eig(A)

def testJitComputationNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
B = jax.jit(jnp.tanh)(A)
_ = jax.jit(jnp.tanh)(A)

def testSingleResultPrimitiveNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
B = 0. / A
_ = 0. / A


if __name__ == '__main__':
absltest.main()
absltest.main()
1 change: 0 additions & 1 deletion tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from absl.testing import absltest
from absl.testing import parameterized

import jax
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
Expand Down
14 changes: 7 additions & 7 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def func(x):

with self.assertRaises(hcb.TapFunctionException):
with hcb.outfeed_receiver():
res = func(0)
_ = func(0)

# We should have received everything before the error
assertMultiLineStrippedEqual(self, """
Expand Down Expand Up @@ -563,7 +563,7 @@ def test_jit_types(self, nr_args=2, dtype=jnp.int16, shape=(2,)):
a_new_test="************",
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = jit_fun1(args)
_ = jit_fun1(args)
# self.assertAllClose(args, res)

def test_jit_large(self):
Expand Down Expand Up @@ -656,7 +656,7 @@ def func(x):
return x3

with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = api.jit(func)(0)
_ = api.jit(func)(0)

assert False # It seems that the previous jit blocks above

Expand Down Expand Up @@ -885,7 +885,7 @@ def test_vmap(self):
g = integer_pow[ y=2 ] f
in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
with hcb.outfeed_receiver():
res_vmap = vmap_fun1(vargs)
_ = vmap_fun1(vargs)
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
[ 8.00 10.00]
Expand All @@ -910,7 +910,7 @@ def func(y):
d = add c 3.00
in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
with hcb.outfeed_receiver():
res_vmap = vmap_func(vargs)
_ = vmap_func(vargs)
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
[ 3.00
Expand Down Expand Up @@ -940,7 +940,7 @@ def sum_all(xv, yv):
transforms=(('batch', (0,)), ('batch', (0,))) ] e
in (f,) }""", str(api.make_jaxpr(sum_all)(xv, yv)))
with hcb.outfeed_receiver():
res_vmap = sum_all(xv, yv)
_ = sum_all(xv, yv)
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)})
[[0 1 2 3 4]
Expand Down Expand Up @@ -1044,7 +1044,7 @@ def padded_sum(x):
h = reduce_sum[ axes=(0,) ] g
in (h,) }""", str(api.make_jaxpr(padded_sum)(*args)))

res = padded_sum(*args)
_ = padded_sum(*args)
self.assertMultiLineStrippedEqual("""
logical_shapes: [(2,)] transforms: ('mask',) what: x
[0 1 2 3]
Expand Down
5 changes: 2 additions & 3 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,6 @@ def harmonic_bond(conf, params):

def minimize_structure(test_params):
energy_fn = partial(harmonic_bond, params=test_params)
grad_fn = api.grad(energy_fn)

def apply_carry(carry, _):
i, x = carry
Expand Down Expand Up @@ -1950,7 +1949,7 @@ def test_while_loop_of_pmap(self):
def body(i, x):
result = api.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), axis_name='i')(x)
return result + x
f_loop = lambda x: lax.fori_loop(0, 3, body, x)
f_loop = lambda x: lax.fori_loop(0, 3, body, x) # noqa: F821
ans = f_loop(jnp.ones(api.device_count()))
del body, f_loop

Expand Down Expand Up @@ -1999,7 +1998,7 @@ def fn(t):
fn = api.vmap(fn)

with api.disable_jit():
outputs = fn(jnp.array([1])) # doesn't crash
_ = fn(jnp.array([1])) # doesn't crash

def test_disable_jit_while_loop_with_vmap(self):
# https://github.com/google/jax/issues/2823
Expand Down
Loading

0 comments on commit 9ee4ef1

Please sign in to comment.