Skip to content

Commit

Permalink
deflake jax.numpy and add to flake8 check (jax-ml#3312)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Jun 3, 2020
1 parent d1dbf7c commit c77c083
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 12 deletions.
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from . import fft
from . import linalg

Expand Down
7 changes: 1 addition & 6 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@

import builtins
import collections
from collections.abc import Sequence
import itertools
import operator
import os
import re
import string
import types
from typing import Sequence, Set, Tuple, Union
import warnings
Expand All @@ -50,10 +46,9 @@
from ..interpreters.masking import Poly
from .. import lax
from .. import ops
from ..util import (partial, get_module_functions, unzip2, prod as _prod,
from ..util import (partial, unzip2, prod as _prod,
subvals, safe_zip)
from ..lib import pytree
from ..lib import xla_client

FLAGS = flags.FLAGS
flags.DEFINE_enum(
Expand Down
3 changes: 1 addition & 2 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .vectorize import vectorize
from . import lax_numpy as jnp
from ..util import get_module_functions
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve # noqa: F401

_T = lambda x: jnp.swapaxes(x, -1, -2)
_H = lambda x: jnp.conj(jnp.swapaxes(x, -1, -2))
Expand Down Expand Up @@ -196,7 +196,6 @@ def _cofactor_solve(a, b):
a_shape = jnp.shape(a)
b_shape = jnp.shape(b)
a_ndims = len(a_shape)
b_ndims = len(b_shape)
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
and b_shape[-2:] == a_shape[-2:]):
msg = ("The arguments to _cofactor_solve must have shapes "
Expand Down
4 changes: 1 addition & 3 deletions jax/numpy/vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.
import functools
import re
import textwrap
from typing import Any, Callable, Dict, List, Set, Tuple
from typing import Any, Callable, Dict, List, Tuple

from .. import api
from .. import lax
from . import lax_numpy as jnp
from ..util import safe_map as map, safe_zip as zip
from ._util import _wraps


# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ ignore =
max-complexity = 18
select = B,C,F,W,T4,B9
filename =
./tests/*.py
./tests/*.py
./jax/numpy/*.py

0 comments on commit c77c083

Please sign in to comment.