Skip to content

Commit

Permalink
Merge pull request jax-ml#1944 from clemisch/master
Browse files Browse the repository at this point in the history
Implement numpy.gradient
  • Loading branch information
mattjj authored Jan 9, 2020
2 parents ab25825 + 9ef9b38 commit 327dca8
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
11 changes: 11 additions & 0 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,17 @@ def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0):
limit_indices = list(operand.shape)
strides = [1] * operand.ndim

# translate `None`
len_axis = operand.shape[axis]
start_index = start_index if start_index is not None else 0
limit_index = limit_index if limit_index is not None else len_axis

# translate negative indices
if start_index < 0:
start_index = start_index + len_axis
if limit_index < 0:
limit_index = limit_index + len_axis

axis = int(axis)
start_indices[axis] = int(start_index)
limit_indices[axis] = int(limit_index)
Expand Down
43 changes: 43 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,49 @@ def diff(a, n=1, axis=-1,):
return a


@partial(jit, static_argnums=1)
def _gradient(a, axis):
def gradient_along_axis(a, axis):
sliced = partial(lax.slice_in_dim, a, axis=axis)
a_grad = concatenate((
sliced(1, 2) - sliced(0, 1),
(sliced(2, None) - sliced(0, -2)) * 0.5,
sliced(-1, None) - sliced(-2, -1),
), axis)
return a_grad

if axis is None:
axis = range(a.ndim)
else:
if isinstance(axis, int):
axis = (axis,)
if not isinstance(axis, tuple) and not isinstance(axis, list):
raise ValueError("Give `axis` either as int or iterable")
axis = [_canonicalize_axis(i, a.ndim) for i in axis]

if min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
raise ValueError(
"Shape of array too small to calculate a numerical gradient")

# TODO: use jax.lax loop tools if possible
a_grad = [gradient_along_axis(a, ax) for ax in axis]

if len(axis) == 1:
a_grad = a_grad[0]

return a_grad


@_wraps(onp.gradient)
def gradient(a, *args, **kwargs):
axis = kwargs.pop("axis", None)
if not len(args) == 0:
raise ValueError("*args (sample distances) not implemented")
if not len(kwargs) == 0:
raise ValueError("Only `axis` keyword is implemented")
return _gradient(a, axis)


@_wraps(onp.isrealobj)
def isrealobj(x):
return not iscomplexobj(x)
Expand Down
21 changes: 20 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,26 @@ def testPrecision(self):
partial(lnp.inner, precision=HIGHEST),
ones_1d, ones_1d)

@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_shape={}_axis={}_dtype={}").format(shape, axis, dtype),
"shape": shape,
"axis": axis,
"dtype": dtype, "rng_factory": rng_factory}
for shape in [(10,), (10, 15), (10, 15, 20)]
for _num_axes in range(len(shape))
for axis in itertools.combinations(range(len(shape)), _num_axes)
for dtype in inexact_dtypes
for rng_factory in [jtu.rand_default]))
def testGradient(self, shape, axis, dtype, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
lnp_fun = lambda y: lnp.gradient(y, axis=axis)
onp_fun = lambda y: onp.gradient(y, axis=axis)
self._CheckAgainstNumpy(
onp_fun, lnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

def testZerosShapeErrors(self):
# see https://github.com/google/jax/issues/1822
self.assertRaisesRegex(
Expand All @@ -2579,7 +2599,6 @@ def testZerosShapeErrors(self):
"If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.",
lambda: api.jit(lnp.zeros)(2))


# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.

Expand Down

0 comments on commit 327dca8

Please sign in to comment.