Skip to content

Commit

Permalink
Cleanup: convert uses of import numpy as onp in library code (jax-m…
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Jul 14, 2020
1 parent 512ed18 commit a7c2cde
Show file tree
Hide file tree
Showing 15 changed files with 527 additions and 528 deletions.
6 changes: 3 additions & 3 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Optional, Union, Callable, List, Dict

from absl import flags
import numpy as onp
import numpy as np
from tabulate import tabulate

from jax.util import safe_zip
Expand Down Expand Up @@ -59,7 +59,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
if iters is None:
warmup = 1
else:
warmup = onp.clip(1, iters // 10, 10)
warmup = np.clip(1, iters // 10, 10)
for _ in range(warmup):
f()

Expand All @@ -73,7 +73,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
times.append(end - start)
count += 1

times_arr = onp.array(times)
times_arr = np.array(times)
print("---------Benchmark results for %s---------" % (name or f.__name__))
print("mean=%f std=%f %%std=%f total=%f" %
(times_arr.mean(), times_arr.std(), _pstd(times_arr), times_arr.sum()))
Expand Down
136 changes: 68 additions & 68 deletions jax/api.py

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
Type, Union, cast)

import numpy as onp
import numpy as np

from . import dtypes
from .config import FLAGS
Expand Down Expand Up @@ -846,7 +846,7 @@ class UnshapedArray(AbstractValue):
array_abstraction_level = 2

def __init__(self, dtype, weak_type=False):
self.dtype = onp.dtype(dtypes.canonicalize_dtype(dtype))
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
self.weak_type = weak_type

def __eq__(self, other):
Expand All @@ -858,7 +858,7 @@ def __ne__(self, other):

def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.dtype, self.weak_type))

Expand Down Expand Up @@ -925,7 +925,7 @@ def __eq__(self, other):

def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type))

Expand Down Expand Up @@ -968,16 +968,16 @@ class ConcreteArray(ShapedArray):
array_abstraction_level = 0

def __init__(self, val, weak_type=False):
super(ConcreteArray, self).__init__(onp.shape(val), onp.result_type(val),
super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val),
weak_type=weak_type)
# Note: canonicalized self.dtype doesn't necessarily match self.val
self.val = val
assert self.dtype != onp.dtype('O')
assert self.dtype != np.dtype('O')

def __eq__(self, other):
return (type(self) is type(other) and self.dtype == other.dtype
and self.shape == other.shape and self.weak_type == other.weak_type
and onp.all(self.val == other.val))
and np.all(self.val == other.val))

