Skip to content

Commit

Permalink
deflake jax/scipy/* and add to setup.cfg (jax-ml#3316)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Jun 4, 2020
1 parent b187663 commit bc51e9c
Show file tree
Hide file tree
Showing 12 changed files with 7 additions and 11 deletions.
1 change: 1 addition & 0 deletions jax/scipy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from . import linalg
from . import ndimage
from . import signal
Expand Down
1 change: 1 addition & 0 deletions jax/scipy/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from . import linalg
2 changes: 0 additions & 2 deletions jax/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

from functools import partial
import operator
import textwrap

import scipy.sparse.linalg
import numpy as np
import jax.numpy as jnp
from jax import lax, device_put
Expand Down
2 changes: 0 additions & 2 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
import scipy.special as osp_special

from .. import util
from .. import lax
from .. import api
from ..numpy import lax_numpy as jnp
Expand Down Expand Up @@ -289,7 +288,6 @@ def ndtri(p):
Raises:
TypeError: if `p` is not floating-type.
"""
x = jnp.asarray(p)
dtype = lax.dtype(p)
if dtype not in (jnp.float32, jnp.float64):
raise TypeError(
Expand Down
1 change: 1 addition & 0 deletions jax/scipy/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from . import bernoulli
from . import poisson
from . import beta
Expand Down
1 change: 0 additions & 1 deletion jax/scipy/stats/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ def logpdf(x, a, b, loc=0, scale=1):
@_wraps(osp_stats.beta.pdf, update_doc=False)
def pdf(x, a, b, loc=0, scale=1):
return lax.exp(logpdf(x, a, b, loc, scale))

1 change: 0 additions & 1 deletion jax/scipy/stats/cauchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
@_wraps(osp_stats.cauchy.logpdf, update_doc=False)
def logpdf(x, loc=0, scale=1):
x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
one = _constant_like(x, 1)
pi = _constant_like(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
normalize_term = lax.log(lax.mul(pi, scale))
Expand Down
2 changes: 1 addition & 1 deletion jax/scipy/stats/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ... import lax
from ...numpy import lax_numpy as jnp
from ...numpy._util import _wraps
from ..special import xlogy, xlog1py
from ..special import xlog1py

@_wraps(osp_stats.geom.logpmf, update_doc=False)
def logpmf(k, p, loc=0):
Expand Down
1 change: 0 additions & 1 deletion jax/scipy/stats/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from ... import lax
from ...numpy._util import _wraps
from ...numpy.lax_numpy import _promote_args_inexact


@_wraps(osp_stats.logistic.logpdf, update_doc=False)
Expand Down
2 changes: 1 addition & 1 deletion jax/scipy/stats/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ...lax_linalg import cholesky, triangular_solve
from ... import numpy as jnp
from ...numpy._util import _wraps
from ...numpy.lax_numpy import _promote_dtypes_inexact, _constant_like
from ...numpy.lax_numpy import _promote_dtypes_inexact


@_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False)
Expand Down
3 changes: 1 addition & 2 deletions jax/scipy/stats/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

from ... import lax
from ...numpy._util import _wraps
from ...numpy.lax_numpy import (_constant_like, _promote_args_inexact,
where, inf, logical_or)
from ...numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or


@_wraps(osp_stats.uniform.logpdf, update_doc=False)
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ filename =
./tests/*.py
./jax/lax/*.py
./jax/numpy/*.py
./jax/scipy/*.py

0 comments on commit bc51e9c

Please sign in to comment.