From bc51e9c7f699811914cb9602fbba1e1460fffc77 Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Thu, 4 Jun 2020 14:38:41 -0700 Subject: [PATCH] deflake jax/scipy/* and add to setup.cfg (#3316) --- jax/scipy/__init__.py | 1 + jax/scipy/sparse/__init__.py | 1 + jax/scipy/sparse/linalg.py | 2 -- jax/scipy/special.py | 2 -- jax/scipy/stats/__init__.py | 1 + jax/scipy/stats/beta.py | 1 - jax/scipy/stats/cauchy.py | 1 - jax/scipy/stats/geom.py | 2 +- jax/scipy/stats/logistic.py | 1 - jax/scipy/stats/multivariate_normal.py | 2 +- jax/scipy/stats/uniform.py | 3 +-- setup.cfg | 1 + 12 files changed, 7 insertions(+), 11 deletions(-) diff --git a/jax/scipy/__init__.py b/jax/scipy/__init__.py index 392641c0d564..347188211848 100644 --- a/jax/scipy/__init__.py +++ b/jax/scipy/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa: F401 from . import linalg from . import ndimage from . import signal diff --git a/jax/scipy/sparse/__init__.py b/jax/scipy/sparse/__init__.py index 1d7527c91d3b..13ac5929c40c 100644 --- a/jax/scipy/sparse/__init__.py +++ b/jax/scipy/sparse/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa: F401 from . import linalg diff --git a/jax/scipy/sparse/linalg.py b/jax/scipy/sparse/linalg.py index 9fe21835a710..4ece1331a2e1 100644 --- a/jax/scipy/sparse/linalg.py +++ b/jax/scipy/sparse/linalg.py @@ -14,9 +14,7 @@ from functools import partial import operator -import textwrap -import scipy.sparse.linalg import numpy as np import jax.numpy as jnp from jax import lax, device_put diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 0adf42fbe2a9..3dddf6ea86d8 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -17,7 +17,6 @@ import numpy as np import scipy.special as osp_special -from .. import util from .. import lax from .. import api from ..numpy import lax_numpy as jnp @@ -289,7 +288,6 @@ def ndtri(p): Raises: TypeError: if `p` is not floating-type. """ - x = jnp.asarray(p) dtype = lax.dtype(p) if dtype not in (jnp.float32, jnp.float64): raise TypeError( diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index fd14e5377568..3f224054a9c1 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa: F401 from . import bernoulli from . import poisson from . import beta diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index a869e592c582..3d6b701c4e8a 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -36,4 +36,3 @@ def logpdf(x, a, b, loc=0, scale=1): @_wraps(osp_stats.beta.pdf, update_doc=False) def pdf(x, a, b, loc=0, scale=1): return lax.exp(logpdf(x, a, b, loc, scale)) - diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index 9b0e07b09cd8..7f437f0ae190 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -24,7 +24,6 @@ @_wraps(osp_stats.cauchy.logpdf, update_doc=False) def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale) - one = _constant_like(x, 1) pi = _constant_like(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.mul(pi, scale)) diff --git a/jax/scipy/stats/geom.py b/jax/scipy/stats/geom.py index 33907e3cdc51..4e4b29f95319 100644 --- a/jax/scipy/stats/geom.py +++ b/jax/scipy/stats/geom.py @@ -17,7 +17,7 @@ from ... import lax from ...numpy import lax_numpy as jnp from ...numpy._util import _wraps -from ..special import xlogy, xlog1py +from ..special import xlog1py @_wraps(osp_stats.geom.logpmf, update_doc=False) def logpmf(k, p, loc=0): diff --git a/jax/scipy/stats/logistic.py b/jax/scipy/stats/logistic.py index 219b01006427..193fae9fbc19 100644 --- a/jax/scipy/stats/logistic.py +++ b/jax/scipy/stats/logistic.py @@ -17,7 +17,6 @@ from ... import lax from ...numpy._util import _wraps -from ...numpy.lax_numpy import _promote_args_inexact @_wraps(osp_stats.logistic.logpdf, update_doc=False) diff --git a/jax/scipy/stats/multivariate_normal.py b/jax/scipy/stats/multivariate_normal.py index 33ee3b17aef0..4154f7b53647 100644 --- a/jax/scipy/stats/multivariate_normal.py +++ b/jax/scipy/stats/multivariate_normal.py @@ -20,7 +20,7 @@ from ...lax_linalg import cholesky, triangular_solve from ... import numpy as jnp from ...numpy._util import _wraps -from ...numpy.lax_numpy import _promote_dtypes_inexact, _constant_like +from ...numpy.lax_numpy import _promote_dtypes_inexact @_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False) diff --git a/jax/scipy/stats/uniform.py b/jax/scipy/stats/uniform.py index 7ac177a0d3f0..0685c96b1557 100644 --- a/jax/scipy/stats/uniform.py +++ b/jax/scipy/stats/uniform.py @@ -17,8 +17,7 @@ from ... import lax from ...numpy._util import _wraps -from ...numpy.lax_numpy import (_constant_like, _promote_args_inexact, - where, inf, logical_or) +from ...numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or @_wraps(osp_stats.uniform.logpdf, update_doc=False) diff --git a/setup.cfg b/setup.cfg index e8d9b1e06b60..628962f0ba11 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,3 +11,4 @@ filename = ./tests/*.py ./jax/lax/*.py ./jax/numpy/*.py + ./jax/scipy/*.py