Skip to content

Commit

Permalink
remove some trailing whitespace (jax-ml#3287)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj authored Jun 3, 2020
1 parent ea4277b commit c42a7f7
Show file tree
Hide file tree
Showing 16 changed files with 35 additions and 35 deletions.
2 changes: 1 addition & 1 deletion examples/differentially_private_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Gradient Descent (https://arxiv.org/abs/1607.00133). DPSGD requires clipping
the per-example parameter gradients, which is non-trivial to implement
efficiently for convolutional neural networks. The JAX XLA compiler shines in
this setting by optimizing the minibatch-vectorized computation for
this setting by optimizing the minibatch-vectorized computation for
convolutional architectures. Train time takes a few seconds per epoch on a
commodity GPU.
Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,11 @@ def pp_val(arg) -> ppu.PrettyPrint:
* nr_untapped: how many positional arguments (from the tail) should not be
passed to the tap function.
* arg_treedef: the treedef of the tapped positional arguments.
* transforms: a tuple of the transformations that have been applied. Each
* transforms: a tuple of the transformations that have been applied. Each
element of the tuple is itself a tuple with the first element the name
of the transform. The remaining elements depend on the transform. For
example, for `batch`, the parameters are the dimensions that have been
batched, and for `mask` the logical shapes. These are unpacked by
of the transform. The remaining elements depend on the transform. For
example, for `batch`, the parameters are the dimensions that have been
batched, and for `mask` the logical shapes. These are unpacked by
_ConsumerCallable before passing to the user function.
* the remaining parameters are passed to the tap function.
Expand Down
10 changes: 5 additions & 5 deletions jax/experimental/vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
("gufuncs") are one of my favorite abstractions from NumPy. They generalize
NumPy's `broadcasting rules
<https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html>`_ to
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
* "core dimensions" over which an operation is defined.
* "core dimensions" over which an operation is defined.
* "broadcast dimensions" over which operations can be automatically vectorized.
A string `signature <https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html#details-of-signature>`_
Expand Down Expand Up @@ -199,7 +199,7 @@ def _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims):
return [broadcast_shape + tuple(dim_sizes[dim] for dim in core_dims)
for core_dims in list_of_core_dims]


# adapted from np.vectorize (again authored by shoyer@)
def broadcast_with_core_dims(args, input_core_dims, output_core_dims):
if len(args) != len(input_core_dims):
Expand Down Expand Up @@ -245,7 +245,7 @@ def vectorize(signature):
"""Vectorize a function using JAX.
Turns an arbitrary function into a numpy style "gufunc". Once
you specify the behavior of the core axis, the rest will be
you specify the behavior of the core axis, the rest will be
broadcast naturally.
Args:
Expand All @@ -258,7 +258,7 @@ def vectorize(signature):
which axis should be treated as the core one.
"""
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions jax/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def init(key, shape, dtype=dtype):
def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
"""
Construct an initializer for uniformly distributed orthogonal matrices.
If the shape is not square, the matrices will have orthonormal rows or columns
depending on which side is smaller.
"""
Expand All @@ -100,7 +100,7 @@ def init(key, shape, dtype=dtype):

def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
"""
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
The shape must be 3D, 4D or 5D.
"""
Expand Down
4 changes: 2 additions & 2 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,14 @@ def _cofactor_solve(a, b):
If a is rank n-1, then the lower right corner of u will be zero and the
triangular_solve will fail.
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
Then y_{n} =
Then y_{n}
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
x_{n} * prod_{i=1...n-1}(u_{ii})
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
we can avoid the triangular_solve failing.
To correctly compute the rest of y_{i} for i != n, we simply multiply
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
For the second case, a check is done on the matrix to see if `solve`
returns NaN or Inf, and gives a matrix of zeros as a result, as the
gradient of the determinant of a matrix with rank less than n-1 is 0.
Expand Down
6 changes: 3 additions & 3 deletions jax/numpy/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def _nonzero_range(arr):

@_wraps(np.roots, lax_description="""\
If the input polynomial coefficients of length n do not start with zero,
the polynomial is of degree n - 1 leading to n - 1 roots.
the polynomial is of degree n - 1 leading to n - 1 roots.
If the coefficients do have leading zeros, the polynomial they define
has a smaller degree and the number of roots (and thus the output shape)
has a smaller degree and the number of roots (and thus the output shape)
is value dependent.
The general implementation can therefore not be transformed with jit.
If the coefficients are guaranteed to have no leading zeros, use the
If the coefficients are guaranteed to have no leading zeros, use the
keyword argument `strip_zeros=False` to get a jit-compatible variant::
>>> roots_unsafe = jax.jit(functools.partial(jnp.roots, strip_zeros=False))
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _parse_gufunc_signature(
'not a valid gufunc signature: {}'.format(signature))
args, retvals = ([tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list)]
for arg_list in signature.split('->'))
for arg_list in signature.split('->'))
return args, retvals


Expand Down
2 changes: 1 addition & 1 deletion jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Among other requirements, the JAX PRNG aims to:
(a) ensure reproducibility,
(b) parallelize well, both in terms of vectorization (generating array values)
and multi-replica, multi-core computation. In particular it should not use
and multi-replica, multi-core computation. In particular it should not use
sequencing constraints between random function calls.
The approach is based on:
Expand Down
2 changes: 1 addition & 1 deletion jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..numpy import lax_numpy as jnp
from ..numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
_promote_args_inexact)
from ..numpy._util import _wraps
from ..numpy._util import _wraps


@_wraps(osp_special.gammaln)
Expand Down
6 changes: 3 additions & 3 deletions jax/third_party/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def tensorsolve(a, b, axes=None):
allaxes.insert(an, k)

a = a.transpose(allaxes)

Q = a.shape[-(an - b.ndim):]

prod = 1
Expand All @@ -98,10 +98,10 @@ def tensorsolve(a, b, axes=None):

a = a.reshape(-1, prod)
b = b.ravel()

res = jnp.asarray(la.solve(a, b))
res = res.reshape(Q)

return res


Expand Down
2 changes: 1 addition & 1 deletion tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def testDefaultTypes(self, type, dtype):

@parameterized.named_parameters(
{"testcase_name": "_swap={}_jit={}".format(swap, jit),
"swap": swap, "jit": jit}
"swap": swap, "jit": jit}
for swap in [False, True] for jit in [False, True])
@jtu.skip_on_devices("tpu") # F16 not supported on TPU
def testBinaryPromotion(self, swap, jit):
Expand Down
8 changes: 4 additions & 4 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def func(x):
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
testing_stream.reset()

with hcb.outfeed_receiver():
res_grad = grad_func(jnp.float32(5.))

Expand Down Expand Up @@ -843,7 +843,7 @@ def func(x):
with hcb.outfeed_receiver():
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let
let
in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
# Just making the Jaxpr invokes the id_print twiceonce
assertMultiLineStrippedEqual(self, """
Expand Down Expand Up @@ -1100,7 +1100,7 @@ def func(x, z):
in (d, e, h) }
linear=(False, False, False, False, False, False)
true_jaxpr={ lambda ; d g_ a b c h.
let
let
in (a, d, h) } ] c d e 1 2 b h
in (f, g, i) }""", func, [y, 5])

Expand Down Expand Up @@ -1176,7 +1176,7 @@ def func(x):
in (w, t, u, x) }
body_nconsts=2
cond_jaxpr={ lambda ; j k l m.
let
let
in (j,) }
cond_nconsts=0 ] b c h a 1 i
in (d, 5, g) }""", func, [ct_body])
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,7 +2027,7 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, rng_factory):
rtol=tol, atol=tol)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
{"testcase_name":
f"_arg{i}_ndmin={ndmin}_dtype={np.dtype(dtype) if dtype else None}",
"arg": arg, "ndmin": ndmin, "dtype": dtype}
for i, (arg, dtypes) in enumerate([
Expand Down
6 changes: 3 additions & 3 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def testDet(self, n, dtype, rng_factory):
def testDetOfSingularMatrix(self):
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
self.assertAllClose(np.float32(0), jsp.linalg.det(x))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
Expand Down Expand Up @@ -174,10 +174,10 @@ def testTensorsolve(self, m, nq, dtype, rng_factory):
result = jnp.linalg.tensorsolve(*args_maker())
self.assertEqual(result.shape, Q)

self._CheckAgainstNumpy(np.linalg.tensorsolve,
self._CheckAgainstNumpy(np.linalg.tensorsolve,
jnp.linalg.tensorsolve, args_maker,
tol={np.float32: 1e-2, np.float64: 1e-3})
self._CompileAndCheck(jnp.linalg.tensorsolve,
self._CompileAndCheck(jnp.linalg.tensorsolve,
args_maker,
rtol={np.float64: 1e-13})

Expand Down
2 changes: 1 addition & 1 deletion tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def testEluGrad(self):
def testEluValue(self):
val = nn.elu(1e4)
self.assertAllClose(val, 1e4, check_dtypes=False)

def testGluValue(self):
val = nn.glu(jnp.array([1.0, 0.0]))
self.assertAllClose(val, jnp.array([0.5]))
Expand Down
4 changes: 2 additions & 2 deletions tests/vectorize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def test_center(self):
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (3,))
self.assertAllClose(jnp.mean(X, axis=1), b)

b, a = center(X, axis=0)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (4,))
self.assertAllClose(jnp.mean(X, axis=0), b)


if __name__ == "__main__":
absltest.main()

0 comments on commit c42a7f7

Please sign in to comment.