Skip to content

Commit

Permalink
deflake jax/lax & add to flake8 check (jax-ml#3310)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Jun 4, 2020
1 parent 9c0a58a commit b187663
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 27 deletions.
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from .lax import (
ConvDimensionNumbers,
DotDimensionNumbers,
Expand Down
17 changes: 4 additions & 13 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,32 @@


import builtins
import collections
import enum
import functools
import itertools
import operator
import string
from typing import (Any, Callable, List, NamedTuple, Optional, Sequence, Union,
Tuple, Type)
from typing import (Any, Callable, List, NamedTuple, Optional, Sequence, Union, Tuple)
import warnings

import numpy as onp

from ..util import partial, prod

from .. import core
from .. import ad_util
from .. import api
from .. import linear_util as lu
from .. import dtypes
from .. import lazy
from .. import lib
from ..config import flags
from ..core import Primitive, _canonicalize_dimension
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
AbstractToken, array_types, make_shaped_array,
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray, array_types,
raise_to_shaped, abstract_token, canonicalize_shape)
from ..interpreters import partial_eval as pe
from ..interpreters import xla
from ..interpreters import pxla
from ..interpreters import ad
from ..interpreters import batching
from ..interpreters import masking
from ..util import curry, cache, safe_zip, unzip2, prod, safe_map
from ..tree_util import build_tree, tree_unflatten, tree_map
from ..util import cache, safe_zip, partial, prod, safe_map
from ..tree_util import tree_map
from ..lib import pytree
from ..lib import xla_bridge
from ..lib import xla_client
Expand Down Expand Up @@ -1359,7 +1351,6 @@ def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
Returns:
An array containing the convolution result.
"""
pads = padtype_to_pads(lhs.shape[2:], rhs.shape[2:], window_strides, padding)
return conv_general_dilated(lhs, rhs, window_strides, padding,
precision=precision)

Expand Down
16 changes: 4 additions & 12 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import inspect
import itertools
import operator
import threading
from typing import Callable, Sequence

import numpy as onp
Expand All @@ -34,7 +33,7 @@
from jax.lax import lax
from jax import linear_util as lu
from jax.abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs
from jax.api_util import flatten_fun_nokwargs
from jax.core import get_aval, typecheck, typematch
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
Expand All @@ -46,8 +45,7 @@
from jax.util import (partial, unzip2, unzip4, safe_map, safe_zip, split_list,
split_dict, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_leaves,
tree_map, tree_multimap)
treedef_children, treedef_tuple, tree_multimap)
from jax import ad_util

xops = xla_client.ops
Expand Down Expand Up @@ -951,7 +949,6 @@ def pad_jaxpr_res_avals(i, jaxpr):
return tuple(pad_jaxpr_res_avals(i, jaxpr) for i, jaxpr in enumerate(jaxprs))

def _transpose_cond_jaxpr(jaxpr, num_res):
num_non_res = len(jaxpr.in_avals) - num_res
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = _map(raise_to_shaped, primal_avals)

Expand Down Expand Up @@ -1220,7 +1217,7 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
nonzeros = const_nz + carry_nz + xs_nz
jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
if carry_nz_out == carry_nz:
break
else:
Expand Down Expand Up @@ -1268,7 +1265,6 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
"num_carry": num_carry, "jaxpr": jaxpr, "linear": linear}
return trace.default_process_primitive(scan_p, tracers, params)

num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry

unknowns = [t.pval[0] is not None for t in tracers]
Expand All @@ -1285,7 +1281,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys,
trace_type=trace.master.trace_type)
carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
carry_uk_out = out_uk[:num_carry]
if carry_uk_out == carry_uk:
break
else:
Expand Down Expand Up @@ -1523,10 +1519,8 @@ def scan_bind(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
assert len(linear) == len(args)
consts, init, xs = split_list(args, [num_consts, num_carry])
consts_avals, init_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
xs_avals = _map(partial(_promote_aval_rank, length), x_avals)
assert all(_map(typecheck, consts_avals, consts)), (consts, consts_avals)
assert all(_map(typecheck, init_avals, init))
# assert all(_map(typecheck, xs_avals, xs))
carry_avals, _ = split_list(jaxpr.out_avals, [num_carry])
assert all(_map(typematch, init_avals, carry_avals))
core.check_jaxpr(jaxpr.jaxpr)
Expand Down Expand Up @@ -1776,8 +1770,6 @@ def _check_shapes(func_name, expected_name, actual, expected, tree):
actual_shapes = _map(onp.shape, actual)
expected_shapes = _map(onp.shape, expected)
if actual_shapes != expected_shapes:
actual_shape_tree = tree_unflatten(tree, actual_shapes)
act_shape_tree = tree_unflatten(tree, actual_shapes)
raise ValueError('{}() output shapes must match {}, got {} and {}'
.format(func_name, expected_name,
tree_unflatten(tree, actual_shapes),
Expand Down
1 change: 0 additions & 1 deletion jax/lax/lax_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,3 @@ def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
xla.translations[fft_p] = fft_translation_rule
ad.deflinear(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule

3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ max-complexity = 18
select = B,C,F,W,T4,B9
filename =
./tests/*.py
./jax/numpy/*.py
./jax/lax/*.py
./jax/numpy/*.py

0 comments on commit b187663

Please sign in to comment.