def __hash__(self):
return id(self.val)
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
By default, loops and control-flow in JAX are executed and inlined during tracing.
For example, in the following code the `for` loop is unrolled during JAX tracing::
arr = onp.zeros(5)
arr = np.zeros(5)
for i in range(arr.shape[0]):
arr[i] += 2.
if i % 2 == 0:
Expand All @@ -32,7 +32,7 @@
conditionals as functions, and the array updates using a functional style that
returns an updated array, e.g.::
arr = onp.zeros(5)
arr = np.zeros(5)
def loop_body(i, acc_arr):
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
return lax.cond(i % 2 == 0,
Expand Down
24 changes: 12 additions & 12 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as onp
import numpy as np
from typing import Any, Callable, Dict, Optional, Tuple, Union

import jax
Expand Down Expand Up @@ -103,7 +103,7 @@ def aval(self):
return aval
elif type(aval) is ShapedArray:
assert 0 <= self.batch_dim < aval.ndim
new_shape = tuple(onp.delete(aval.shape, self.batch_dim))
new_shape = tuple(np.delete(aval.shape, self.batch_dim))
return ShapedArray(new_shape, aval.dtype)
else:
raise TypeError(aval)
Expand Down Expand Up @@ -236,7 +236,7 @@ def broadcast_batcher(prim, args, dims, **params):
either an int indicating the batch dimension, or else `not_mapped`
indicating no batching.
"""
shapes = {(x.shape, d) for x, d in zip(args, dims) if onp.ndim(x)}
shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)}
if len(shapes) == 1:
# if there's only agreeing batch dims and scalars, just call the primitive
d = next(d for d in dims if d is not not_mapped)
Expand All @@ -245,25 +245,25 @@ def broadcast_batcher(prim, args, dims, **params):
else:
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
ndim = max(onp.ndim(x) for x in args) # special-case scalar broadcasting
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
out = prim.bind(*args, **params)
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)

def _handle_scalar_broadcasting(nd, x, d):
if d is not_mapped or nd == onp.ndim(x):
if d is not_mapped or nd == np.ndim(x):
return x
else:
return x.reshape(x.shape + (1,) * (nd - onp.ndim(x)))
return x.reshape(x.shape + (1,) * (nd - np.ndim(x)))

def defreducer(prim):
primitive_batchers[prim] = partial(reducer_batcher, prim)

def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
operand, = batched_args
bdim, = batch_dims
axes = tuple(onp.where(onp.less(axes, bdim), axes, onp.add(axes, 1)))
bdim_out = int(list(onp.delete(onp.arange(operand.ndim), axes)).index(bdim))
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
Expand Down Expand Up @@ -303,10 +303,10 @@ def broadcast(x, sz, axis):
if core.get_aval(x) is core.abstract_unit:
return core.unit
if axis is last:
axis = onp.ndim(x)
shape = list(onp.shape(x))
axis = np.ndim(x)
shape = list(np.shape(x))
shape.insert(axis, sz)
broadcast_dims = tuple(onp.delete(onp.arange(len(shape)), axis))
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)

def moveaxis(x, src, dst):
Expand All @@ -315,7 +315,7 @@ def moveaxis(x, src, dst):
if src == dst:
return x
src, dst = src % x.ndim, dst % x.ndim
perm = [i for i in range(onp.ndim(x)) if i != src]
perm = [i for i in range(np.ndim(x)) if i != src]
perm.insert(dst, src)
return x.transpose(perm)

Expand Down
8 changes: 4 additions & 4 deletions jax/interpreters/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import string
from typing import Callable, Dict, Sequence, Union

import numpy as onp
import numpy as np

from .. import abstract_arrays
from .. import core, dtypes
Expand Down Expand Up @@ -317,7 +317,7 @@ def parse_spec(spec=''):

def _parse_dim(spec):
if '+' in spec:
return onp.sum(map(_parse_dim, spec.split('+')))
return np.sum(map(_parse_dim, spec.split('+')))
elif '*' in spec:
return prod(map(_parse_dim, spec.split('*')))
elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit():
Expand Down Expand Up @@ -383,10 +383,10 @@ def full_lower(self):

class MaskTrace(Trace):
def pure(self, val):
return MaskTracer(self, val, onp.shape(val))
return MaskTracer(self, val, np.shape(val))

def lift(self, val):
return MaskTracer(self, val, onp.shape(val))
return MaskTracer(self, val, np.shape(val))

def sublift(self, val):
return MaskTracer(self, val.val, val.polymorphic_shape)
Expand Down
6 changes: 3 additions & 3 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Set, Tuple, Type, Union, cast)
from weakref import ref

import numpy as onp
import numpy as np

from .. import core
from .. import linear_util as lu
Expand Down Expand Up @@ -128,7 +128,7 @@ def instantiate_const(self, tracer) -> Tracer:
if const is None:
return tracer
else:
if type(const) in core.literalable_types and onp.shape(const) == ():
if type(const) in core.literalable_types and np.shape(const) == ():
return self.new_instantiated_literal(const)
else:
return self.new_instantiated_const(const)
Expand All @@ -138,7 +138,7 @@ def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer':
if const is None:
return tracer
else:
aval = raise_to_shaped(get_aval(const), onp.isscalar(const))
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))

def process_primitive(self, primitive, tracers, params):
Expand Down
42 changes: 21 additions & 21 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
Type, Union)

from absl import logging
import numpy as onp
import numpy as np

from ..config import flags
from .. import core
Expand Down Expand Up @@ -465,7 +465,7 @@ def _axis_index_bind(*, axis_name):
nreps = dynamic_axis_env.nreps
trace = frame.pmap_trace

out_aval = ShapedArray((), onp.int32)
out_aval = ShapedArray((), np.int32)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
dict(nreps=nreps, sizes=sizes,
Expand All @@ -476,19 +476,19 @@ def _axis_index_bind(*, axis_name):
if not frame.soft_trace:
return out_tracer
else:
val_out = out_tracer * frame.soft_size + onp.arange(frame.soft_size)
val_out = out_tracer * frame.soft_size + np.arange(frame.soft_size)
return SplitAxisTracer(frame.soft_trace, axis_name, val_out)

def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):
div = xb.constant(c, onp.array(nreps // prod(sizes), dtype=onp.uint32))
mod = xb.constant(c, onp.array(sizes[-1], dtype=onp.uint32))
div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(onp.int32))
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))

axis_index_p = core.Primitive('axis_index')
axis_index_p.def_custom_bind(_axis_index_bind)
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), onp.int32))
lambda *args, **params: ShapedArray((), np.int32))
xla.translations[axis_index_p] = _axis_index_translation_rule


Expand Down Expand Up @@ -587,7 +587,7 @@ def block_until_ready(self):
def _value(self):
if self._npy_value is None:
self.copy_to_host_async()
npy_value = onp.empty(self.aval.shape, self.aval.dtype)
npy_value = np.empty(self.aval.shape, self.aval.dtype)
for i in self.one_replica_buffer_indices:
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
self._npy_value = npy_value
Expand Down Expand Up @@ -633,7 +633,7 @@ def _shard_sharded_device_array_slow_path(x, devices, indices):
shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array_slow_path

def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
return xb.constant(c, onp.asarray(val), canonicalize_types=canonicalize_types)
return xb.constant(c, np.asarray(val), canonicalize_types=canonicalize_types)
xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)

core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
Expand Down Expand Up @@ -838,7 +838,7 @@ def dynamic_fun(dummy, *args):
# provided 1D list of devices).
device_assignment = tree_map(lambda d: d.id, devices)
# Convert to 2D in case it's 1D and we have > 1 partitions.
device_assignment = onp.array(device_assignment).reshape(
device_assignment = np.array(device_assignment).reshape(
(num_global_replicas, num_partitions))
compile_options = xb.get_compile_options(
num_replicas=num_global_replicas,
Expand Down Expand Up @@ -933,7 +933,7 @@ def get_num_partitions(*partitions):
if len(partition_specs) == 0:
# Everything is specified as replicated (all Nones).
return None
num_partitions_set = set(onp.prod(spec) for spec in partition_specs)
num_partitions_set = set(np.prod(spec) for spec in partition_specs)
if len(num_partitions_set) > 1:
raise ValueError(
f"All partition specs must use the same number of total partitions, "
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def _xla_shard(c, aval, axis_env, x):
return x
elif isinstance(aval, ShapedArray):
dims = list(c.get_shape(x).dimensions())
zero = xb.constant(c, onp.zeros((), dtype=onp.uint32))
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
idxs = [_unravel_index(c, axis_env)] + [zero] * (len(dims) - 1)
return xops.Reshape(xops.DynamicSlice(x, idxs, [1] + dims[1:]), dims[1:])
else:
Expand All @@ -1169,16 +1169,16 @@ def _xla_unshard(c, aval, axis_env, x, backend):
return x
elif isinstance(aval, ShapedArray):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
convert_bool = (onp.issubdtype(aval.dtype, onp.bool_)
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
and xb.get_backend(backend).platform in ('cpu', 'gpu'))
if convert_bool:
x = xops.ConvertElementType(x, xb.dtype_to_etype(onp.float32))
x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))

xla_shape = c.get_shape(x)
dims = list(xla_shape.dimensions())
padded = xops.Broadcast(xb.constant(c, onp.array(0, xla_shape.numpy_dtype())),
padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
[axis_env.sizes[-1]] + dims)
zero = xb.constant(c, onp.zeros((), dtype=onp.uint32))
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims)
padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs)
replica_groups_protos = xc.make_replica_groups(
Expand All @@ -1187,15 +1187,15 @@ def _xla_unshard(c, aval, axis_env, x, backend):

# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
nonzero = xops.Ne(out, xb.constant(c, onp.array(0, dtype=onp.float32)))
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(onp.bool_))
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_))
return out
else:
raise TypeError((aval, c.get_shape(x)))

def _unravel_index(c, axis_env):
div = xb.constant(c, onp.array(axis_env.nreps // prod(axis_env.sizes), onp.uint32))
mod = xb.constant(c, onp.array(axis_env.sizes[-1], onp.uint32))
div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32))
mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32))
return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)


Expand Down Expand Up @@ -1278,7 +1278,7 @@ def process_primitive(self, primitive, tracers, params):
if primitive is axis_index_p:
dummy, = vals_in
hard_idx = primitive.bind(dummy, **params)
val_out = hard_idx * params['soft_size'] + onp.arange(params['soft_size'])
val_out = hard_idx * params['soft_size'] + np.arange(params['soft_size'])
return SplitAxisTracer(self, params['axis_name'], val_out)
elif all(axis_name is not_mapped for axis_name in names_in):
return primitive.bind(*vals_in, **params)
Expand Down
Loading

0 comments on commit a7c2cde

Please sign in to comment.