From a7c2cdea64338653a88fc75456d846fb6250b99a Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Tue, 14 Jul 2020 13:05:31 -0700 Subject: [PATCH] Cleanup: convert uses of `import numpy as onp` in library code (#3754) --- benchmarks/benchmark.py | 6 +- jax/api.py | 136 +++---- jax/core.py | 14 +- jax/experimental/loops.py | 4 +- jax/interpreters/batching.py | 24 +- jax/interpreters/masking.py | 8 +- jax/interpreters/partial_eval.py | 6 +- jax/interpreters/pxla.py | 42 +-- jax/interpreters/xla.py | 52 +-- jax/lax/lax.py | 626 +++++++++++++++---------------- jax/lax/lax_control_flow.py | 46 +-- jax/lax/lax_fft.py | 20 +- jax/lax/lax_parallel.py | 34 +- jax/lib/xla_bridge.py | 33 +- jax/util.py | 4 +- 15 files changed, 527 insertions(+), 528 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 74d8624cbfb7..965fd234f1a1 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -20,7 +20,7 @@ from typing import Any, Optional, Union, Callable, List, Dict from absl import flags -import numpy as onp +import numpy as np from tabulate import tabulate from jax.util import safe_zip @@ -59,7 +59,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None, if iters is None: warmup = 1 else: - warmup = onp.clip(1, iters // 10, 10) + warmup = np.clip(1, iters // 10, 10) for _ in range(warmup): f() @@ -73,7 +73,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None, times.append(end - start) count += 1 - times_arr = onp.array(times) + times_arr = np.array(times) print("---------Benchmark results for %s---------" % (name or f.__name__)) print("mean=%f std=%f %%std=%f total=%f" % (times_arr.mean(), times_arr.std(), _pstd(times_arr), times_arr.sum())) diff --git a/jax/api.py b/jax/api.py index f2e9fd9fe3e5..0ccff18e280c 100644 --- a/jax/api.py +++ b/jax/api.py @@ -32,7 +32,7 @@ from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union from warnings import warn -import numpy as onp +import numpy as np from contextlib import contextmanager from . import core @@ -204,10 +204,10 @@ def disable_jit(): debugging, and avoid the tracer too, we can use the :py:func:`disable_jit` context manager: - >>> import jax.numpy as np + >>> import jax >>> >>> with jax.disable_jit(): - ... print(f(np.array([1, 2, 3]))) + ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9] @@ -339,7 +339,7 @@ def make_axis_env(nreps): return xla.AxisEnv(nreps, names, sizes) def abstractify(x): - return ShapedArray(onp.shape(x), dtypes.result_type(x)) + return ShapedArray(np.shape(x), dtypes.result_type(x)) @wraps(fun) def computation_maker(*args, **kwargs): @@ -474,7 +474,7 @@ def value_and_grad_f(*args, **kwargs): _check_scalar(ans) dtype = dtypes.result_type(ans) tree_map(partial(_check_output_dtype_grad, holomorphic), ans) - g = vjp_py(onp.ones((), dtype=dtype)) + g = vjp_py(np.ones((), dtype=dtype)) g = g[0] if isinstance(argnums, int) else g if not has_aux: return ans, g @@ -500,12 +500,12 @@ def _check_input_dtype_revderiv(name, holomorphic, x): _check_arg(x) aval = core.get_aval(x) if holomorphic: - if not dtypes.issubdtype(aval.dtype, onp.complexfloating): + if not dtypes.issubdtype(aval.dtype, np.complexfloating): msg = (f"{name} with holomorphic=True requires inputs with complex dtype, " f"but got {aval.dtype.name}.") raise TypeError(msg) - elif not (dtypes.issubdtype(aval.dtype, onp.floating) or - dtypes.issubdtype(aval.dtype, onp.complexfloating)): + elif not (dtypes.issubdtype(aval.dtype, np.floating) or + dtypes.issubdtype(aval.dtype, np.complexfloating)): msg = (f"{name} requires real- or complex-valued inputs (input dtype that " "is a sub-dtype of np.floating or np.complexfloating), " f"but got {aval.dtype.name}. ") @@ -515,11 +515,11 @@ def _check_input_dtype_revderiv(name, holomorphic, x): def _check_output_dtype_revderiv(name, holomorphic, x): aval = core.get_aval(x) if holomorphic: - if not dtypes.issubdtype(aval.dtype, onp.complexfloating): + if not dtypes.issubdtype(aval.dtype, np.complexfloating): msg = (f"{name} with holomorphic=True requires outputs with complex dtype, " f"but got {aval.dtype.name}.") raise TypeError(msg) - elif not dtypes.issubdtype(aval.dtype, onp.floating): + elif not dtypes.issubdtype(aval.dtype, np.floating): msg = (f"{name} requires real-valued outputs (output dtype that is " f"a sub-dtype of np.floating), but got {aval.dtype.name}. " "For holomorphic differentiation, pass holomorphic=True. " @@ -545,13 +545,13 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0, ``fun`` using forward-mode automatic differentiation. >>> import jax - >>> import jax.numpy as np + >>> import jax.numpy as jnp >>> >>> def f(x): - ... return jax.numpy.asarray( - ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])]) + ... return jnp.asarray( + ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... - >>> print(jax.jacfwd(f)(np.array([1., 2., 3.]))) + >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] @@ -575,12 +575,12 @@ def _check_input_dtype_jacfwd(holomorphic, x): _check_arg(x) aval = core.get_aval(x) if holomorphic: - if not (dtypes.issubdtype(aval.dtype, onp.complexfloating) and - not dtypes.issubdtype(aval.dtype, onp.floating)): + if not (dtypes.issubdtype(aval.dtype, np.complexfloating) and + not dtypes.issubdtype(aval.dtype, np.floating)): msg = ("jacfwd with holomorphic=True requires inputs with complex dtype, " f"but got {aval.dtype.name}.") raise TypeError(msg) - elif not dtypes.issubdtype(aval.dtype, onp.floating): + elif not dtypes.issubdtype(aval.dtype, np.floating): msg = ("jacfwd requires real-valued inputs (input dtype that is " f"a sub-dtype of np.floating), but got {aval.dtype.name}. " "For holomorphic differentiation, pass holomorphic=True. " @@ -591,7 +591,7 @@ def _check_input_dtype_jacfwd(holomorphic, x): def _check_output_dtype_jacfwd(holomorphic, x): aval = core.get_aval(x) if holomorphic: - if not dtypes.issubdtype(aval.dtype, onp.complexfloating): + if not dtypes.issubdtype(aval.dtype, np.complexfloating): msg = ("jacfwd with holomorphic=True requires outputs with complex dtype, " f"but got {aval.dtype.name}.") raise TypeError(msg) @@ -613,13 +613,13 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0, ``fun`` using reverse-mode automatic differentiation. >>> import jax - >>> import jax.numpy as np + >>> import jax.numpy as jnp >>> >>> def f(x): - ... return jax.numpy.asarray( - ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])]) + ... return jnp.asarray( + ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... - >>> print(jax.jacrev(f)(np.array([1., 2., 3.]))) + >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] @@ -711,23 +711,23 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0, def _std_basis(pytree): leaves, _ = tree_flatten(pytree) - ndim = sum(map(onp.size, leaves)) + ndim = sum(map(np.size, leaves)) # TODO(mattjj): use a symbolic identity matrix here dtype = dtypes.result_type(*leaves) - flat_basis = onp.eye(ndim, dtype=dtype) + flat_basis = np.eye(ndim, dtype=dtype) return _unravel_array_into_pytree(pytree, 1, flat_basis) def _unravel_array_into_pytree(pytree, axis, arr): leaves, treedef = tree_flatten(pytree) axis = axis % arr.ndim - shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves] - parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis) - reshaped_parts = [onp.reshape(x, shape) for x, shape in zip(parts, shapes)] + shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves] + parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis) + reshaped_parts = [np.reshape(x, shape) for x, shape in zip(parts, shapes)] return tree_unflatten(treedef, reshaped_parts) def _split(x, indices, axis): - if isinstance(x, onp.ndarray): - return onp.split(x, indices, axis) + if isinstance(x, np.ndarray): + return np.split(x, indices, axis) else: return x.split(indices, axis) @@ -771,9 +771,9 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable: For example, we can implement a matrix-matrix product using a vector dot product: - >>> import jax.numpy as np + >>> import jax.numpy as jnp >>> - >>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> [] + >>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis) @@ -788,21 +788,21 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable: axes of the container elements to map over: >>> A, B, C, D = 2, 3, 4, 5 - >>> x = np.ones((A, B)) - >>> y = np.ones((B, C)) - >>> z = np.ones((C, D)) + >>> x = jnp.ones((A, B)) + >>> y = jnp.ones((B, C)) + >>> z = jnp.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg - ... return np.dot(x, np.dot(y, z)) + ... return jnp.dot(x, jnp.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size - >>> x = np.ones((K, A, B)) # batch axis in different locations - >>> y = np.ones((B, K, C)) - >>> z = np.ones((C, D, K)) + >>> x = jnp.ones((K, A, B)) # batch axis in different locations + >>> y = jnp.ones((B, K, C)) + >>> z = jnp.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree).shape) @@ -811,7 +811,7 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable: Here's another example using container types in ``in_axes``, this time a dictionary, to specify the elements of the container to map over: - >>> dct = {'a': 0., 'b': np.arange(5.)} + >>> dct = {'a': 0., 'b': jnp.arange(5.)} >>> x = 1. >>> def foo(dct, x): ... return dct['a'] + dct['b'] + x @@ -824,13 +824,13 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable: element mapped and the second unmapped. Only for unmapped results we can specify ``out_axes`` to be ``None`` (to keep it unmapped). - >>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(np.arange(2.), 4.)) + >>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.)) (DeviceArray([4., 5.], dtype=float32), 8.0) If the ``out_axes`` is specified for an unmapped result, the result is broadcast across the mapped axis: - >>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(np.arange(2.), 4.)) + >>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.)) (DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32)) If the ``out_axes`` is specified for a mapped result, the result is @@ -884,7 +884,7 @@ def _get_axis_size(name: str, i:int, shape: Tuple[int, ...], axis: int): f"but axis to be mapped {axis}") from e def _mapped_axis_size(tree, vals, dims, name): - mapped_axis_sizes = {_get_axis_size(name, i, onp.shape(x), d) + mapped_axis_sizes = {_get_axis_size(name, i, np.shape(x), d) for i, (x, d) in enumerate(zip(vals, dims)) if d is not None} try: @@ -1005,18 +1005,18 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, For example, assuming 8 XLA devices are available, :py:func:`pmap` can be used as a map along a leading array axis: - >>> import jax.numpy as np + >>> import jax.numpy as jnp >>> - >>> out = pmap(lambda x: x ** 2)(np.arange(8)) # doctest: +SKIP + >>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) # doctest: +SKIP >>> print(out) # doctest: +SKIP [0, 1, 4, 9, 16, 25, 36, 49] When the leading dimension is smaller than the number of available devices JAX will simply run on a subset of devices: - >>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) - >>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 - >>> out = pmap(np.dot)(x, y) # doctest: +SKIP + >>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) + >>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 + >>> out = pmap(jnp.dot)(x, y) # doctest: +SKIP >>> print(out) # doctest: +SKIP [[[ 4. 9.] [ 12. 29.]] @@ -1028,14 +1028,14 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, If your leading dimension is larger than the number of available devices you will get an error: - >>> pmap(lambda x: x ** 2)(np.arange(9)) # doctest: +SKIP + >>> pmap(lambda x: x ** 2)(jnp.arange(9)) # doctest: +SKIP ValueError: ... requires 9 replicas, but only 8 XLA devices are available As with :py:func:`vmap`, using ``None`` in ``in_axes`` indicates that an argument doesn't have an extra axis and should be broadcasted, rather than mapped, across the replicas: - >>> x, y = np.arange(2.), 4. + >>> x, y = jnp.arange(2.), 4. >>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) # doctest: +SKIP >>> print(out) # doctest: +SKIP ([4., 5.], [8., 8.]) @@ -1048,7 +1048,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, collective operations. For example: >>> f = lambda x: x / jax.lax.psum(x, axis_name='i') - >>> out = pmap(f, axis_name='i')(np.arange(4.)) # doctest: +SKIP + >>> out = pmap(f, axis_name='i')(jnp.arange(4.)) # doctest: +SKIP >>> print(out) # doctest: +SKIP [ 0. 0.16666667 0.33333334 0.5 ] >>> print(out.sum()) # doctest: +SKIP @@ -1073,7 +1073,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) ... return row_normed, col_normed, doubly_normed >>> - >>> x = np.arange(8.).reshape((4, 2)) + >>> x = jnp.arange(8.).reshape((4, 2)) >>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP >>> print(row_normed.sum(0)) # doctest: +SKIP [ 1. 1.] @@ -1087,7 +1087,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, runs on two hosts with 4 XLA devices each: >>> f = lambda x: x + jax.lax.psum(x, axis_name='i') - >>> data = np.arange(4) if jax.host_id() == 0 else np.arange(4,8) + >>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4,8) >>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP >>> print(out) # doctest: +SKIP [28 29 30 31] # on host 0 @@ -1096,7 +1096,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, Each host passes in a different length-4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length-4 arrays can be thought of as sharded length-8 array (in this example - equivalent to np.arange(8)) that is mapped over, with the length-8 mapped axis + equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis given name 'i'. The pmap call on each host then returns the corresponding length-4 output shard. @@ -1114,9 +1114,9 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, ... def f2(x): ... return jax.lax.psum(x ** 2, axis_name='i') >>> - >>> print(f1(np.arange(6.))) # doctest: +SKIP + >>> print(f1(jnp.arange(6.))) # doctest: +SKIP [0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333] - >>> print(f2(np.array([2., 3.]))) # doctest: +SKIP + >>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP [ 13. 13.] """ # axis_size is an optional integer representing the global axis size. @@ -1301,7 +1301,7 @@ def wrapped_fun(args, logical_env): if in_tree != in_shapes_tree: raise TypeError(f"Tree mismatch: Input {in_tree} and shape spec {in_shapes_tree}.") logical_env = {unique_ids[name] : val for name, val in logical_env.items()} - in_shapes = map(masking.finalize_spec, in_specs, map(onp.shape, args_flat)) + in_shapes = map(masking.finalize_spec, in_specs, map(np.shape, args_flat)) padded_env = masking.bind_shapes(in_shapes, [x.shape for x in args_flat]) f = lu.wrap_init(fun) flat_fun, out_tree_thunk = flatten_fun_nokwargs(f, in_tree) @@ -1314,7 +1314,7 @@ def padded_spec(shape_spec): return tuple(dim if dim is masking._monomorphic_dim else masking.eval_poly(dim, padded_env) for dim in shape_spec) masking.check_shapes(map(padded_spec, out_specs), out_spec_tree, - map(onp.shape, outs), out_tree, "Padded output") + map(np.shape, outs), out_tree, "Padded output") return tree_unflatten(out_tree, outs) return wrapped_fun @@ -1326,7 +1326,7 @@ def shapecheck(in_shapes, out_shape, fun: Callable): out_specs, out_spec_tree = tree_flatten(out_shape) out_specs = map(masking.parse_spec, out_specs) flat_fun, out_tree_thunk = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes) + avals = map(partial(ShapedArray, dtype=np.float32), in_shapes) out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)] masking.check_shapes(map(tuple, out_specs), out_spec_tree, map(tuple, out_shapes), out_tree_thunk()) @@ -1439,9 +1439,9 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]: Here's a more complete example of using :py:func:`linearize`: >>> import jax - >>> import jax.numpy as np + >>> import jax.numpy as jnp >>> - >>> def f(x): return 3. * np.sin(x) + np.cos(x / 2.) + >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (DeviceArray(3.26819, dtype=float32), DeviceArray(-5.00753, dtype=float32)) @@ -1483,7 +1483,7 @@ def fun(*tangents): def _check_inexact_input_vjp(x): aval = core.get_aval(x) - if not dtypes.issubdtype(aval.dtype, onp.inexact): + if not dtypes.issubdtype(aval.dtype, np.inexact): msg = ("Primal inputs to reverse-mode differentiation must be of float " "or complex type, got type {}") raise TypeError(msg.format(aval.dtype.name)) @@ -1703,9 +1703,9 @@ class ShapeDtypeStruct(object): __slots__ = ["shape", "dtype"] def __init__(self, shape, dtype): self.shape = shape - self.dtype = onp.dtype(dtype) + self.dtype = np.dtype(dtype) - size = property(lambda self: onp.prod(self.shape)) + size = property(lambda self: np.prod(self.shape)) ndim = property(lambda self: len(self.shape)) def __len__(self): @@ -1773,16 +1773,16 @@ def __init__(self, shape, dtype): For example: >>> import jax - >>> import jax.numpy as np + >>> import jax.numpy as jnp >>> - >>> f = lambda A, x: np.tanh(np.dot(A, x)) + >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape ... self.dtype = dtype ... - >>> A = MyArgArray((2000, 3000), np.float32) - >>> x = MyArgArray((3000, 1000), np.float32) + >>> A = MyArgArray((2000, 3000), jnp.float32) + >>> x = MyArgArray((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) @@ -1790,7 +1790,7 @@ def __init__(self, shape, dtype): float32 """ def abstractify(x): - return ShapedArray(onp.shape(x), dtypes.result_type(x)) + return ShapedArray(np.shape(x), dtypes.result_type(x)) args_flat, in_tree = tree_flatten((args, kwargs)) wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, diff --git a/jax/core.py b/jax/core.py index b665b8a39c31..49d9dd80e07d 100644 --- a/jax/core.py +++ b/jax/core.py @@ -26,7 +26,7 @@ Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple, Type, Union, cast) -import numpy as onp +import numpy as np from . import dtypes from .config import FLAGS @@ -846,7 +846,7 @@ class UnshapedArray(AbstractValue): array_abstraction_level = 2 def __init__(self, dtype, weak_type=False): - self.dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)) + self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype)) self.weak_type = weak_type def __eq__(self, other): @@ -858,7 +858,7 @@ def __ne__(self, other): def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype - # objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use + # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.dtype, self.weak_type)) @@ -925,7 +925,7 @@ def __eq__(self, other): def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype - # objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use + # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.shape, self.dtype, self.weak_type)) @@ -968,16 +968,16 @@ class ConcreteArray(ShapedArray): array_abstraction_level = 0 def __init__(self, val, weak_type=False): - super(ConcreteArray, self).__init__(onp.shape(val), onp.result_type(val), + super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val), weak_type=weak_type) # Note: canonicalized self.dtype doesn't necessarily match self.val self.val = val - assert self.dtype != onp.dtype('O') + assert self.dtype != np.dtype('O') def __eq__(self, other): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type - and onp.all(self.val == other.val)) + and np.all(self.val == other.val)) def __hash__(self): return id(self.val) diff --git a/jax/experimental/loops.py b/jax/experimental/loops.py index d4c6ba0a3a5f..f5dce9d1586e 100644 --- a/jax/experimental/loops.py +++ b/jax/experimental/loops.py @@ -21,7 +21,7 @@ By default, loops and control-flow in JAX are executed and inlined during tracing. For example, in the following code the `for` loop is unrolled during JAX tracing:: - arr = onp.zeros(5) + arr = np.zeros(5) for i in range(arr.shape[0]): arr[i] += 2. if i % 2 == 0: @@ -32,7 +32,7 @@ conditionals as functions, and the array updates using a functional style that returns an updated array, e.g.:: - arr = onp.zeros(5) + arr = np.zeros(5) def loop_body(i, acc_arr): arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.) return lax.cond(i % 2 == 0, diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 6befde6945cd..55e90c30c379 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as onp +import numpy as np from typing import Any, Callable, Dict, Optional, Tuple, Union import jax @@ -103,7 +103,7 @@ def aval(self): return aval elif type(aval) is ShapedArray: assert 0 <= self.batch_dim < aval.ndim - new_shape = tuple(onp.delete(aval.shape, self.batch_dim)) + new_shape = tuple(np.delete(aval.shape, self.batch_dim)) return ShapedArray(new_shape, aval.dtype) else: raise TypeError(aval) @@ -236,7 +236,7 @@ def broadcast_batcher(prim, args, dims, **params): either an int indicating the batch dimension, or else `not_mapped` indicating no batching. """ - shapes = {(x.shape, d) for x, d in zip(args, dims) if onp.ndim(x)} + shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)} if len(shapes) == 1: # if there's only agreeing batch dims and scalars, just call the primitive d = next(d for d in dims if d is not not_mapped) @@ -245,16 +245,16 @@ def broadcast_batcher(prim, args, dims, **params): else: size, = {shape[d] for shape, d in shapes if d is not not_mapped} args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)] - ndim = max(onp.ndim(x) for x in args) # special-case scalar broadcasting + ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)] out = prim.bind(*args, **params) return (out, (0,) * len(out)) if prim.multiple_results else (out, 0) def _handle_scalar_broadcasting(nd, x, d): - if d is not_mapped or nd == onp.ndim(x): + if d is not_mapped or nd == np.ndim(x): return x else: - return x.reshape(x.shape + (1,) * (nd - onp.ndim(x))) + return x.reshape(x.shape + (1,) * (nd - np.ndim(x))) def defreducer(prim): primitive_batchers[prim] = partial(reducer_batcher, prim) @@ -262,8 +262,8 @@ def defreducer(prim): def reducer_batcher(prim, batched_args, batch_dims, axes, **params): operand, = batched_args bdim, = batch_dims - axes = tuple(onp.where(onp.less(axes, bdim), axes, onp.add(axes, 1))) - bdim_out = int(list(onp.delete(onp.arange(operand.ndim), axes)).index(bdim)) + axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1))) + bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim)) if 'input_shape' in params: params = dict(params, input_shape=operand.shape) return prim.bind(operand, axes=axes, **params), bdim_out @@ -303,10 +303,10 @@ def broadcast(x, sz, axis): if core.get_aval(x) is core.abstract_unit: return core.unit if axis is last: - axis = onp.ndim(x) - shape = list(onp.shape(x)) + axis = np.ndim(x) + shape = list(np.shape(x)) shape.insert(axis, sz) - broadcast_dims = tuple(onp.delete(onp.arange(len(shape)), axis)) + broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis)) return jax.lax.broadcast_in_dim(x, shape, broadcast_dims) def moveaxis(x, src, dst): @@ -315,7 +315,7 @@ def moveaxis(x, src, dst): if src == dst: return x src, dst = src % x.ndim, dst % x.ndim - perm = [i for i in range(onp.ndim(x)) if i != src] + perm = [i for i in range(np.ndim(x)) if i != src] perm.insert(dst, src) return x.transpose(perm) diff --git a/jax/interpreters/masking.py b/jax/interpreters/masking.py index 011bfa2bbb6a..2bdf3dc7b34c 100644 --- a/jax/interpreters/masking.py +++ b/jax/interpreters/masking.py @@ -20,7 +20,7 @@ import string from typing import Callable, Dict, Sequence, Union -import numpy as onp +import numpy as np from .. import abstract_arrays from .. import core, dtypes @@ -317,7 +317,7 @@ def parse_spec(spec=''): def _parse_dim(spec): if '+' in spec: - return onp.sum(map(_parse_dim, spec.split('+'))) + return np.sum(map(_parse_dim, spec.split('+'))) elif '*' in spec: return prod(map(_parse_dim, spec.split('*'))) elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit(): @@ -383,10 +383,10 @@ def full_lower(self): class MaskTrace(Trace): def pure(self, val): - return MaskTracer(self, val, onp.shape(val)) + return MaskTracer(self, val, np.shape(val)) def lift(self, val): - return MaskTracer(self, val, onp.shape(val)) + return MaskTracer(self, val, np.shape(val)) def sublift(self, val): return MaskTracer(self, val.val, val.polymorphic_shape) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index cfb5e96746ca..c383189ddd3d 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -20,7 +20,7 @@ Set, Tuple, Type, Union, cast) from weakref import ref -import numpy as onp +import numpy as np from .. import core from .. import linear_util as lu @@ -128,7 +128,7 @@ def instantiate_const(self, tracer) -> Tracer: if const is None: return tracer else: - if type(const) in core.literalable_types and onp.shape(const) == (): + if type(const) in core.literalable_types and np.shape(const) == (): return self.new_instantiated_literal(const) else: return self.new_instantiated_const(const) @@ -138,7 +138,7 @@ def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer': if const is None: return tracer else: - aval = raise_to_shaped(get_aval(const), onp.isscalar(const)) + aval = raise_to_shaped(get_aval(const), np.isscalar(const)) return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) def process_primitive(self, primitive, tracers, params): diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index b7f430f22410..6b12c3ff6e7c 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -37,7 +37,7 @@ Type, Union) from absl import logging -import numpy as onp +import numpy as np from ..config import flags from .. import core @@ -465,7 +465,7 @@ def _axis_index_bind(*, axis_name): nreps = dynamic_axis_env.nreps trace = frame.pmap_trace - out_aval = ShapedArray((), onp.int32) + out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(nreps=nreps, sizes=sizes, @@ -476,19 +476,19 @@ def _axis_index_bind(*, axis_name): if not frame.soft_trace: return out_tracer else: - val_out = out_tracer * frame.soft_size + onp.arange(frame.soft_size) + val_out = out_tracer * frame.soft_size + np.arange(frame.soft_size) return SplitAxisTracer(frame.soft_trace, axis_name, val_out) def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name): - div = xb.constant(c, onp.array(nreps // prod(sizes), dtype=onp.uint32)) - mod = xb.constant(c, onp.array(sizes[-1], dtype=onp.uint32)) + div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) + mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) - return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(onp.int32)) + return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) axis_index_p = core.Primitive('axis_index') axis_index_p.def_custom_bind(_axis_index_bind) axis_index_p.def_abstract_eval( - lambda *args, **params: ShapedArray((), onp.int32)) + lambda *args, **params: ShapedArray((), np.int32)) xla.translations[axis_index_p] = _axis_index_translation_rule @@ -587,7 +587,7 @@ def block_until_ready(self): def _value(self): if self._npy_value is None: self.copy_to_host_async() - npy_value = onp.empty(self.aval.shape, self.aval.dtype) + npy_value = np.empty(self.aval.shape, self.aval.dtype) for i in self.one_replica_buffer_indices: npy_value[self.indices[i]] = self.device_buffers[i].to_py() self._npy_value = npy_value @@ -633,7 +633,7 @@ def _shard_sharded_device_array_slow_path(x, devices, indices): shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array_slow_path def _sharded_device_array_constant_handler(c, val, canonicalize_types=True): - return xb.constant(c, onp.asarray(val), canonicalize_types=canonicalize_types) + return xb.constant(c, np.asarray(val), canonicalize_types=canonicalize_types) xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler) core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray @@ -838,7 +838,7 @@ def dynamic_fun(dummy, *args): # provided 1D list of devices). device_assignment = tree_map(lambda d: d.id, devices) # Convert to 2D in case it's 1D and we have > 1 partitions. - device_assignment = onp.array(device_assignment).reshape( + device_assignment = np.array(device_assignment).reshape( (num_global_replicas, num_partitions)) compile_options = xb.get_compile_options( num_replicas=num_global_replicas, @@ -933,7 +933,7 @@ def get_num_partitions(*partitions): if len(partition_specs) == 0: # Everything is specified as replicated (all Nones). return None - num_partitions_set = set(onp.prod(spec) for spec in partition_specs) + num_partitions_set = set(np.prod(spec) for spec in partition_specs) if len(num_partitions_set) > 1: raise ValueError( f"All partition specs must use the same number of total partitions, " @@ -1157,7 +1157,7 @@ def _xla_shard(c, aval, axis_env, x): return x elif isinstance(aval, ShapedArray): dims = list(c.get_shape(x).dimensions()) - zero = xb.constant(c, onp.zeros((), dtype=onp.uint32)) + zero = xb.constant(c, np.zeros((), dtype=np.uint32)) idxs = [_unravel_index(c, axis_env)] + [zero] * (len(dims) - 1) return xops.Reshape(xops.DynamicSlice(x, idxs, [1] + dims[1:]), dims[1:]) else: @@ -1169,16 +1169,16 @@ def _xla_unshard(c, aval, axis_env, x, backend): return x elif isinstance(aval, ShapedArray): # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU - convert_bool = (onp.issubdtype(aval.dtype, onp.bool_) + convert_bool = (np.issubdtype(aval.dtype, np.bool_) and xb.get_backend(backend).platform in ('cpu', 'gpu')) if convert_bool: - x = xops.ConvertElementType(x, xb.dtype_to_etype(onp.float32)) + x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32)) xla_shape = c.get_shape(x) dims = list(xla_shape.dimensions()) - padded = xops.Broadcast(xb.constant(c, onp.array(0, xla_shape.numpy_dtype())), + padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())), [axis_env.sizes[-1]] + dims) - zero = xb.constant(c, onp.zeros((), dtype=onp.uint32)) + zero = xb.constant(c, np.zeros((), dtype=np.uint32)) idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims) padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs) replica_groups_protos = xc.make_replica_groups( @@ -1187,15 +1187,15 @@ def _xla_unshard(c, aval, axis_env, x, backend): # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU if convert_bool: - nonzero = xops.Ne(out, xb.constant(c, onp.array(0, dtype=onp.float32))) - out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(onp.bool_)) + nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32))) + out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_)) return out else: raise TypeError((aval, c.get_shape(x))) def _unravel_index(c, axis_env): - div = xb.constant(c, onp.array(axis_env.nreps // prod(axis_env.sizes), onp.uint32)) - mod = xb.constant(c, onp.array(axis_env.sizes[-1], onp.uint32)) + div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32)) + mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32)) return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) @@ -1278,7 +1278,7 @@ def process_primitive(self, primitive, tracers, params): if primitive is axis_index_p: dummy, = vals_in hard_idx = primitive.bind(dummy, **params) - val_out = hard_idx * params['soft_size'] + onp.arange(params['soft_size']) + val_out = hard_idx * params['soft_size'] + np.arange(params['soft_size']) return SplitAxisTracer(self, params['axis_name'], val_out) elif all(axis_name is not_mapped for axis_name in names_in): return primitive.bind(*vals_in, **params) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index bd2738b23beb..64ec8ebcda44 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -20,7 +20,7 @@ from warnings import warn from absl import logging -import numpy as onp +import numpy as np from ..config import flags, bool_env from .. import core @@ -74,11 +74,11 @@ def identity(x): return x _scalar_types = dtypes.python_scalar_dtypes.keys() # unit representation -def _make_unit(c): return xb.constant(c, onp.zeros((), dtype=onp.dtype('bool'))) -def _make_abstract_unit(_): return xc.Shape.array_shape(onp.dtype('bool'), ()) +def _make_unit(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool'))) +def _make_abstract_unit(_): return xc.Shape.array_shape(np.dtype('bool'), ()) def _device_put_unit(_, device): backend = xb.get_device_backend(device) - return backend.buffer_from_pyval(onp.zeros((), dtype=onp.dtype('bool')), + return backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')), device) def _make_array_shape(a): return xc.Shape.array_shape(a.dtype, a.shape) @@ -143,10 +143,10 @@ def canonicalize_dtype(x): raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}") def _canonicalize_ndarray_dtype(x): - return onp.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x))) + return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x))) def _canonicalize_python_scalar_dtype(typ, x): - return onp.asarray( + return np.asarray( x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ])) canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity} @@ -342,8 +342,8 @@ def check_nans(prim, bufs): def _check_nans(name, xla_shape, buf): assert not xla_shape.is_tuple() - if dtypes.issubdtype(xla_shape.element_type(), onp.inexact): - if onp.any(onp.isnan(buf.to_py())): + if dtypes.issubdtype(xla_shape.element_type(), np.inexact): + if np.any(np.isnan(buf.to_py())): raise FloatingPointError(f"invalid value (nan) encountered in {name}") ### compiling jaxprs @@ -477,10 +477,10 @@ def _axis_groups(nrep, mesh_spec, mesh_axes): trailing_size, ragged = divmod(nrep, prod(mesh_spec)) assert not ragged full_spec = list(mesh_spec) + [trailing_size] - iota = onp.arange(prod(full_spec)).reshape(full_spec) - groups = onp.reshape( - onp.moveaxis(iota, mesh_axes, onp.arange(len(mesh_axes))), - (prod(onp.take(full_spec, mesh_axes)), -1)) + iota = np.arange(prod(full_spec)).reshape(full_spec) + groups = np.reshape( + np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))), + (prod(np.take(full_spec, mesh_axes)), -1)) return tuple(unsafe_map(tuple, groups.T)) def jaxpr_replicas(jaxpr): @@ -862,7 +862,7 @@ def _xla_call_translation_rule(c, axis_env, def zeros_like_translation_rule(c, x): shape = c.get_shape(x) assert not shape.is_tuple() - zero = xb.constant(c, onp.array(0, shape.element_type())) + zero = xb.constant(c, np.array(0, shape.element_type())) return xops.Broadcast(zero, shape.dimensions()) translations[ad_util.zeros_like_p] = zeros_like_translation_rule @@ -1018,7 +1018,7 @@ def ndim(self): def copy(self): """Returns an ndarray (backed by host memory, not device memory).""" - return onp.asarray(self) + return np.asarray(self) def copy_to_host_async(self): """Requests a copy of the buffer to the host.""" @@ -1042,10 +1042,10 @@ def delete(self): self._npy_value = None def __repr__(self): - line_width = onp.get_printoptions()['linewidth'] + line_width = np.get_printoptions()['linewidth'] prefix = '{}('.format(self.__class__.__name__) - s = onp.array2string(self._value, prefix=prefix, suffix=',', - separator=', ', max_line_width=line_width) + s = np.array2string(self._value, prefix=prefix, suffix=',', + separator=', ', max_line_width=line_width) dtype_str = 'dtype={})'.format(self.dtype.name) last_line_len = len(s) - s.rfind('\n') + 1 sep = ' ' @@ -1054,13 +1054,13 @@ def __repr__(self): return "{}{},{}{}".format(prefix, s, sep, dtype_str) def item(self): - if dtypes.issubdtype(self.dtype, onp.complexfloating): + if dtypes.issubdtype(self.dtype, np.complexfloating): return complex(self) - elif dtypes.issubdtype(self.dtype, onp.floating): + elif dtypes.issubdtype(self.dtype, np.floating): return float(self) - elif dtypes.issubdtype(self.dtype, onp.integer): + elif dtypes.issubdtype(self.dtype, np.integer): return int(self) - elif dtypes.issubdtype(self.dtype, onp.bool_): + elif dtypes.issubdtype(self.dtype, np.bool_): return bool(self) else: raise TypeError(self.dtype) @@ -1091,7 +1091,7 @@ def __format__(self, format_spec): return format(self._value, format_spec) def __array__(self, dtype=None, context=None): - return onp.asarray(self._value, dtype=dtype) + return np.asarray(self._value, dtype=dtype) @property def __cuda_array_interface__(self): @@ -1251,10 +1251,10 @@ def _remat_translation_rule(c, axis_env, in_nodes, Conditional.""" del device, concrete # Unused. # Fake condition which always selects True branch. - rng = xops.RngUniform(xb.constant(c, onp.array(0, dtype=onp.float32)), - xb.constant(c, onp.array(1, dtype=onp.float32)), + rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)), + xb.constant(c, np.array(1, dtype=np.float32)), xc.Shape.array_shape(xc.PrimitiveType.F32, [])) - pred = xops.Lt(rng, xb.constant(c, onp.array(2, dtype=onp.float32))) + pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32))) true_op = xops.Tuple(c, in_nodes) remat_subc = xb.make_computation_builder("remat_call_subcomputation") @@ -1272,7 +1272,7 @@ def _remat_translation_rule(c, axis_env, in_nodes, def zeros(xla_shape): shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype() - zero = xb.constant(dummy_subc, onp.array(0, dtype=dtype)) + zero = xb.constant(dummy_subc, np.array(0, dtype=dtype)) return xops.Broadcast(zero, shape) out_nodes = [zeros(s) for s in out_node_shapes] dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes)) diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 9aa5b1108c3f..ddec40138de1 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -20,7 +20,7 @@ from typing import (Any, Callable, List, NamedTuple, Optional, Sequence, Union, Tuple) import warnings -import numpy as onp +import numpy as np from .. import core from .. import ad_util @@ -61,9 +61,9 @@ def _try_broadcast_shapes(shapes): # Replace 1 with 0 to avoid inconclusive comparisons for polymorphic dims: - out_shape = onp.max(onp.where(shapes == 1, 0, shapes), axis=0) - out_shape = onp.where(onp.all(shapes == 1, axis=0), 1, out_shape) - if not onp.all((shapes == out_shape) | (shapes == 1)): + out_shape = np.max(np.where(shapes == 1, 0, shapes), axis=0) + out_shape = np.where(np.all(shapes == 1, axis=0), 1, out_shape) + if not np.all((shapes == out_shape) | (shapes == 1)): return None return canonicalize_shape(out_shape) @@ -73,7 +73,7 @@ def broadcast_shapes(*shapes): if len(shapes) == 1: return shapes[0] ndim = _max(len(shape) for shape in shapes) - shapes = onp.array([(1,) * (ndim - len(shape)) + shape for shape in shapes]) + shapes = np.array([(1,) * (ndim - len(shape)) + shape for shape in shapes]) result_shape = _try_broadcast_shapes(shapes) if result_shape is None: raise ValueError("Incompatible shapes for broadcasting: {}" @@ -385,14 +385,14 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array: # type. If we passed a Python scalar directly to the bind call below, it is # cast to the default type as part of the calling convention. if type(operand) in dtypes.python_scalar_dtypes: - operand = onp.asarray(operand, new_dtype) + operand = np.asarray(operand, new_dtype) old_dtype = dtypes.canonicalize_dtype(_dtype(operand)) if old_dtype == new_dtype: return operand - if (dtypes.issubdtype(old_dtype, onp.complexfloating) and - not dtypes.issubdtype(new_dtype, onp.complexfloating)): + if (dtypes.issubdtype(old_dtype, np.complexfloating) and + not dtypes.issubdtype(new_dtype, np.complexfloating)): msg = "Casting complex values to real discards the imaginary part" - warnings.warn(msg, onp.ComplexWarning, stacklevel=2) + warnings.warn(msg, np.ComplexWarning, stacklevel=2) return convert_element_type_p.bind( operand, new_dtype=new_dtype, old_dtype=old_dtype) @@ -546,10 +546,10 @@ def conv_general_dilated( rhs_dilation = (1,) * (rhs.ndim - 2) if isinstance(padding, str): lhs_perm, rhs_perm, _ = dnums - rhs_shape = onp.take(rhs.shape, rhs_perm)[2:] + rhs_shape = np.take(rhs.shape, rhs_perm)[2:] effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)] padding = padtype_to_pads( - onp.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, + np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, window_strides, padding) return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), @@ -614,15 +614,15 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers, contract_dims_seq, batch_dims_seq = dimension_numbers contract_dims = tuple(map(lambda x: tuple(x), contract_dims_seq)) batch_dims = tuple(map(lambda x: tuple(x), batch_dims_seq)) - if not dtypes.issubdtype(lhs.dtype, onp.inexact): + if not dtypes.issubdtype(lhs.dtype, np.inexact): # TODO(b/134526360): XLA doesn't support bool or integer dots, so we emit a # sum of products instead. lhs_contract_dims, rhs_contract_dims = contract_dims lhs_batch_dims, rhs_batch_dims = batch_dims lhs_noncontract_dims = tuple(sorted( - set(range(onp.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims))) + set(range(np.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims))) rhs_noncontract_dims = tuple(sorted( - set(range(onp.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims))) + set(range(np.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims))) lhs = transpose(lhs, lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims) rhs = transpose(rhs, @@ -638,8 +638,8 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers, out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) + len(rhs_noncontract_dims)) - op_product = bitwise_and if lhs.dtype == onp.bool_ else mul - op_sum = bitwise_or if lhs.dtype == onp.bool_ else add + op_product = bitwise_and if lhs.dtype == np.bool_ else mul + op_sum = bitwise_or if lhs.dtype == np.bool_ else add return reduce(op_product(lhs, rhs), _zero(lhs), op_sum, tuple(range(out_ndim, out_ndim + len(lhs_contract_dims)))) @@ -662,8 +662,8 @@ def broadcast(operand: Array, sizes: Sequence[int]) -> Array: Returns: An array containing the result. """ - dims = tuple(range(len(sizes), len(sizes) + onp.ndim(operand))) - return broadcast_in_dim(operand, tuple(sizes) + onp.shape(operand), dims) + dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand))) + return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims) def broadcast_in_dim(operand: Array, shape: Shape, broadcast_dimensions: Sequence[int]) -> Array: @@ -673,7 +673,7 @@ def broadcast_in_dim(operand: Array, shape: Shape, """ shape = _broadcast_in_dim_shape_rule( operand, shape=shape, broadcast_dimensions=broadcast_dimensions) - if onp.ndim(operand) == len(shape) and not len(broadcast_dimensions): + if np.ndim(operand) == len(shape) and not len(broadcast_dimensions): return operand return broadcast_in_dim_p.bind( operand, shape=tuple(shape), @@ -695,9 +695,9 @@ def reshape(operand: Array, new_sizes: Shape, """ new_sizes = canonicalize_shape(new_sizes) # TODO new_sizes = tuple(new_sizes) - same_shape = onp.shape(operand) == new_sizes - same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand))) - if onp.shape(operand) and same_shape and same_dims: + same_shape = np.shape(operand) == new_sizes + same_dims = dimensions is None or tuple(dimensions) == tuple(range(np.ndim(operand))) + if np.shape(operand) and same_shape and same_dims: return operand else: return reshape_p.bind( @@ -733,8 +733,8 @@ def slice(operand: Array, start_indices: Sequence[int], `_ operator. """ - if (onp.all(onp.equal(start_indices, 0)) - and onp.all(onp.equal(limit_indices, operand.shape)) + if (np.all(np.equal(start_indices, 0)) + and np.all(np.equal(limit_indices, operand.shape)) and strides is None): return operand else: @@ -1005,7 +1005,7 @@ def scatter(operand: Array, scatter_indices:Array, updates: Array, def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: indices = concatenate([expand_dims(i, (1,)) for i in idxs], 1) - indices = indices % onp.array([src.shape[ax] for ax in axes]) + indices = indices % np.array([src.shape[ax] for ax in axes]) slice_sizes = list(src.shape) for ax in axes: slice_sizes[ax] = 1 @@ -1066,34 +1066,34 @@ def _get_monoid_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]: dtype = _dtype(x) if (type(aval) is ConcreteArray) and aval.shape == (): if monoid_op is add: - return onp.equal(aval.val, 0) and _reduce_sum + return np.equal(aval.val, 0) and _reduce_sum if monoid_op is mul: - return onp.equal(aval.val, 1) and _reduce_prod - elif monoid_op is bitwise_or and dtype == onp.bool_: - return onp.equal(aval.val, _get_max_identity(dtype)) and _reduce_or - elif monoid_op is bitwise_and and dtype == onp.bool_: - return onp.equal(aval.val, _get_min_identity(dtype)) and _reduce_and + return np.equal(aval.val, 1) and _reduce_prod + elif monoid_op is bitwise_or and dtype == np.bool_: + return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_or + elif monoid_op is bitwise_and and dtype == np.bool_: + return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_and elif monoid_op is max: - return onp.equal(aval.val, _get_max_identity(dtype)) and _reduce_max + return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_max elif monoid_op is min: - return onp.equal(aval.val, _get_min_identity(dtype)) and _reduce_min + return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_min return None def _get_max_identity(dtype: DType) -> Array: - if dtypes.issubdtype(dtype, onp.inexact): - return onp.array(-onp.inf, dtype) - elif dtypes.issubdtype(dtype, onp.integer): - return onp.array(dtypes.iinfo(dtype).min, dtype) - elif dtypes.issubdtype(dtype, onp.bool_): - return onp.array(False, onp.bool_) + if dtypes.issubdtype(dtype, np.inexact): + return np.array(-np.inf, dtype) + elif dtypes.issubdtype(dtype, np.integer): + return np.array(dtypes.iinfo(dtype).min, dtype) + elif dtypes.issubdtype(dtype, np.bool_): + return np.array(False, np.bool_) def _get_min_identity(dtype: DType) -> Array: - if dtypes.issubdtype(dtype, onp.inexact): - return onp.array(onp.inf, dtype) - elif dtypes.issubdtype(dtype, onp.integer): - return onp.array(dtypes.iinfo(dtype).max, dtype) - elif dtypes.issubdtype(dtype, onp.bool_): - return onp.array(True, onp.bool_) + if dtypes.issubdtype(dtype, np.inexact): + return np.array(np.inf, dtype) + elif dtypes.issubdtype(dtype, np.integer): + return np.array(dtypes.iinfo(dtype).max, dtype) + elif dtypes.issubdtype(dtype, np.bool_): + return np.array(True, np.bool_) def _reduce_sum(operand: Array, axes: Sequence[int]) -> Array: return reduce_sum_p.bind(operand, axes=tuple(axes)) @@ -1301,9 +1301,9 @@ def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Arra will be cast to `dtype`. """ shape = canonicalize_shape(shape) - if onp.shape(fill_value): + if np.shape(fill_value): msg = "full must be called with scalar fill_value, got fill_value.shape {}." - raise TypeError(msg.format(onp.shape(fill_value))) + raise TypeError(msg.format(np.shape(fill_value))) dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) # TODO(mattjj): remove device_put when dtype conversion produces DeviceArray fill_value = xla.device_put_p.bind(convert_element_type(fill_value, dtype)) @@ -1348,7 +1348,7 @@ def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array: shape = tuple(map(int, shape)) axes = tuple(map(int, axes)) dtype = dtypes.canonicalize_dtype(dtype) - base_shape = tuple(onp.take(shape, axes)) + base_shape = tuple(np.take(shape, axes)) lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes) aval = ShapedArray(shape, dtype) return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) @@ -1460,7 +1460,7 @@ def _conv_transpose_padding(k, s, padding): if s > k - 1: pad_a = k - 1 else: - pad_a = int(onp.ceil(pad_len / 2)) + pad_a = int(np.ceil(pad_len / 2)) elif padding == 'VALID': pad_len = k + s - 2 + _max(k - s, 0) pad_a = k - 1 @@ -1473,7 +1473,7 @@ def _conv_transpose_padding(k, s, padding): def _flip_axes(x, axes): """Flip ndarray 'x' along each axis specified in axes tuple.""" for axis in axes: - x = onp.flip(x, axis) + x = np.flip(x, axis) return x @@ -1529,7 +1529,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], else: raise ValueError('No 4+ dimensional dimension_number defaults.') dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) - k_shape = onp.take(rhs.shape, dn.rhs_spec) + k_shape = np.take(rhs.shape, dn.rhs_spec) k_sdims = k_shape[2:] # Calculate correct output shape given padding and strides. pads: Union[str, Sequence[Tuple[int, int]]] @@ -1543,8 +1543,8 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], pads = padding if transpose_kernel: # flip spatial dims and swap input / output channel axes - rhs = _flip_axes(rhs, onp.array(dn.rhs_spec)[2:]) - rhs = onp.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) + rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn, precision=precision) @@ -1563,7 +1563,7 @@ def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None, An ndarray with the same shape as `x` with its entries set equal to `fill_value`, similar to the output of np.full. """ - fill_shape = onp.shape(x) if shape is None else canonicalize_shape(shape) + fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape) fill_value = tie_in(x, fill_value) return full(fill_shape, fill_value, dtype or _dtype(x)) @@ -1688,9 +1688,9 @@ def _upcast_fp16_for_computation(f): @functools.wraps(f) def f_wrapped(x): dtype = _dtype(x) - if dtype == onp.float16 or dtype == dtypes.bfloat16: + if dtype == np.float16 or dtype == dtypes.bfloat16: return convert_element_type( - f(convert_element_type(x, onp.float32)), dtype) + f(convert_element_type(x, np.float32)), dtype) return f(x) return f_wrapped @@ -1714,7 +1714,7 @@ def acos(x: Array) -> Array: ne(x, _const(x, -1.0)), mul(_const(x, 2), atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))), - full_like(x, onp.pi)) + full_like(x, np.pi)) def atan(x: Array) -> Array: r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.""" @@ -1773,7 +1773,7 @@ def zeros_like_array(x): _input_dtype = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype) _fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype) -_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype +_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None): prim = Primitive(name) @@ -1805,7 +1805,7 @@ def standard_translate(name, c, *args, **kwargs): def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes): msg = '{} does not accept dtype {}. Accepted dtypes are subtypes of {}.' - typename = str(onp.dtype(aval.dtype).name) + typename = str(np.dtype(aval.dtype).name) accepted_typenames = (t.__name__ for t in accepted_dtypes) raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames))) return result_dtype(aval.dtype) @@ -1828,7 +1828,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs): if not any(dtypes.issubdtype(aval_dtype, t) for t in types): msg = ('{} does not accept dtype {} at position {}. ' 'Accepted dtypes at position {} are subtypes of {}.') - typename = str(onp.dtype(aval_dtype).name) + typename = str(np.dtype(aval_dtype).name) typenames = ', '.join(t.__name__ for t in types) raise TypeError(msg.format(name, typename, i, i, typenames)) _check_same_dtypes(name, False, *aval_dtypes) @@ -1836,7 +1836,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs): def _broadcasting_shape_rule(name, *avals): - shapes = onp.array([aval.shape for aval in avals if aval.shape]) + shapes = np.array([aval.shape for aval in avals if aval.shape]) if not shapes.size: return () if len({len(shape) for shape in shapes}) != 1: @@ -1889,32 +1889,32 @@ def _brcast(x, *others): # Requires shape info during jvp tracing, which isn't strictly necessary. # We don't need full numpy broadcasting, but otherwise the logic is the same # so we reuse the broadcast_shapes function after filtering out scalars. - shapes = tuple(filter(None, map(onp.shape, (x,) + others))) + shapes = tuple(filter(None, map(np.shape, (x,) + others))) shape = shapes and broadcast_shapes(*shapes) - if onp.shape(x) != shape: + if np.shape(x) != shape: return _brcast_to(x, shape) else: return x def _brcast_to(x, shape): - x_shape = onp.shape(x) + x_shape = np.shape(x) assert x_shape != shape if x_shape: assert len(x_shape) == len(shape) - broadcast_dimensions, = onp.where(onp.equal(x_shape, shape)) - squeezed_dimensions, = onp.where(onp.not_equal(x_shape, shape)) + broadcast_dimensions, = np.where(np.equal(x_shape, shape)) + squeezed_dimensions, = np.where(np.not_equal(x_shape, shape)) squeezed = squeeze(x, squeezed_dimensions) return broadcast_in_dim(squeezed, shape, broadcast_dimensions) else: return broadcast(x, shape) -_float = {onp.floating} -_complex = {onp.complexfloating} -_complex_elem_types = {onp.float32, onp.float64} -_int = {onp.integer} -_bool = {onp.bool_} +_float = {np.floating} +_complex = {np.complexfloating} +_complex_elem_types = {np.float32, np.float64} +_int = {np.integer} +_bool = {np.bool_} _num = _int | _float | _complex _any = _int | _float | _complex | _bool @@ -1926,11 +1926,11 @@ def _brcast_to(x, shape): def _sign_translation_rule(c, x): shape = c.get_shape(x) dtype = shape.numpy_dtype() - if dtypes.issubdtype(dtype, onp.unsignedinteger): - zero = xb.constant(c, onp.array(0, dtype=dtype)) + if dtypes.issubdtype(dtype, np.unsignedinteger): + zero = xb.constant(c, np.array(0, dtype=dtype)) dims = c.get_shape(x).dimensions() return xops.Select(xops.Eq(x, zero), xops.Broadcast(zero, dims), - xops.Broadcast(xb.constant(c, onp.array(1, dtype=dtype)), + xops.Broadcast(xb.constant(c, np.array(1, dtype=dtype)), dims)) return xops.Sign(x) @@ -1950,7 +1950,7 @@ def _sign_translation_rule(c, x): round_p = standard_unop(_float, 'round') ad.defjvp_zero(round_p) -is_finite_p = unop(_fixed_dtype(onp.bool_), _float, 'is_finite') +is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite') ad.defjvp_zero(is_finite_p) exp_p = standard_unop(_float | _complex, 'exp') @@ -2070,24 +2070,24 @@ def _bessel_i1e_jvp(g, y, x): ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp) erf_p = standard_unop(_float, 'erf') -ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / onp.sqrt(onp.pi)), +ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)), mul(g, exp(neg(square(x)))))) erfc_p = standard_unop(_float, 'erfc') -ad.defjvp(erfc_p, lambda g, x: mul(_const(x, 2. / onp.sqrt(onp.pi)), +ad.defjvp(erfc_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)), mul(neg(g), exp(neg(square(x)))))) erf_inv_p = standard_unop(_float, 'erf_inv') -ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, onp.sqrt(onp.pi) / 2.), +ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.), mul(g, exp(square(ans))))) real_p = unop(_complex_basetype, _complex, 'real') -ad.deflinear(real_p, lambda t: [complex(t, onp.zeros((), _dtype(t)))]) +ad.deflinear(real_p, lambda t: [complex(t, np.zeros((), _dtype(t)))]) imag_p = unop(_complex_basetype, _complex, 'imag') ad.defjvp(imag_p, lambda g, _: real(mul(_const(g, -1j), g))) -_complex_dtype = lambda dtype, *args: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype +_complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types], 'complex') ad.deflinear(complex_p, lambda t: [real(t), imag(neg(t))]) @@ -2096,7 +2096,7 @@ def _bessel_i1e_jvp(g, y, x): def _conj_transpose_rule(t, x, *, input_dtype): assert ad.is_undefined_primal(x) - if dtypes.issubdtype(input_dtype, onp.complexfloating): + if dtypes.issubdtype(input_dtype, np.complexfloating): return [conj(t)] else: return [real(t)] @@ -2139,7 +2139,7 @@ def _pow_jvp_rhs(g, ans, x, y): def _integer_pow_dtype_rule(x, *, y): dtype = unop_dtype_rule(_identity, _int | _float | _complex, 'integer_pow', x) - if y < 0 and dtypes.issubdtype(dtype, onp.integer): + if y < 0 and dtypes.issubdtype(dtype, np.integer): raise TypeError("Integers cannot be raised to negative powers, got " f"integer_pow({x}, {y})") return dtype @@ -2147,7 +2147,7 @@ def _integer_pow_dtype_rule(x, *, y): def _integer_pow_translation_rule(c, x, *, y): if y == 0: shape = c.get_shape(x) - return xb.constant(c, onp.array(1, dtype=shape.numpy_dtype())) + return xb.constant(c, np.array(1, dtype=shape.numpy_dtype())) is_reciprocal = y < 0 if is_reciprocal: y = -y @@ -2252,7 +2252,7 @@ def _broadcasting_select(c, which, x, y): def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): dtype = c.get_shape(x).numpy_dtype() - if dtypes.issubdtype(dtype, onp.complexfloating): + if dtypes.issubdtype(dtype, np.complexfloating): rx = xops.Real(x) ry = xops.Real(y) return _broadcasting_select( @@ -2283,22 +2283,22 @@ def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical') ad.defjvp_zero(shift_right_logical_p) -eq_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'eq') +eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq') ad.defjvp_zero(eq_p) -ne_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'ne') +ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne') ad.defjvp_zero(ne_p) -ge_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'ge') +ge_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ge') ad.defjvp_zero(ge_p) -gt_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'gt') +gt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'gt') ad.defjvp_zero(gt_p) -le_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'le') +le_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'le') ad.defjvp_zero(le_p) -lt_p = naryop(_fixed_dtype(onp.bool_), [_any, _any], 'lt') +lt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'lt') ad.defjvp_zero(lt_p) @@ -2309,8 +2309,8 @@ def _convert_element_type_dtype_rule(operand, *, new_dtype, old_dtype): return new_dtype def _convert_element_type_translation_rule(c, operand, *, new_dtype, old_dtype): - if (dtypes.issubdtype(old_dtype, onp.complexfloating) and - not dtypes.issubdtype(new_dtype, onp.complexfloating)): + if (dtypes.issubdtype(old_dtype, np.complexfloating) and + not dtypes.issubdtype(new_dtype, np.complexfloating)): operand = xops.Real(operand) new_etype = xla_client.dtype_to_etype(new_dtype) return xops.ConvertElementType(operand, new_element_type=new_etype) @@ -2395,11 +2395,11 @@ def _conv_general_dilated_shape_rule( raise ValueError(msg.format(batch_group_count, feature_group_count)) lhs_perm, rhs_perm, out_perm = dimension_numbers - lhs_trans = _dilate_shape(onp.take(lhs.shape, lhs_perm), lhs_dilation) - rhs_trans = _dilate_shape(onp.take(rhs.shape, rhs_perm), rhs_dilation) + lhs_trans = _dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation) + rhs_trans = _dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation) out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding, batch_group_count) - return tuple(onp.take(out_trans, onp.argsort(out_perm))) + return tuple(np.take(out_trans, np.argsort(out_perm))) def _conv_general_dilated_dtype_rule( lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, @@ -2457,8 +2457,8 @@ def _conv_general_dilated_transpose_lhs( feature_group_count = batch_group_count trans_dimension_numbers = ConvDimensionNumbers(out_spec, t_rhs_spec, lhs_spec) padding = _conv_general_vjp_lhs_padding( - onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims), - window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation, + np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims), + window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation, rhs_dilation) revd_weights = rev(rhs, rhs_sdims) out = conv_general_dilated( @@ -2477,7 +2477,7 @@ def _conv_general_dilated_transpose_rhs( dimension_numbers: ConvDimensionNumbers, feature_group_count: int, batch_group_count: int, lhs_shape, rhs_shape, precision): assert type(dimension_numbers) is ConvDimensionNumbers - if onp.size(g) == 0: + if np.size(g) == 0: # Avoids forming degenerate convolutions where the RHS has spatial size 0. return ad_util.Zero lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers) @@ -2491,8 +2491,8 @@ def _conv_general_dilated_transpose_rhs( feature_group_count = 1 trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans) padding = _conv_general_vjp_rhs_padding( - onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims), - window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation, + np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims), + window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation, rhs_dilation) return conv_general_dilated( lhs, g, window_strides=rhs_dilation, padding=padding, @@ -2514,7 +2514,7 @@ def _conv_general_dilated_translation_rule( x, y, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision_config=precision_config) - if expand_complex_convolutions and onp.issubdtype(dtype, onp.complexfloating): + if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating): # We use a trick for complex multiplication due to Gauss which uses three # multiplications and five additions; instead of the naive method of four # multiplications and two additions. @@ -2621,7 +2621,7 @@ def _masked(padded_value, logical_shape, dimensions, value=0): if len(dimensions) == 0: return padded_value - masks = [broadcasted_iota(onp.int32, padded_value.shape, d) < logical_shape[d] + masks = [broadcasted_iota(np.int32, padded_value.shape, d) < logical_shape[d] for d in dimensions] mask_intersection = masks[0] for mask in masks[1:]: @@ -2636,8 +2636,8 @@ def _conv_general_dilated_masking_rule( logical_lhs_shape, logical_rhs_shape = logical_shapes o, i, *window_dimensions = dimension_numbers.rhs_spec - assert (onp.all(onp.take(rhs.shape, window_dimensions) - == onp.take(logical_rhs_shape, window_dimensions))), \ + assert (np.all(np.take(rhs.shape, window_dimensions) + == np.take(logical_rhs_shape, window_dimensions))), \ "Conv filter masking not yet implemented." n, c, *padded_dimensions = dimension_numbers.lhs_spec @@ -2675,7 +2675,7 @@ def _conv_general_dilated_masking_rule( def _reshape_axis_into(src, dst, x): perm = [i for i in range(x.ndim) if i != src] perm.insert(dst, src) - new_shape = list(onp.delete(x.shape, src)) + new_shape = list(np.delete(x.shape, src)) new_shape[dst] *= x.shape[src] return reshape(x, new_shape, perm) @@ -2696,14 +2696,14 @@ def _precision_config(precision): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - if not all(onp.all(onp.greater_equal(d, 0)) and onp.all(onp.less(d, lhs.ndim)) + if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): msg = ("dot_general requires lhs dimension numbers to be nonnegative and " "less than the number of axes of the lhs value, got " f"lhs_batch of {lhs_batch} and lhs_contracting of {lhs_contracting} " f"for lhs of rank {lhs.ndim}") raise TypeError(msg) - if not all(onp.all(onp.greater_equal(d, 0)) and onp.all(onp.less(d, rhs.ndim)) + if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, rhs.ndim)) for d in (rhs_contracting, rhs_batch)): msg = ("dot_general requires rhs dimension numbers to be nonnegative and " "less than the number of axes of the rhs value, got " @@ -2714,13 +2714,13 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision): msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch " "dimensions, got lhs_batch {} and rhs_batch {}.") raise TypeError(msg.format(lhs_batch, rhs_batch)) - if not onp.all(onp.equal(lhs_batch, rhs_batch)): + if not np.all(np.equal(lhs_batch, rhs_batch)): msg = ("dot_general requires same lhs and rhs batch dimension numbers, " "got {} and {}.") raise TypeError(msg.format(lhs_batch, rhs_batch)) - lhs_batch_shape = onp.take(lhs.shape, lhs_batch) - rhs_batch_shape = onp.take(rhs.shape, rhs_batch) - if not onp.all(onp.equal(lhs_batch_shape, rhs_batch_shape)): + lhs_batch_shape = np.take(lhs.shape, lhs_batch) + rhs_batch_shape = np.take(rhs.shape, rhs_batch) + if not np.all(np.equal(lhs_batch_shape, rhs_batch_shape)): msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " "to have the same shape, got {} and {}.") raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape)) @@ -2732,18 +2732,18 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision): msg = ("dot_general requires rhs batch dimensions to precede contracting " "and non-contracting dimensions, got rhs_batch {}.") raise TypeError(msg.format(rhs_batch)) - lhs_contracting_shape = onp.take(lhs.shape, lhs_contracting) - rhs_contracting_shape = onp.take(rhs.shape, rhs_contracting) - if not onp.all(onp.equal(lhs_contracting_shape, rhs_contracting_shape)): + lhs_contracting_shape = np.take(lhs.shape, lhs_contracting) + rhs_contracting_shape = np.take(rhs.shape, rhs_contracting) + if not np.all(np.equal(lhs_contracting_shape, rhs_contracting_shape)): msg = ("dot_general requires contracting dimensions to have the same " "shape, got {} and {}.") raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape)) - batch_shape = tuple(onp.take(lhs.shape, lhs_batch)) + batch_shape = tuple(np.take(lhs.shape, lhs_batch)) lhs_contract_or_batch = tuple(lhs_contracting) + tuple(lhs_batch) - lhs_tensored_shape = tuple(onp.delete(lhs.shape, lhs_contract_or_batch)) + lhs_tensored_shape = tuple(np.delete(lhs.shape, lhs_contract_or_batch)) rhs_contract_or_batch = tuple(rhs_contracting) + tuple(rhs_batch) - rhs_tensored_shape = tuple(onp.delete(rhs.shape, rhs_contract_or_batch)) + rhs_tensored_shape = tuple(np.delete(rhs.shape, rhs_contract_or_batch)) return batch_shape + lhs_tensored_shape + rhs_tensored_shape @@ -2762,8 +2762,8 @@ def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision, else: ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) dims = ((ans_y, y_kept), (ans_batch, y_batch)) - x_contract_sorted_by_y = list(onp.take(x_contract, onp.argsort(y_contract))) - out_axes = onp.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) + out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) return transpose(dot_general(g, y, dims, precision=precision), tuple(out_axes)) @@ -2791,18 +2791,18 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs = batching.moveaxis(lhs, lbd, 0) if rbd != 0: rhs = batching.moveaxis(rhs, rbd, 0) - lhs_batch = (0,) + tuple(onp.add(1, lhs_batch)) - rhs_batch = (0,) + tuple(onp.add(1, rhs_batch)) - lhs_contract = tuple(onp.add(1, lhs_contract)) - rhs_contract = tuple(onp.add(1, rhs_contract)) + lhs_batch = (0,) + tuple(np.add(1, lhs_batch)) + rhs_batch = (0,) + tuple(np.add(1, rhs_batch)) + lhs_contract = tuple(np.add(1, lhs_contract)) + rhs_contract = tuple(np.add(1, rhs_contract)) result_batch_dim = 0 else: # adding a tensor product dimension if lbd is not None: - if lhs_batch == () or lbd > onp.max(lhs_batch): + if lhs_batch == () or lbd > np.max(lhs_batch): # can avoid transposes - bump_lhs_contract = onp.greater_equal(lhs_contract, lbd) - lhs_contract = tuple(onp.add(lhs_contract, bump_lhs_contract)) + bump_lhs_contract = np.greater_equal(lhs_contract, lbd) + lhs_contract = tuple(np.add(lhs_contract, bump_lhs_contract)) result_batch_dim = lbd - len(lhs_contract) + sum(bump_lhs_contract) else: # move the new dimension to the end of lhs to avoid changing batch dims @@ -2810,10 +2810,10 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, # lhs tensor product dims in result come after batch dims result_batch_dim = lhs.ndim - len(lhs_contract) - 1 else: - if rhs_batch == () or rbd > onp.max(rhs_batch): + if rhs_batch == () or rbd > np.max(rhs_batch): # can avoid transposes - bump_rhs_contract = onp.greater_equal(rhs_contract, rbd) - rhs_contract = tuple(onp.add(rhs_contract, bump_rhs_contract)) + bump_rhs_contract = np.greater_equal(rhs_contract, rbd) + rhs_contract = tuple(np.add(rhs_contract, bump_rhs_contract)) result_batch_dim = (rbd + (lhs.ndim - len(lhs_contract) - len(lhs_batch)) - (len(rhs_contract) - sum(bump_rhs_contract))) else: @@ -2869,8 +2869,8 @@ def _broadcast_batch_rule(batched_args, batch_dims, *, sizes): batching.primitive_batchers[broadcast_p] = _broadcast_batch_rule def _broadcast_in_dim_impl(operand, *, shape, broadcast_dimensions): - if type(operand) is xla.DeviceArray and onp.all( - onp.equal(operand.shape, onp.take(shape, broadcast_dimensions))): + if type(operand) is xla.DeviceArray and np.all( + np.equal(operand.shape, np.take(shape, broadcast_dimensions))): shape = _broadcast_in_dim_shape_rule( operand, shape=shape, broadcast_dimensions=broadcast_dimensions) aval = ShapedArray(shape, _dtype(operand)) @@ -2884,7 +2884,7 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): _check_shapelike('broadcast_in_dim', 'shape', shape) _check_shapelike('broadcast_in_dim', 'broadcast_dimensions', broadcast_dimensions) - operand_ndim = onp.ndim(operand) + operand_ndim = np.ndim(operand) if operand_ndim != len(broadcast_dimensions): msg = ('broadcast_in_dim broadcast_dimensions must have length equal to ' 'operand ndim; got broadcast_dimensions {} for operand ndim {}.') @@ -2913,7 +2913,7 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): return shape def _broadcast_in_dim_transpose_rule(t, *, shape, broadcast_dimensions): - axes = tuple(onp.delete(range(len(shape)), broadcast_dimensions)) + axes = tuple(np.delete(range(len(shape)), broadcast_dimensions)) return [_reduce_sum(t, axes)] def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape, @@ -2922,7 +2922,7 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape, bdim, = batch_dims new_operand = batching.moveaxis(operand, bdim, 0) new_shape = (operand.shape[bdim],) + shape - new_broadcast_dimensions = (0,) + tuple(onp.add(1, broadcast_dimensions)) + new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions)) return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0 @@ -2970,11 +2970,11 @@ def _concatenate_shape_rule(*operands, **kwargs): if len(set(operand.ndim for operand in operands)) != 1: msg = "Cannot concatenate arrays with different ranks, got {}." raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands))) - shapes = onp.array([operand.shape for operand in operands]) + shapes = np.array([operand.shape for operand in operands]) if not 0 <= dimension < shapes.shape[1]: msg = "concatenate dimension out of bounds: dimension {} for shapes {}." raise TypeError(msg.format(dimension, ", ".join(map(str, shapes)))) - if not onp.all(onp.delete(shapes[0] == shapes, dimension, axis=1)): + if not np.all(np.delete(shapes[0] == shapes, dimension, axis=1)): msg = ("Cannot concatenate arrays with shapes that differ in dimensions " "other than the one being concatenated: dimension {} for shapes {}.") raise TypeError(msg.format(dimension, ", ".join(map(str, shapes)))) @@ -2997,10 +2997,10 @@ def _concatenate_transpose_rule(t, *operands, dimension): if type(t) is ad_util.Zero: return ad_util.Zero else: - limit_points = onp.cumsum([shape[dimension] for shape in operand_shapes]) - starts = onp.zeros((len(operands), t.ndim), dtype=int) + limit_points = np.cumsum([shape[dimension] for shape in operand_shapes]) + starts = np.zeros((len(operands), t.ndim), dtype=int) starts[1:, dimension] = limit_points[:-1] - limits = onp.tile(t.shape, (len(operands), 1)) + limits = np.tile(t.shape, (len(operands), 1)) limits[:, dimension] = limit_points return [slice(t, start, limit) if ad.is_undefined_primal(o) else None @@ -3034,9 +3034,9 @@ def _pad_dtype_rule(operand, padding_value, *, padding_config): def _pad_shape_rule(operand, padding_value, *, padding_config): lo, hi, interior = zip(*padding_config) - out_shape = onp.add( - onp.add(onp.add(lo, hi), operand.shape), - onp.maximum(0, onp.multiply(interior, onp.subtract(operand.shape, 1)))) + out_shape = np.add( + np.add(np.add(lo, hi), operand.shape), + np.maximum(0, np.multiply(interior, np.subtract(operand.shape, 1)))) return tuple(out_shape) def _pad_transpose(t, operand, padding_value, *, padding_config): @@ -3047,10 +3047,10 @@ def _pad_transpose(t, operand, padding_value, *, padding_config): total = lambda x: _reduce_sum(x, list(range(t.ndim))) def t_op(): - unpad_config = safe_zip(onp.negative(lo), onp.negative(hi), - onp.zeros_like(interior)) - unpadded = pad(t, onp.array(0., t.dtype), unpad_config) - return slice(unpadded, onp.zeros_like(lo), unpadded.shape, onp.add(interior, 1)) + unpad_config = safe_zip(np.negative(lo), np.negative(hi), + np.zeros_like(interior)) + unpadded = pad(t, np.array(0., t.dtype), unpad_config) + return slice(unpadded, np.zeros_like(lo), unpadded.shape, np.add(interior, 1)) t_operand = t_op() if ad.is_undefined_primal(operand) else None t_padv = sub(total(t), total(t_operand)) if ad.is_undefined_primal(padding_value) else None @@ -3103,7 +3103,7 @@ def _pad_masking_rule(padded_vals, logical_shapes, padding_config): def squeeze(array: Array, dimensions: Tuple[int, ...]) -> Array: """Squeeze any number of size 1 dimensions from an array.""" - ndim = onp.ndim(array) + ndim = np.ndim(array) dimensions = tuple(sorted(_canonicalize_axis(i, ndim) for i in dimensions)) if not dimensions: return array @@ -3113,7 +3113,7 @@ def _squeeze_dtype_rule(operand, *, dimensions): return operand.dtype def _squeeze_shape_rule(operand, *, dimensions): - return _compute_squeeze_shape(onp.shape(operand), dimensions) + return _compute_squeeze_shape(np.shape(operand), dimensions) def _compute_squeeze_shape(shape, dimensions): dims_set = set(dimensions) @@ -3139,7 +3139,7 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): operand, = batched_args bdim, = batch_dims operand = batching.moveaxis(operand, bdim, 0) - dimensions = tuple(onp.add(1, dimensions)) + dimensions = tuple(np.add(1, dimensions)) return squeeze(operand, dimensions=dimensions), 0 squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, @@ -3150,9 +3150,9 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): def expand_dims(array: Array, dimensions: Tuple[int, ...]) -> Array: """Insert any number of size 1 dimensions into an array.""" - ndim_out = onp.ndim(array) + len(dimensions) + ndim_out = np.ndim(array) + len(dimensions) dims_set = frozenset(_canonicalize_axis(i, ndim_out) for i in dimensions) - result_shape = list(onp.shape(array)) + result_shape = list(np.shape(array)) for i in sorted(dims_set): result_shape.insert(i, 1) broadcast_dims = [i for i in range(ndim_out) if i not in dims_set] @@ -3161,7 +3161,7 @@ def expand_dims(array: Array, dimensions: Tuple[int, ...]) -> Array: # We have a nonstandard reshape impl so that we can be lazy about data movement. def _reshape_impl(operand, *, new_sizes, dimensions): - old_sizes = onp.shape(operand) + old_sizes = np.shape(operand) if type(operand) is xla.DeviceArray and dimensions is None: bcast_dims = _is_singleton_reshape(old_sizes, new_sizes) if bcast_dims is not None: @@ -3251,17 +3251,17 @@ def _is_axis_split(s1, s2): return _is_axis_merge(s2, s1) def _reshape_shape_rule(operand, *, new_sizes, dimensions): - if not onp.all(onp.greater_equal(new_sizes, 0)): + if not np.all(np.greater_equal(new_sizes, 0)): msg = 'reshape new_sizes must all be positive, got {}.' raise TypeError(msg.format(new_sizes)) - if prod(onp.shape(operand)) != prod(new_sizes): + if prod(np.shape(operand)) != prod(new_sizes): msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.' - raise TypeError(msg.format(new_sizes, onp.shape(operand))) + raise TypeError(msg.format(new_sizes, np.shape(operand))) if dimensions is not None: - if set(dimensions) != set(range(onp.ndim(operand))): + if set(dimensions) != set(range(np.ndim(operand))): msg = ('reshape dimensions must be a permutation of operand dimensions, ' 'got dimensions {} for shape {}.') - raise TypeError(msg.format(dimensions, onp.shape(operand))) + raise TypeError(msg.format(dimensions, np.shape(operand))) return tuple(new_sizes) def _reshape_dtype_rule(operand, *, new_sizes, dimensions): @@ -3278,15 +3278,15 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions): if dimensions is None: return [reshape(t, operand.aval.shape)] else: - return [transpose(reshape(t, onp.take(operand.aval.shape, dimensions)), - onp.argsort(dimensions))] + return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)), + np.argsort(dimensions))] def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): operand, = batched_args bdim, = batch_dims operand = batching.moveaxis(operand, bdim, 0) if dimensions is not None: - dimensions = (0,) + tuple(onp.add(1, dimensions)) + dimensions = (0,) + tuple(np.add(1, dimensions)) return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0 reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, @@ -3327,14 +3327,14 @@ def _transpose_impl(operand, *, permutation): return xla.apply_primitive(transpose_p, operand, permutation=permutation) def _transpose_shape_rule(operand, *, permutation): - if not isinstance(permutation, (tuple, list, onp.ndarray)): + if not isinstance(permutation, (tuple, list, np.ndarray)): msg = "transpose permutation must be a tuple/list/ndarray, got {}." raise TypeError(msg.format(type(permutation))) if tuple(sorted(permutation)) != tuple(range(operand.ndim)): msg = ("transpose permutation isn't a permutation of operand dimensions, " "got permutation {} for operand shape {}.") raise TypeError(msg.format(permutation, operand.shape)) - return tuple(onp.take(operand.shape, permutation)) + return tuple(np.take(operand.shape, permutation)) def _transpose_batch_rule(batched_args, batch_dims, *, permutation): operand, = batched_args @@ -3349,7 +3349,7 @@ def _transpose_masking_rule(padded_vals, logical_shapes, permutation): 'transpose') transpose_p.def_impl(_transpose_impl) ad.deflinear(transpose_p, - lambda t, permutation: [transpose(t, onp.argsort(permutation))]) + lambda t, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule masking.masking_rules[transpose_p] = _transpose_masking_rule @@ -3366,7 +3366,7 @@ def _select_shape_rule(pred, on_true, on_false): def _select_dtype_rule(pred, on_true, on_false): _check_same_dtypes("select", False, on_true.dtype, on_false.dtype) - if not dtypes.issubdtype(pred.dtype, onp.bool_): + if not dtypes.issubdtype(pred.dtype, np.bool_): msg = "select pred must be boolean type, got {}." raise TypeError(msg.format(pred.dtype)) return on_true.dtype @@ -3389,31 +3389,31 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): # avoid transposes and some broadcasts in special cases if pred_bdim == ot_bdim == of_bdim: - if onp.shape(pred) == onp.shape(on_true): + if np.shape(pred) == np.shape(on_true): return select(pred, on_true, on_false), pred_bdim else: # vmapped function had a scalar pred with nonscalar args - assert onp.ndim(pred) == 1 + assert np.ndim(pred) == 1 pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim]) return select(pred, on_true, on_false), pred_bdim - elif onp.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None: + elif np.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None: if ot_bdim == of_bdim: return select(pred, on_true, on_false), ot_bdim - elif onp.shape(on_true) == onp.shape(on_false): + elif np.shape(on_true) == np.shape(on_false): on_false = batching.moveaxis(on_false, of_bdim, ot_bdim) return select(pred, on_true, on_false), ot_bdim - pred = batching.bdim_at_front(pred, pred_bdim, size) if onp.shape(pred) else pred - if not onp.shape(on_true) == onp.shape(on_false) == (): + pred = batching.bdim_at_front(pred, pred_bdim, size) if np.shape(pred) else pred + if not np.shape(on_true) == np.shape(on_false) == (): on_true = batching.bdim_at_front(on_true, ot_bdim, size) on_false = batching.bdim_at_front(on_false, of_bdim, size) - assert onp.shape(on_true) == onp.shape(on_false) - if 0 < onp.ndim(pred) < onp.ndim(on_true): + assert np.shape(on_true) == np.shape(on_false) + if 0 < np.ndim(pred) < np.ndim(on_true): # vmapped function had a scalar pred with nonscalar args - assert onp.ndim(pred) == 1 + assert np.ndim(pred) == 1 pred = broadcast_in_dim(pred, on_true.shape, [0]) - if onp.ndim(pred) > onp.ndim(on_true): - assert onp.ndim(on_true) == 0 + if np.ndim(pred) > np.ndim(on_true): + assert np.ndim(on_true) == 0 on_true = broadcast(on_true, pred.shape) on_false = broadcast(on_false, pred.shape) return select(pred, on_true, on_false), 0 @@ -3421,8 +3421,8 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): def _select_masking_rule(padded_vals, logical_shapes): pred_shape, true_shape, false_shape = [ masking.padded_shape_as_value(val.shape) for val in padded_vals] - assert onp.array_equal(pred_shape, true_shape) - assert onp.array_equal(pred_shape, false_shape) + assert np.array_equal(pred_shape, true_shape) + assert np.array_equal(pred_shape, false_shape) return select(*padded_vals) select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select') @@ -3448,33 +3448,33 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): raise TypeError(msg.format(start_indices, limit_indices)) if (not masking.is_polymorphic(limit_indices) and not masking.is_polymorphic(operand.shape) and - not onp.all(onp.less_equal(limit_indices, operand.shape))): + not np.all(np.less_equal(limit_indices, operand.shape))): msg = ("slice limit_indices must be less than or equal to operand shape, " "got limit_indices {} for operand shape {}.") raise TypeError(msg.format(limit_indices, operand.shape)) - if not onp.all(onp.greater_equal(start_indices, 0)): + if not np.all(np.greater_equal(start_indices, 0)): msg = ("slice start_indices must be greater than or equal to zero, " "got start_indices of {}.") raise TypeError(msg.format(start_indices)) if (not masking.is_polymorphic(limit_indices) and - not onp.all(onp.greater_equal(limit_indices, start_indices))): + not np.all(np.greater_equal(limit_indices, start_indices))): msg = ("slice limit_indices must be greater than or equal to start_indices," " got start_indices {} and limit_indices {}.") raise TypeError(msg.format(start_indices, limit_indices)) if strides is None: - strides = onp.ones(operand.ndim, onp.int32) + strides = np.ones(operand.ndim, np.int32) else: _check_shapelike("slice", "strides", strides) if len(strides) != operand.ndim: msg = ("slice strides must have length equal to the number of dimensions " "of the operand, got strides {} for operand shape {}.") raise TypeError(msg.format(strides, operand.shape)) - if not onp.all(onp.greater(strides, 0)): + if not np.all(np.greater(strides, 0)): msg = "slice strides must be positive, got {}" raise TypeError(msg.format(strides)) - result_shape = onp.floor_divide( - onp.add(onp.subtract(limit_indices, start_indices), strides) - 1, strides) + result_shape = np.floor_divide( + np.add(np.subtract(limit_indices, start_indices), strides) - 1, strides) return tuple(result_shape) def _slice_translation_rule(c, operand, *, start_indices, limit_indices, @@ -3485,14 +3485,14 @@ def _slice_translation_rule(c, operand, *, start_indices, limit_indices, def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) operand_shape = operand.aval.shape - if strides is None or onp.all(onp.equal(strides, 1)): - pads = zip(start_indices, onp.subtract(operand_shape, limit_indices), + if strides is None or np.all(np.equal(strides, 1)): + pads = zip(start_indices, np.subtract(operand_shape, limit_indices), (0,) * len(start_indices)) else: - real_limits = onp.add(onp.add(start_indices, 1), - onp.multiply(onp.subtract(t.shape, 1), strides)) - pads = safe_zip(start_indices, onp.subtract(operand_shape, real_limits), - onp.subtract(strides, 1)) + real_limits = np.add(np.add(start_indices, 1), + np.multiply(np.subtract(t.shape, 1), strides)) + pads = safe_zip(start_indices, np.subtract(operand_shape, real_limits), + np.subtract(strides, 1)) result = pad(t, _const(t, 0), pads) assert result.shape == operand_shape return [result] @@ -3541,11 +3541,11 @@ def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes): msg = ("dynamic_slice slice_sizes must have the same length as " "start_indices, got start_inidices length {} and slice_sizes {}.") raise TypeError(msg.format(len(start_indices), slice_sizes)) - if not onp.all(onp.less_equal(slice_sizes, operand.shape)): + if not np.all(np.less_equal(slice_sizes, operand.shape)): msg = ("slice slice_sizes must be less than or equal to operand shape, " "got slice_sizes {} for operand shape {}.") raise TypeError(msg.format(slice_sizes, operand.shape)) - if not onp.all(onp.greater_equal(slice_sizes, 0)): + if not np.all(np.greater_equal(slice_sizes, 0)): msg = ("slice slice_sizes must be greater than or equal to zero, " "got slice_sizes of {}.") raise TypeError(msg.format(slice_sizes)) @@ -3553,7 +3553,7 @@ def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes): def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes): if any(i.dtype != start_indices[0].dtype or - not dtypes.issubdtype(i.dtype, onp.integer) for i in start_indices): + not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices): msg = ("index arguments to dynamic_slice must be integers of the same " "type, got: {}") raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices))) @@ -3595,7 +3595,7 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): operand, *start_indices = batched_args operand_bd, *start_idx_bds = batch_dims operand_shape = (operand.shape if operand_bd is batching.not_mapped - else tuple(onp.delete(operand.shape, operand_bd))) + else tuple(np.delete(operand.shape, operand_bd))) dims = tuple(range(len(operand_shape))) dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(), start_index_map=dims) @@ -3622,7 +3622,7 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): msg = ("dynamic_update_slice start_indices must have length equal to the " "rank of operand, got indices {} for operand shape {}.") raise TypeError(msg.format(start_indices, operand.shape)) - if not onp.all(onp.less_equal(update.shape, operand.shape)): + if not np.all(np.less_equal(update.shape, operand.shape)): msg = ("dynamic_update_slice update shape must be smaller than operand " "shape, got update shape {} for operand shape {}.") raise TypeError(msg.format(update.shape, operand.shape)) @@ -3631,7 +3631,7 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): _check_same_dtypes("dynamic_update_slice", False, operand.dtype, update.dtype) if any(i.dtype != start_indices[0].dtype or - not dtypes.issubdtype(i.dtype, onp.integer) for i in start_indices): + not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices): msg = ("index arguments to dynamic_update_slice must be integers of the " "same type, got {}") raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices))) @@ -3674,7 +3674,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): operand, update, *start_idx = batched_args operand_bd, update_bd, *start_idx_bd = batch_dims update_shape = (update.shape if update_bd is batching.not_mapped - else tuple(onp.delete(update.shape, update_bd))) + else tuple(np.delete(update.shape, update_bd))) dims = tuple(range(len(update_shape))) dnums = ScatterDimensionNumbers(update_window_dims=dims, inserted_window_dims=(), @@ -3706,7 +3706,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers): return proto def _gather_dtype_rule(operand, start_indices, **kwargs): - if not dtypes.issubdtype(start_indices.dtype, onp.integer): + if not dtypes.issubdtype(start_indices.dtype, np.integer): raise ValueError("start_indices must have an integer type") return dtypes.canonicalize_dtype(operand.dtype) @@ -3718,7 +3718,7 @@ def _gather_shape_rule(operand, start_indices, *, dimension_numbers, raise ValueError(msg) result_rank = len(dimension_numbers.offset_dims) + start_indices.ndim - 1 start_indices_shape = iter(start_indices.shape[:-1]) - slice_sizes = iter(onp.delete(slice_sizes, dimension_numbers.collapsed_slice_dims)) + slice_sizes = iter(np.delete(slice_sizes, dimension_numbers.collapsed_slice_dims)) return tuple(next(slice_sizes) if i in dimension_numbers.offset_dims else next(start_indices_shape) for i in range(result_rank)) @@ -3755,9 +3755,9 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, if operand_bdim is not None and start_indices_bdim is None: operand = batching.moveaxis(operand, operand_bdim, 0) slice_sizes = (operand.shape[0],) + slice_sizes - offset_dims = (0,) + tuple(onp.add(1, dimension_numbers.offset_dims)) - collapsed_slice_dims = tuple(onp.add(1, dimension_numbers.collapsed_slice_dims)) - start_index_map = tuple(onp.add(1, dimension_numbers.start_index_map)) + offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) + collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) + start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, @@ -3767,7 +3767,7 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, elif operand_bdim is None and start_indices_bdim is not None: start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0) - offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims)) + offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=dimension_numbers.collapsed_slice_dims, @@ -3790,9 +3790,9 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, start_indices = concatenate([counts, start_indices], len(count_shape) - 1) slice_sizes = (1,) + slice_sizes - collapsed_slice_dims = (0,) + tuple(onp.add(1, dimension_numbers.collapsed_slice_dims)) - offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims)) - start_index_map = (0,) + tuple(onp.add(1, dimension_numbers.start_index_map)) + collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) + offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) + start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, @@ -3822,7 +3822,7 @@ def _scatter_dimensions_proto(indices_shape, dimension_numbers): return proto def _scatter_dtype_rule(operand, scatter_indices, updates, **kwargs): - if not dtypes.issubdtype(scatter_indices.dtype, onp.integer): + if not dtypes.issubdtype(scatter_indices.dtype, np.integer): raise ValueError("scatter_indices must have an integer type") _check_same_dtypes("scatter", False, operand.dtype, updates.dtype) return dtypes.canonicalize_dtype(operand.dtype) @@ -3833,7 +3833,7 @@ def _scatter_shape_rule(operand, scatter_indices, updates, **kwargs): def _scatter_translation_rule(c, operand, scatter_indices, updates, update_jaxpr, update_consts, dimension_numbers): dtype = c.get_shape(operand).numpy_dtype() - init_value = xb.constant(c, onp.array(0, dtype)) + init_value = xb.constant(c, np.array(0, dtype)) update_computation = _reduction_computation( c, update_jaxpr, update_consts, init_value) indices_shape = c.get_shape(scatter_indices) @@ -3938,9 +3938,9 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, updates = batching.bdim_at_front(updates, updates_bdim, size) if scatter_indices_bdim is None: - inserted_window_dims = tuple(onp.add(1, dimension_numbers.inserted_window_dims)) - update_window_dims = (0,) + tuple(onp.add(1, dimension_numbers.update_window_dims)) - scatter_dims_to_operand_dims = tuple(onp.add(1, dimension_numbers.scatter_dims_to_operand_dims)) + inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) + update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims)) + scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, @@ -3958,9 +3958,9 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_indices = concatenate([counts, scatter_indices], len(count_shape) - 1) - update_window_dims = tuple(onp.add(1, dimension_numbers.update_window_dims)) - inserted_window_dims = (0,) + tuple(onp.add(1, dimension_numbers.inserted_window_dims)) - scatter_dims_to_operand_dims = (0,) + tuple(onp.add(1, dimension_numbers.scatter_dims_to_operand_dims)) + update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims)) + inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims)) + scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, @@ -4033,10 +4033,10 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, # tangents for the values in updates. initial_vals = gather( - operand, scatter_indices, gather_dnums, onp.array(slice_sizes)) + operand, scatter_indices, gather_dnums, np.array(slice_sizes)) target_vals = gather( - val_out, scatter_indices, gather_dnums, onp.array(slice_sizes)) + val_out, scatter_indices, gather_dnums, np.array(slice_sizes)) successful_updates = (updates == target_vals) retained_values = (initial_vals == target_vals) @@ -4048,7 +4048,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_dnums), scatter_indices, gather_dnums, - onp.array(slice_sizes)) + np.array(slice_sizes)) num_refs = gather( scatter_add(_zeros(operand), @@ -4057,7 +4057,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_dnums), scatter_indices, gather_dnums, - onp.array(slice_sizes)) + np.array(slice_sizes)) updates_normalizer = select(retained_values, 1.0 / (num_updates + 1), @@ -4075,7 +4075,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, # This can be simplified once scatter has transpose implemented target_tangents = gather( - g_operand, scatter_indices, gather_dnums, onp.array(slice_sizes)) + g_operand, scatter_indices, gather_dnums, np.array(slice_sizes)) tangent_updates = (target_tangents * operand_coef + g_updates * updates_coef) @@ -4139,9 +4139,9 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # We specify the dtype here in case `updates_shape` is an empty tuple, in # which case numpy defaults to float64. - ids_shape = onp.array(updates_shape, dtype=onp.int32) + ids_shape = np.array(updates_shape, dtype=np.int32) ids_shape[dnums.update_window_dims,] = 1 - num_ids = onp.prod(ids_shape) + num_ids = np.prod(ids_shape) update_ids = add(reshape(iota(updates_dtype, num_ids), ids_shape), _ones(updates)) @@ -4204,7 +4204,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, def _reduce_shape_rule(operand, init_value, *, computation, jaxpr, consts, dimensions): - return tuple(onp.delete(operand.shape, dimensions)) + return tuple(np.delete(operand.shape, dimensions)) def _reduce_translation_rule(c, operand, init_value, *, computation, jaxpr, consts, dimensions): @@ -4218,7 +4218,7 @@ def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr, consts, if init_value_bdim is None: assert operand_bdim is not None new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions] - new_operand_bdim = operand_bdim - int(onp.sum(onp.less(dimensions, operand_bdim))) + new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim))) return reduce(operand, init_value, computation, new_dimensions), new_operand_bdim else: raise NotImplementedError # loop and stack @@ -4239,7 +4239,7 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes, axes, input_shape=None): (padded_val,), (logical_shape,) = padded_vals, logical_shapes padded_shape = masking.padded_shape_as_value(padded_val.shape) - masks = [broadcasted_iota(onp.int32, padded_shape, i) < d + masks = [broadcasted_iota(np.int32, padded_shape, i) < d for i, d in enumerate(logical_shape) if i in axes] mask = _reduce(operator.and_, masks) masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype)) @@ -4252,9 +4252,9 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes, def _reduce_number_dtype_rule(name, operand, *args, **kw): - if not dtypes.issubdtype(operand.dtype, onp.number): + if not dtypes.issubdtype(operand.dtype, np.number): raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes " - "of number.".format(name, onp.dtype(operand.dtype).name)) + "of number.".format(name, np.dtype(operand.dtype).name)) return dtypes.canonicalize_dtype(operand.dtype) def _reduce_sum_shape_rule(operand, *, axes): @@ -4263,14 +4263,14 @@ def _reduce_sum_shape_rule(operand, *, axes): def _reduce_sum_translation_rule(c, operand, *, axes): dtype = c.get_shape(operand).numpy_dtype() scalar = ShapedArray((), dtype) - return xops.Reduce(c, [operand], [xb.constant(c, onp.array(0, dtype))], + return xops.Reduce(c, [operand], [xb.constant(c, np.array(0, dtype))], xla.primitive_subcomputation(add_p, scalar, scalar), axes) def _reduce_sum_transpose_rule(cotangent, operand, *, axes): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape - broadcast_dimensions = tuple(onp.delete(onp.arange(len(input_shape)), axes)) + broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes)) result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions) assert result.shape == input_shape return [result] @@ -4281,28 +4281,28 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes): ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p) _masking_defreducer(reduce_sum_p, - lambda shape, dtype: onp.broadcast_to(onp.array(0, dtype), shape)) + lambda shape, dtype: np.broadcast_to(np.array(0, dtype), shape)) def _reduce_op_shape_rule(operand, *, axes, input_shape=None): del input_shape # Unused. if len(axes) != len(set(axes)): raise ValueError(f"duplicate value in 'axes' of reduction: {axes}") - return tuple(onp.delete(operand.shape, axes)) + return tuple(np.delete(operand.shape, axes)) def _reduce_prod_translation_rule(c, operand, *, axes): dtype = c.get_shape(operand).numpy_dtype() scalar = ShapedArray((), dtype) - return xops.Reduce(c, [operand], [xb.constant(c, onp.array(1, dtype))], + return xops.Reduce(c, [operand], [xb.constant(c, np.array(1, dtype))], xla.primitive_subcomputation(mul_p, scalar, scalar), axes) def _reduce_prod_jvp_rule(primals, tangents, *, axes): operand, = primals tangent, = tangents - input_shape = onp.array(operand.shape) + input_shape = np.array(operand.shape) - n = onp.prod(input_shape[list(axes)]) - non_axes = onp.delete(onp.arange(len(input_shape)), axes) + n = np.prod(input_shape[list(axes)]) + non_axes = np.delete(np.arange(len(input_shape)), axes) # Move the reduced axes to the front, and flatten them to 1D. permutation = axes + tuple(non_axes) @@ -4336,11 +4336,11 @@ def _reduce_prod_tree(x, axis=0): ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p) _masking_defreducer(reduce_prod_p, - lambda shape, dtype: onp.broadcast_to(onp.array(1, dtype), shape)) + lambda shape, dtype: np.broadcast_to(np.array(1, dtype), shape)) def _reduce_chooser_shape_rule(operand, *, axes): - return tuple(onp.delete(operand.shape, axes)) + return tuple(np.delete(operand.shape, axes)) def _reduce_chooser_translation_rule(prim, identity, c, operand, *, axes): dtype = c.get_shape(operand).numpy_dtype() @@ -4365,7 +4365,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p) _masking_defreducer(reduce_max_p, - lambda shape, dtype: onp.broadcast_to(onp.array(-onp.inf, dtype), shape)) + lambda shape, dtype: np.broadcast_to(np.array(-np.inf, dtype), shape)) _reduce_min_translation_rule = partial( @@ -4375,13 +4375,13 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p) _masking_defreducer(reduce_min_p, - lambda shape, dtype: onp.broadcast_to(onp.array(onp.inf, dtype), shape)) + lambda shape, dtype: np.broadcast_to(np.array(np.inf, dtype), shape)) def _argminmax_shape_rule(operand, *, axes, index_dtype): axis, = axes - return tuple(onp.delete(operand.shape, axis)) + return tuple(np.delete(operand.shape, axis)) def _argminmax_dtype_rule(operand, *, axes, index_dtype): return index_dtype @@ -4411,13 +4411,13 @@ def _argminmax_translation_rule(value_comparator, identity, out = xops.Reduce( c, [operand, iota], [xb.constant(c, identity(dtype)), - xb.constant(c, onp.array(0, index_dtype))], comparator, [axis]) + xb.constant(c, np.array(0, index_dtype))], comparator, [axis]) return xops.GetTupleElement(out, 1) def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype): axis, = axes idxs = tie_in(a, broadcasted_iota(index_dtype, a.shape, axis)) - maxval = onp.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype) + maxval = np.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype) maxval = broadcast(tie_in(a, maxval), a.shape) mask_idxs = select(eq(a, expand_dims(op(a, (axis,)), (axis,))), idxs, maxval) @@ -4454,26 +4454,26 @@ def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype): def _reduce_logical_shape_rule(operand, *, axes): - if operand.dtype != onp.bool_: + if operand.dtype != np.bool_: msg = "logical reduction requires operand dtype bool, got {}." raise TypeError(msg.format(operand.dtype)) - return tuple(onp.delete(operand.shape, axes)) + return tuple(np.delete(operand.shape, axes)) def _reduce_logical_translation_rule(prim, identity, c, operand, *, axes): - scalar = ShapedArray((), onp.bool_) - return xops.Reduce(c, [operand], [xb.constant(c, identity(onp.bool_))], + scalar = ShapedArray((), np.bool_) + return xops.Reduce(c, [operand], [xb.constant(c, identity(np.bool_))], xla.primitive_subcomputation(prim, scalar, scalar), axes) _reduce_or_translation_rule = partial(_reduce_logical_translation_rule, or_p, _get_max_identity) -reduce_or_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(onp.bool_), +reduce_or_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_), 'reduce_or', _reduce_or_translation_rule) batching.defreducer(reduce_or_p) _reduce_and_translation_rule = partial(_reduce_logical_translation_rule, and_p, _get_min_identity) -reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(onp.bool_), +reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_), 'reduce_and', _reduce_and_translation_rule) batching.defreducer(reduce_and_p) @@ -4520,9 +4520,9 @@ def reduce_window(x, window_dimensions, window_strides, padding): def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides, padding): - if not dtypes.issubdtype(operand.dtype, onp.number): + if not dtypes.issubdtype(operand.dtype, np.number): msg = "operand to reduce_window_sum must have a number dtype, got {}" - raise TypeError(msg.format(onp.dtype(operand.dtype).name)) + raise TypeError(msg.format(np.dtype(operand.dtype).name)) return _common_reduce_window_shape_rule(operand, window_dimensions, window_strides, padding) @@ -4531,7 +4531,7 @@ def _reduce_window_sum_translation_rule(c, operand, *, window_dimensions, dtype = c.get_shape(operand).numpy_dtype() scalar = ShapedArray((), dtype) return xops.ReduceWindowWithGeneralPadding( - operand, xb.constant(c, onp.array(0, dtype)), + operand, xb.constant(c, np.array(0, dtype)), xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions, window_strides, (), (), padding) @@ -4606,9 +4606,9 @@ def _common_reduce_window_shape_rule(operand, window_dimensions, def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, padding): - operand_padded = onp.add(operand_shape, onp.add(*zip(*padding))) - t = onp.floor_divide( - onp.subtract(operand_padded, window_dimensions), window_strides) + 1 + operand_padded = np.add(operand_shape, np.add(*zip(*padding))) + t = np.floor_divide( + np.subtract(operand_padded, window_dimensions), window_strides) + 1 return tuple(t) _reduce_window_max_translation_rule = partial( @@ -4670,7 +4670,7 @@ def _select_and_scatter_add_translation( scalar = ShapedArray((), dtype) select = xla.primitive_subcomputation(select_prim, scalar, scalar) scatter = xla.primitive_subcomputation(add_p, scalar, scalar) - zero = xb.constant(c, onp.array(0, dtype)) + zero = xb.constant(c, np.array(0, dtype)) return xops.SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, padding, source, zero, scatter) @@ -4751,15 +4751,15 @@ def _select_and_gather_add_shape_rule( _UINT_DTYPES = { - 16: onp.uint16, - 32: onp.uint32, - 64: onp.uint64, + 16: np.uint16, + 32: np.uint32, + 64: np.uint64, } _INT_DTYPES = { - 16: onp.int16, - 32: onp.int32, - 64: onp.int64, + 16: np.int16, + 32: np.int32, + 64: np.int64, } def _select_and_gather_add_translation( @@ -4773,7 +4773,7 @@ def _select_and_gather_add_translation( assert nbits <= max_bits double_word_reduction = nbits * 2 <= max_bits - const = lambda c, dtype, x: xb.constant(c, onp.array(x, dtype=dtype), + const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype), canonicalize_types=False) if double_word_reduction: @@ -4842,9 +4842,9 @@ def snd(t): def reducer(): c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer") x = xb.parameter(c, 0, - xla_client.Shape.array_shape(onp.dtype(double_word_dtype), ())) + xla_client.Shape.array_shape(np.dtype(double_word_dtype), ())) y = xb.parameter(c, 1, - xla_client.Shape.array_shape(onp.dtype(double_word_dtype), ())) + xla_client.Shape.array_shape(np.dtype(double_word_dtype), ())) assert select_prim is ge_p or select_prim is le_p which = xops.Ge if select_prim is ge_p else xops.Le xops.Select(which(fst(c, x), fst(c, y)), x, y) @@ -4852,7 +4852,7 @@ def reducer(): assert select_prim is ge_p or select_prim is le_p, select_prim - init = -onp.inf if select_prim is ge_p else onp.inf + init = -np.inf if select_prim is ge_p else np.inf out = xops.ReduceWindowWithGeneralPadding( pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)), reducer(), window_dimensions, window_strides, (), (), padding) @@ -4952,11 +4952,11 @@ def _prescan_power_of_two(x, axis: int, op: Callable, unit): def _parallel_prefix_scan(x, axis: int, op: Callable, unit: Any): - if onp.issubdtype(x.dtype, onp.integer): - if onp.isposinf(unit): - unit = onp.iinfo(x.dtype).max - elif onp.isneginf(unit): - unit = onp.iinfo(x.dtype).min + if np.issubdtype(x.dtype, np.integer): + if np.isposinf(unit): + unit = np.iinfo(x.dtype).max + elif np.isneginf(unit): + unit = np.iinfo(x.dtype).min n = x.shape[axis] if n == 0: return x @@ -4972,8 +4972,8 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit: Any): _cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0) _cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1) -_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=-onp.inf) -_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=onp.inf) +_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=-np.inf) +_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=np.inf) def _cumred_shape_rule(x, *, axis: int): if axis < 0 or axis >= x.ndim: @@ -5073,15 +5073,15 @@ def _float_to_int_for_sort(x): # Note that in order to avoid -x to overflow, we calculate # int32_max - x as unsigned, and then convert back to signed. if x.dtype == dtypes.bfloat16: - x = convert_element_type(x, onp.float32) - nbits = onp.finfo(x).bits + x = convert_element_type(x, np.float32) + nbits = np.finfo(x).bits signed_dtype = _INT_DTYPES[nbits] unsigned_dtype = _UINT_DTYPES[nbits] signed = bitcast_convert_type(x, signed_dtype) unsigned = bitcast_convert_type(x, unsigned_dtype) flipped = bitcast_convert_type( - sub(unsigned_dtype(onp.iinfo(signed_dtype).max), unsigned), signed_dtype) + sub(unsigned_dtype(np.iinfo(signed_dtype).max), unsigned), signed_dtype) return select(lt(signed, _zero(signed)), flipped, signed) # Default comparator that sorts the operands lexicographically on the @@ -5098,10 +5098,10 @@ def _sort_lt_comparator(*operands, num_keys=1): x_keys, y_keys = [], [] for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]): assert x.dtype == y.dtype, (x.dtype, y.dtype) - if onp.issubdtype(x.dtype, onp.complexfloating): + if np.issubdtype(x.dtype, np.complexfloating): x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))]) y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))]) - elif onp.issubdtype(x.dtype, onp.floating): + elif np.issubdtype(x.dtype, np.floating): x_keys.append(_float_to_int_for_sort(x)) y_keys.append(_float_to_int_for_sort(y)) else: @@ -5131,7 +5131,7 @@ def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys): shape = primals[0].shape iotas = [] for dim, size in enumerate(shape): - dtype = onp.int32 if size < onp.iinfo(onp.int32).max else onp.int64 + dtype = np.int32 if size < np.iinfo(np.int32).max else np.int64 iotas.append(broadcasted_iota(dtype, shape, dim)) primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension, is_stable=is_stable, num_keys=num_keys) @@ -5146,7 +5146,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys new_args = [] for arg, bdim in zip(batched_args, batch_dims): if bdim is None: - dims = onp.delete(onp.arange(prototype_arg.ndim), new_bdim) + dims = np.delete(np.arange(prototype_arg.ndim), new_bdim) new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims)) else: new_args.append(batching.moveaxis(arg, bdim, new_bdim)) @@ -5177,7 +5177,7 @@ def _top_k_abstract_eval(operand, *, k): raise ValueError(msg.format(k, shape)) shape[-1] = k return (ShapedArray(shape, operand.dtype), - ShapedArray(shape, onp.dtype(onp.int32))) + ShapedArray(shape, np.dtype(np.int32))) def _top_k_jvp(primals, tangents, *, k): operand, = primals @@ -5210,7 +5210,7 @@ def _top_k_batch_rule(batched_args, batch_dims, *, k): operand, = batched_args bdim, = batch_dims if bdim == operand.ndim-1: - perm = onp.arange(operand.ndim) + perm = np.arange(operand.ndim) perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1] top_k_v, top_k_i = top_k(transpose(operand, perm), k=k) return (transpose(top_k_v, perm), @@ -5412,20 +5412,20 @@ def _rng_uniform_translation_rule(c, a, b, *, shape): ### util -_ndim = onp.ndim +_ndim = np.ndim def _dilate_shape(shape, dilation): """Utility function for computing the shape resulting from a dilation.""" - if not onp.all(onp.greater(dilation, 0)): + if not np.all(np.greater(dilation, 0)): msg = "All dilations must be positive, got {}." raise TypeError(msg.format(dilation)) dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation) - return onp.where(shape == 0, 0, - onp.multiply(dilation, onp.subtract(shape, 1)) + 1) + return np.where(shape == 0, 0, + np.multiply(dilation, np.subtract(shape, 1)) + 1) def _ceil_divide(x1, x2): - return -onp.floor_divide(onp.negative(x1), x2) + return -np.floor_divide(np.negative(x1), x2) def padtype_to_pads(in_shape, window_shape, window_strides, padding): """Convert padding string to list of pairs of pad values.""" @@ -5441,7 +5441,7 @@ def padtype_to_pads(in_shape, window_shape, window_strides, padding): if padding == PaddingType.SAME: out_shape = _ceil_divide(in_shape, window_strides) - pad_sizes = onp.maximum(0, (out_shape - 1) * window_strides + + pad_sizes = np.maximum(0, (out_shape - 1) * window_strides + window_shape - in_shape) return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] elif padding == PaddingType.VALID: @@ -5455,11 +5455,11 @@ def _check_same_dtypes(name, ignore_fp_precision, *ttypes): """Check that dtypes agree, possibly ignoring float precision.""" # the `ignore_fp_precision` flag exists because the XLA shape inference logic # allows mixed floating point precision, but the HLO verifier often rejects it - types = list(map(onp.dtype, ttypes)) # canonicalize + types = list(map(np.dtype, ttypes)) # canonicalize if ignore_fp_precision: types = [ - onp.floating if dtypes.issubdtype(dtype, onp.floating) - else onp.complexfloating if dtypes.issubdtype(dtype, onp.complexfloating) + np.floating if dtypes.issubdtype(dtype, np.floating) + else np.complexfloating if dtypes.issubdtype(dtype, np.complexfloating) else dtype for dtype in types] if len({dtypes.canonicalize_dtype(t) for t in types}) != 1: if ignore_fp_precision: @@ -5482,7 +5482,7 @@ def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides): msg = "Arguments to {} must agree on input feature size, got {} and {}." raise TypeError(msg.format(name, lhs_shape[1], rhs_shape[1])) _check_shapelike(name, "window_strides", window_strides) - if not onp.all(onp.greater(window_strides, 0)): + if not np.all(np.greater(window_strides, 0)): msg = "All elements of window_strides must be positive, got {}." raise TypeError(msg.format(window_strides)) if len(window_strides) != len(lhs_shape) - 2: @@ -5499,11 +5499,11 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): msg = "Wrong number of explicit pads for convolution: expected {}, got {}." raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) - lhs_padded = onp.add(lhs_shape[2:], onp.sum(onp.array(pads).reshape(-1, 2), + lhs_padded = np.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2), axis=1)) - out_space = onp.floor_divide( - onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 - out_space = onp.maximum(0, out_space) + out_space = np.floor_divide( + np.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 + out_space = np.maximum(0, out_space) assert lhs_shape[0] % batch_group_count == 0 out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0]) return tuple(out_shape + tuple(out_space)) @@ -5512,39 +5512,39 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, dimension_numbers): lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) - lhs_trans = onp.take(lhs_shape, lhs_perm) - rhs_trans = onp.take(rhs_shape, rhs_perm) + lhs_trans = np.take(lhs_shape, lhs_perm) + rhs_trans = np.take(rhs_shape, rhs_perm) out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding) - return tuple(onp.take(out_trans, onp.argsort(out_perm))) + return tuple(np.take(out_trans, np.argsort(out_perm))) def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, dimension_numbers): lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) - lhs_trans = onp.take(lhs_shape, lhs_perm) - rhs_trans = onp.take(rhs_shape, rhs_perm) + lhs_trans = np.take(lhs_shape, lhs_perm) + rhs_trans = np.take(rhs_shape, rhs_perm) if isinstance(padding, str): padding = [_conv_transpose_padding(k, s, padding) for k,s in zip(rhs_trans[2:], window_strides)] - padding = list(map(onp.sum, padding)) + padding = list(map(np.sum, padding)) unpad_out_space = [(i-1) * s - k + 2 for i, k, s in zip(lhs_trans[2:], rhs_trans[2:], window_strides)] - out_space = onp.sum([unpad_out_space, padding], axis=0).tolist() + out_space = np.sum([unpad_out_space, padding], axis=0).tolist() out_trans = tuple((lhs_trans[0], rhs_trans[0]) + tuple(out_space)) - return tuple(onp.take(out_trans, onp.argsort(out_perm))) + return tuple(np.take(out_trans, np.argsort(out_perm))) def _check_shapelike(fun_name, arg_name, obj): """Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints).""" - if not isinstance(obj, (tuple, list, onp.ndarray)): + if not isinstance(obj, (tuple, list, np.ndarray)): msg = "{} {} must be of type tuple/list/ndarray, got {}." raise TypeError(msg.format(fun_name, arg_name, type(obj))) # bool(obj) for an ndarray raises an error, so we check len if not len(obj): # pylint: disable=g-explicit-length-test return - obj_arr = onp.array(obj) + obj_arr = np.array(obj) if obj_arr.ndim != 1: msg = "{} {} must be rank 1, got {}." raise TypeError(msg.format(obj_arr.ndim)) @@ -5573,8 +5573,8 @@ def _dynamic_slice_indices(operand, start_indices): add(start_indices, _const(start_indices, operand.shape)), start_indices) else: - return [onp.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_)) - if isinstance(i, (int, onp.integer)) + return [np.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_)) + if isinstance(i, (int, np.integer)) else select(lt(i, _const(i, 0)), add(i, _const(i, d)), i) for i, d in zip(start_indices, operand.shape)] @@ -5583,7 +5583,7 @@ def _dynamic_slice_indices(operand, start_indices): def _const(example, val): if dtypes.is_python_scalar(example): return dtypes.scalar_type_of(example)(val) - return onp.array(val, _dtype(example)) + return np.array(val, _dtype(example)) _zeros: Callable = partial(full_like, fill_value=0) _zero: Callable = partial(full_like, shape=(), fill_value=0) @@ -5596,7 +5596,7 @@ def _const(example, val): _dtype: Callable = dtypes.result_type def _iscomplex(x) -> bool: - return dtypes.issubdtype(_dtype(x), onp.complexfloating) + return dtypes.issubdtype(_dtype(x), np.complexfloating) def ranges_like(*xs): @@ -5716,8 +5716,8 @@ def _conv_general_vjp_lhs_padding( lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation) rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation) out_dilated_shape = _dilate_shape(out_shape, window_strides) - pad_before = onp.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1 - pad_after = (onp.add(lhs_dilated_shape, rhs_dilated_shape) - 1 + pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1 + pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1 - out_dilated_shape - pad_before) return safe_zip(pad_before, pad_after) @@ -5753,11 +5753,11 @@ def _abstractify(x): def _check_user_dtype_supported(dtype, fun_name=None): - onp_dtype = onp.dtype(dtype) - if onp_dtype.kind not in "biufc" and onp_dtype.type != dtypes.bfloat16: + np_dtype = np.dtype(dtype) + if np_dtype.kind not in "biufc" and np_dtype.type != dtypes.bfloat16: msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" raise TypeError(msg) - if dtype is not None and onp_dtype != dtypes.canonicalize_dtype(dtype): + if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype): msg = ("Explicitly requested dtype {} {} is not available, " "and will be truncated to dtype {}. To enable more dtypes, set the " "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 577f7d867875..7913981077c4 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -24,7 +24,7 @@ import operator from typing import Callable, Sequence -import numpy as onp +import numpy as np import jax from jax import core @@ -273,7 +273,7 @@ def while_loop(cond_fun, body_fun, init_val): if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) - if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), onp.bool_): + if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_): msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) @@ -313,9 +313,9 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args, cond_jaxpr.literals), extend_name_stack(name_stack, 'cond'), *(x + z)) if batched: - scalar = ShapedArray((), onp.bool_) + scalar = ShapedArray((), np.bool_) or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar) - pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, onp.array(False))], or_, + pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_, list(range(cond_jaxpr.out_avals[0].ndim))) body_c = xb.make_computation_builder("body_computation") @@ -560,10 +560,10 @@ def switch(index, branches, operand): branches: Sequence of functions (A -> B) to be applied based on `index`. operand: Operand (A) input to whichever branch is applied. """ - if len(onp.shape(index)) != 0: + if len(np.shape(index)) != 0: raise TypeError( f"Branch index must be scalar, " - f"got {index} of shape {onp.shape(index)}.") + f"got {index} of shape {np.shape(index)}.") try: index_dtype = dtypes.result_type(index) @@ -582,9 +582,9 @@ def switch(index, branches, operand): elif len(branches) == 1: return branches[0](operand) - index = lax.convert_element_type(index, onp.int32) - lo = onp.array(0, onp.int32) - hi = onp.array(len(branches) - 1, onp.int32) + index = lax.convert_element_type(index, np.int32) + lo = np.array(0, np.int32) + hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) if (jax.api._jit_is_disabled() and @@ -651,9 +651,9 @@ def cond(pred, true_fun, false_fun, operand): return _cond(*args, **kwargs) def _cond(pred, true_fun: Callable, false_fun: Callable, operand): - if len(onp.shape(pred)) != 0: + if len(np.shape(pred)) != 0: raise TypeError( - f"Pred must be a scalar, got {pred} of shape {onp.shape(pred)}.") + f"Pred must be a scalar, got {pred} of shape {np.shape(pred)}.") try: pred_dtype = dtypes.result_type(pred) @@ -686,7 +686,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, operand): out_tree, true_jaxpr.out_avals, false_out_tree, false_jaxpr.out_avals) - index = lax.convert_element_type(pred, onp.int32) + index = lax.convert_element_type(pred, np.int32) linear = (False,) * (len(consts) + len(ops)) out = cond_p.bind( @@ -742,7 +742,7 @@ def _select_tree(indices, branch_vals): if len(branch_vals) == 1: return branch_vals[0] mid = len(branch_vals) // 2 - mid = onp.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices))) + mid = np.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices))) return lax.select(lax.lt(indices, mid), _select_tree(indices, branch_vals[:mid]), _select_tree(indices - mid, branch_vals[mid:])) @@ -752,7 +752,7 @@ def _cond_index_bcast_and_select_tree(indices, branch_vals): return branch_vals[0] else: bcast_indices = lax.broadcast_in_dim( - indices, onp.shape(branch_vals[0]), list(range(onp.ndim(indices)))) + indices, np.shape(branch_vals[0]), list(range(np.ndim(indices)))) return _select_tree(bcast_indices, branch_vals) def _cond_batching_rule(args, dims, branches, linear): @@ -1066,7 +1066,7 @@ def _cond_typecheck(*avals, branches, linear): index_aval, *op_avals = avals core.typecheck_assert( - index_aval.dtype == onp.int32, + index_aval.dtype == np.int32, f'cond called with index of type {index_aval.dtype} instead of int32') core.typecheck_assert( all(_map(core.typecompat, jaxpr0.in_avals, op_avals)), @@ -1859,8 +1859,8 @@ def _flatten(args): def _check_shapes(func_name, expected_name, actual, expected, tree): - actual_shapes = _map(onp.shape, actual) - expected_shapes = _map(onp.shape, expected) + actual_shapes = _map(np.shape, actual) + expected_shapes = _map(np.shape, expected) if actual_shapes != expected_shapes: raise ValueError('{}() output shapes must match {}, got {} and {}' .format(func_name, expected_name, @@ -2102,18 +2102,18 @@ def _linear_solve_batching_rule(args, dims, **kwargs): def _interleave(a, b): """Given two Tensors of static shape, interleave them along the first axis.""" # TODO(mattjj) - import jax.numpy as np + import jax.numpy as jnp # [a b c ...] [d e f ...] -> [a d b e c f ...] half_num_elems = b.shape[0] if a.shape[0] > b.shape[0]: - return np.concatenate( - [np.reshape(np.stack([a[: -1], b], axis=1), - (2 * half_num_elems,) + a.shape[1:]), + return jnp.concatenate( + [jnp.reshape(jnp.stack([a[: -1], b], axis=1), + (2 * half_num_elems,) + a.shape[1:]), a[-1:]], axis=0) else: - return np.reshape(np.stack([a, b], axis=1), - (2 * half_num_elems,) + a.shape[1:]) + return jnp.reshape(jnp.stack([a, b], axis=1), + (2 * half_num_elems,) + a.shape[1:]) def associative_scan(fn, elems): """Perform a scan with an associative binary operation, in parallel. diff --git a/jax/lax/lax_fft.py b/jax/lax/lax_fft.py index 27fcc2b9ba93..70f93aada3b2 100644 --- a/jax/lax/lax_fft.py +++ b/jax/lax/lax_fft.py @@ -15,7 +15,7 @@ from functools import partial -import numpy as onp +import numpy as np from jax.abstract_arrays import ShapedArray from jax.api import jit, vjp @@ -36,24 +36,24 @@ ] def _promote_to_complex(arg): - dtype = dtypes.result_type(arg, onp.complex64) + dtype = dtypes.result_type(arg, np.complex64) # XLA's FFT op only supports C64 in jaxlib versions 0.1.47 and earlier. # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer. - if lib.version <= (0, 1, 47) and dtype == onp.complex128: - dtype = onp.complex64 + if lib.version <= (0, 1, 47) and dtype == np.complex128: + dtype = np.complex64 return lax.convert_element_type(arg, dtype) def _promote_to_real(arg): - dtype = dtypes.result_type(arg, onp.float32) + dtype = dtypes.result_type(arg, np.float32) # XLA's FFT op only supports F32. # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer. - if lib.version <= (0, 1, 47) and dtype == onp.float64: - dtype = onp.float32 + if lib.version <= (0, 1, 47) and dtype == np.float64: + dtype = np.float32 return lax.convert_element_type(arg, dtype) def fft(x, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: - if onp.iscomplexobj(x): + if np.iscomplexobj(x): raise ValueError("only real valued inputs supported for rfft") x = _promote_to_real(x) else: @@ -67,8 +67,8 @@ def fft(x, fft_type, fft_lengths): def fft_impl(x, fft_type, fft_lengths): return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths) -_complex_dtype = lambda dtype: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype -_real_dtype = lambda dtype: onp.zeros((), dtype).real.dtype +_complex_dtype = lambda dtype: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype +_real_dtype = lambda dtype: np.zeros((), dtype).real.dtype _is_even = lambda x: x % 2 == 0 def fft_abstract_eval(x, fft_type, fft_lengths): diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 9aadf6573bc4..328416b4f9f3 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -17,7 +17,7 @@ import collections -import numpy as onp +import numpy as np from jax import core from jax import ad_util @@ -72,8 +72,8 @@ def psum(x, axis_name, *, axis_index_groups=None): """ _validate_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) - leaves = [lax.convert_element_type(l, onp.int32) - if dtypes.dtype(l) == onp.bool_ else l for l in leaves] + leaves = [lax.convert_element_type(l, np.int32) + if dtypes.dtype(l) == np.bool_ else l for l in leaves] out_flat = psum_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -327,7 +327,7 @@ def _psum_translation_rule(c, *args, replica_groups=None, platform=None): out = [None] * len(args) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): - is_complex = dtypes.issubdtype(dtype, onp.complexfloating) + is_complex = dtypes.issubdtype(dtype, np.complexfloating) n = len(dtype_args) if is_complex: dtype_args = ([xops.Real(x) for x in dtype_args] + @@ -355,7 +355,7 @@ def _translate(val): psum = partial(_allreduce_translation_rule, lax.add_p, c, replica_groups=replica_groups) dtype = c.get_shape(val).numpy_dtype() - if dtypes.issubdtype(dtype, onp.complexfloating): + if dtypes.issubdtype(dtype, np.complexfloating): return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val))) else: return psum(val) @@ -507,14 +507,14 @@ def _broadcasting_papply(prim, name, size, vals, axes, **params): if xdim is None: if x.shape: if x.shape[ydim] == 1: - x = x.reshape(onp.delete(x.shape, ydim)) + x = x.reshape(np.delete(x.shape, ydim)) else: x = _drop(x, ydim, name) return prim.bind(x, y, **params), ydim elif ydim is None: if y.shape: if y.shape[xdim] == 1: - y = y.reshape(onp.delete(y.shape, xdim)) + y = y.reshape(np.delete(y.shape, xdim)) else: y = _drop(y, xdim, name) return prim.bind(x, y, **params), xdim @@ -525,11 +525,11 @@ def _broadcasting_papply(prim, name, size, vals, axes, **params): y_tosplit = xdim - int(ydim <= xdim) if y.shape[y_tosplit] == 1: y = _allgather(y, ydim, size, name) - y = y.reshape(onp.delete(y.shape, xdim)) + y = y.reshape(np.delete(y.shape, xdim)) return prim.bind(x, y, **params), ydim elif x.shape[x_tosplit] == 1: x = _allgather(x, xdim, size, name) - x = x.reshape(onp.delete(x.shape, ydim)) + x = x.reshape(np.delete(x.shape, ydim)) return prim.bind(x, y, **params), ydim else: x = all_to_all(x, name, x_tosplit, xdim) @@ -565,7 +565,7 @@ def _reducer_papply(prim, collective, name, size, vals, papply_axes, axes, **kwa if not axes or papply_axis in axes: return collective(result, axis_name=name), None else: - new_papply_axis = papply_axis - onp.sum(onp.less(other_axes, papply_axis)) + new_papply_axis = papply_axis - np.sum(np.less(other_axes, papply_axis)) return result, new_papply_axis def _defreducer(prim, collective_prim): @@ -754,13 +754,13 @@ def cases(x, y, xdim, ydim, xc, yc, xb, yb): def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions): operand, = vals axis, = axes - old_sizes = tuple(onp.insert(operand.shape, axis, size)) + old_sizes = tuple(np.insert(operand.shape, axis, size)) def filter_ones(xs): return filter(lambda x: x != 1, xs) def find_new_axis(old_axis, old_sizes, new_sizes): - left = onp.prod(old_sizes[:old_axis]) + left = np.prod(old_sizes[:old_axis]) size = old_sizes[old_axis] prod = 1 for i, cur_sz in enumerate(new_sizes): @@ -829,7 +829,7 @@ def _conv_general_dilated_papply_rule( lhs_dim, rhs_dim = dims lhs_spec_batch_dim = dimension_numbers.lhs_spec[0] if rhs_dim is None and lhs_dim == lhs_spec_batch_dim: - lhs = lax.reshape(lhs, tuple(onp.insert(lhs.shape, lhs_dim, 1))) + lhs = lax.reshape(lhs, tuple(np.insert(lhs.shape, lhs_dim, 1))) out = lax.conv_general_dilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision) @@ -848,8 +848,8 @@ def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape, raise ValueError( "broadcast_in_dim changes hidden dimension size: {} to {}".format( shape[dim], shape[out_dim])) - sub_bdims = tuple(onp.delete(broadcast_dimensions, dim)) - sub_shape = tuple(onp.delete(shape, out_dim)) + sub_bdims = tuple(np.delete(broadcast_dimensions, dim)) + sub_shape = tuple(np.delete(shape, out_dim)) return lax.broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim @@ -906,8 +906,8 @@ def _gather_papply_rule( start_index_map=dimension_numbers.start_index_map) out = lax.gather(operand, start_indices, dimension_numbers=dnums, slice_sizes=slice_sizes) - out_dim = start_indices_dim + onp.sum( - onp.less_equal(offset_dims, start_indices_dim)) + out_dim = start_indices_dim + np.sum( + np.less_equal(offset_dims, start_indices_dim)) return out, out_dim else: raise NotImplementedError diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 8aa22be8093d..778afed8e966 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -30,7 +30,7 @@ from ..config import flags from .. import util from .. import dtypes -import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy +import numpy as np import threading try: @@ -86,7 +86,7 @@ def get_compile_options(num_replicas, num_partitions, device_assignment=None, 2, 'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s', num_replicas, num_partitions, device_assignment) - device_assignment = onp.array(device_assignment) + device_assignment = np.array(device_assignment) # Allow 1D device assignment if num_partitions is 1. if (device_assignment.ndim == 1) and (num_partitions == 1): @@ -288,9 +288,8 @@ def supported_numpy_dtypes(): # TODO(mattjj,frostig): try to remove this function def normalize_to_xla_dtypes(val): """Normalize dtypes in a value.""" - if hasattr(val, '__array__') or onp.isscalar(val): - return onp.asarray(val, - dtype=dtypes.canonicalize_dtype(dtypes.result_type(val))) + if hasattr(val, '__array__') or np.isscalar(val): + return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val))) elif isinstance(val, (tuple, list)): return tuple(normalize_to_xla_dtypes(x) for x in val) raise TypeError('Can\'t convert to XLA: {}'.format(val)) @@ -361,7 +360,7 @@ def _sharding_to_proto(sharding: SpatialSharding): else: proto.type = xla_client.OpSharding.Type.OTHER proto.tile_assignment_dimensions = list(sharding) - proto.tile_assignment_devices = list(range(onp.product(sharding))) + proto.tile_assignment_devices = list(range(np.product(sharding))) return proto def set_sharding(builder, op, sharding: SpatialSharding): @@ -395,7 +394,7 @@ def _ndarray_constant_handler(c, val, canonicalize_types=True): special handling of arrays with any strides of size zero: for those, it generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose to avoid staging in large literals that might arise from np.zeros or np.ones - or the output of lax.broadcast (which uses onp.broadcast_to which in turn + or the output of lax.broadcast (which uses np.broadcast_to which in turn uses size-zero strides). Args: @@ -407,28 +406,28 @@ def _ndarray_constant_handler(c, val, canonicalize_types=True): staged into the XLA Computation. """ # TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose - if onp.any(onp.equal(0, val.strides)) and val.size > 0: - zero_stride_axes, = onp.where(onp.equal(0, val.strides)) - other_axes, = onp.where(onp.not_equal(0, val.strides)) + if np.any(np.equal(0, val.strides)) and val.size > 0: + zero_stride_axes, = np.where(np.equal(0, val.strides)) + other_axes, = np.where(np.not_equal(0, val.strides)) collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) for ax in range(val.ndim))] xla_val = xops.Broadcast( _numpy_array_constant(c, collapsed_val, canonicalize_types), - onp.take(val.shape, zero_stride_axes)) - permutation = onp.argsort(tuple(zero_stride_axes) + tuple(other_axes)) + np.take(val.shape, zero_stride_axes)) + permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes)) return xops.Transpose(xla_val, permutation) else: return _numpy_array_constant(c, val, canonicalize_types) -register_constant_handler(onp.ndarray, _ndarray_constant_handler) +register_constant_handler(np.ndarray, _ndarray_constant_handler) def _scalar_constant_handler(c, val, canonicalize_types=True): return _numpy_array_constant(c, val, canonicalize_types) -for scalar_type in [onp.int8, onp.int16, onp.int32, onp.int64, - onp.uint8, onp.uint16, onp.uint32, onp.uint64, - onp.float16, onp.float32, onp.float64, onp.float128, - onp.bool_, onp.longlong]: +for scalar_type in [np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64, np.float128, + np.bool_, np.longlong]: register_constant_handler(scalar_type, _scalar_constant_handler) def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True): diff --git a/jax/util.py b/jax/util.py index 0a1bb7a5631e..9ecd9a3d5c86 100644 --- a/jax/util.py +++ b/jax/util.py @@ -17,7 +17,7 @@ import itertools as it import types -import numpy as onp +import numpy as np def safe_zip(*args): @@ -233,7 +233,7 @@ def get_module_functions(module): continue attr = getattr(module, key) if isinstance( - attr, (types.BuiltinFunctionType, types.FunctionType, onp.ufunc)): + attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc)): module_fns[key] = attr return module_fns