Skip to content

Commit

Permalink
Add some types to jax.random and jnp.ndarray.
Browse files Browse the repository at this point in the history
  • Loading branch information
aslanides committed Apr 12, 2020
1 parent f610867 commit c06fe56
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 25 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ docs/notebooks/.ipynb_checkpoints/
docs/_autosummary
.idea
.vscode
venv/
jax.iml
24 changes: 14 additions & 10 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@
import re
import string
import types
from typing import Callable
from typing import Sequence
import warnings

import numpy as onp
import opt_einsum

from jax import jit, device_put
from .. import core
from .. import dtypes
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
from ..config import flags
from ..interpreters.xla import DeviceArray
from .. import lax
from ..util import partial, get_module_functions, unzip2, prod as _prod, subvals
from ..lib import pytree
from ..lib import xla_client
from jax import core
from jax import dtypes
from jax.abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
from jax.config import flags
from jax.interpreters.xla import DeviceArray
from jax import lax
from jax.util import partial, get_module_functions, unzip2, prod as _prod, subvals
from jax.lib import pytree
from jax.lib import xla_client

FLAGS = flags.FLAGS
flags.DEFINE_enum(
Expand Down Expand Up @@ -100,6 +100,10 @@ def __instancecheck__(self, instance):
return isinstance(instance, _arraylike_types)

class ndarray(onp.ndarray, metaclass=_ArrayMeta):
dtype: onp.dtype
shape: Sequence[int]
size: int

def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
Expand Down
52 changes: 37 additions & 15 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@


from functools import partial
import itertools
from typing import Optional, Sequence, Union

import numpy as onp

from . import lax
from . import numpy as np
from . import tree_util
from . import dtypes
from .api import jit, vmap
from .numpy.lax_numpy import _constant_like, asarray, stack
from .numpy.lax_numpy import _constant_like, asarray
from jax.lib import xla_bridge
from jax.lib import cuda_prng
from jax import core
Expand All @@ -40,12 +39,11 @@
from jax.scipy.special import logit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.util import prod


def PRNGKey(seed):
def PRNGKey(seed: int) -> np.ndarray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
Args:
Expand All @@ -68,7 +66,7 @@ def PRNGKey(seed):
k2 = convert(np.bitwise_and(seed, 0xFFFFFFFF))
return lax.concatenate([k1, k2], 0)

def _is_prng_key(key):
def _is_prng_key(key: np.ndarray) -> bool:
try:
return key.shape == (2,) and key.dtype == onp.uint32
except AttributeError:
Expand Down Expand Up @@ -227,7 +225,7 @@ def threefry_2x32(keypair, count):
return lax.reshape(out[:-1] if odd_size else out, count.shape)


def split(key, num=2):
def split(key: np.ndarray, num: int = 2) -> np.ndarray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
Expand Down Expand Up @@ -302,7 +300,11 @@ def _check_shape(name, shape, *param_shapes):
raise ValueError(msg.format(name, shape_, shape))


def uniform(key, shape=(), dtype=onp.float64, minval=0., maxval=1.):
def uniform(key: np.ndarray,
shape: Sequence[int] = (),
dtype: onp.dtype = onp.float64,
minval: float = 0.,
maxval: float = 1.) -> np.ndarray:
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
Expand Down Expand Up @@ -350,7 +352,11 @@ def _uniform(key, shape, dtype, minval, maxval):
lax.reshape(floats * (maxval - minval) + minval, shape))


def randint(key, shape, minval, maxval, dtype=onp.int64):
def randint(key: np.ndarray,
shape: Sequence[int],
minval: Union[int, np.ndarray],
maxval: Union[int, np.ndarray],
dtype: onp.dtype = onp.int64):
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
Expand Down Expand Up @@ -411,7 +417,7 @@ def _randint(key, shape, minval, maxval, dtype):
return lax.add(minval, lax.convert_element_type(random_offset, dtype))


def shuffle(key, x, axis=0):
def shuffle(key: np.ndarray, x: np.ndarray, axis: int = 0) -> np.ndarray:
"""Shuffle the elements of an array uniformly at random along an axis.
Args:
Expand Down Expand Up @@ -452,7 +458,9 @@ def _shuffle(key, x, axis):
return x


def normal(key, shape=(), dtype=onp.float64):
def normal(key: np.ndarray,
shape: Sequence[int] = (),
dtype: onp.dtype = onp.float64) -> np.ndarray:
"""Sample standard normal random values with given shape and float dtype.
Args:
Expand All @@ -478,7 +486,11 @@ def _normal(key, shape, dtype):
return onp.array(onp.sqrt(2), dtype) * lax.erf_inv(u)


def multivariate_normal(key, mean, cov, shape=None, dtype=onp.float64):
def multivariate_normal(key: np.ndarray,
mean: np.ndarray,
cov: np.ndarray,
shape: Optional[Sequence[int]] = None,
dtype: onp.dtype = onp.float64) -> np.ndarray:
"""Sample multivariate normal random values with given mean and covariance.
Args:
Expand Down Expand Up @@ -528,7 +540,11 @@ def _multivariate_normal(key, mean, cov, shape, dtype):
return mean + np.tensordot(normal_samples, chol_factor, [-1, 1])


def truncated_normal(key, lower, upper, shape=None, dtype=onp.float64):
def truncated_normal(key: np.ndarray,
lower: Union[float, np.ndarray],
upper: Union[float, np.ndarray],
shape: Optional[Sequence[int]] = None,
dtype: onp.dtype = onp.float64) -> np.ndarray:
"""Sample truncated standard normal random values with given shape and dtype.
Args:
Expand Down Expand Up @@ -569,7 +585,9 @@ def _truncated_normal(key, lower, upper, shape, dtype):
return sqrt2 * lax.erf_inv(a + u * (b - a))


def bernoulli(key, p=onp.float32(0.5), shape=None):
def bernoulli(key: np.ndarray,
p: np.ndarray = onp.float32(0.5),
shape: Optional[Sequence[int]] = None) -> np.ndarray:
"""Sample Bernoulli random values with given shape and mean.
Args:
Expand Down Expand Up @@ -603,7 +621,11 @@ def _bernoulli(key, p, shape):
return uniform(key, shape, lax.dtype(p)) < p


def beta(key, a, b, shape=None, dtype=onp.float64):
def beta(key: np.ndarray,
a: Union[float, np.ndarray],
b: Union[float, np.ndarray],
shape: Optional[Sequence[int]] = None,
dtype: onp.dtype = onp.float64) -> np.ndarray:
"""Sample Bernoulli random values with given shape and mean.
Args:
Expand Down

0 comments on commit c06fe56

Please sign in to comment.