Skip to content

Commit

Permalink
Add np.polysub (jax-ml#3319)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush-1506 authored Jun 5, 2020
1 parent 841f21f commit 29740de
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ Not every function in NumPy is implemented; contributions are welcome!
percentile
polyadd
polymul
polysub
polyval
power
positive
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
negative, newaxis, nextafter, nonzero, not_equal, number, numpy_version,
object_, ones, ones_like, operator_name, outer, packbits, pad, percentile,
pi, polyadd, polymul, polyval, positive, power, prod, product, promote_types, ptp, quantile,
pi, polyadd, polymul, polysub, polyval, positive, power, prod, product, promote_types, ptp, quantile,
rad2deg, radians, ravel, real, reciprocal, remainder, repeat, reshape,
result_type, right_shift, rint, roll, rollaxis, rot90, round, row_stack,
save, savez, searchsorted, select, set_printoptions, shape, sign, signbit,
Expand Down
11 changes: 11 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2645,6 +2645,17 @@ def polymul(a1, a2, *, trim_leading_zeros=False):
val = convolve(a1, a2, mode='full')
return val

@_wraps(np.polysub)
def polysub(a, b):
a = asarray(a)
b = asarray(b)

if b.shape[0] <= a.shape[0]:
return a.at[-b.shape[0]:].add(-b)
else:
return -b.at[-a.shape[0]:].add(-a)


@_wraps(np.append)
def append(arr, values, axis=None):
if axis is None:
Expand Down
17 changes: 17 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,23 @@ def testPolyAdd(self, shape, dtype):
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "a_shape={} , b_shape={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),
jtu.format_shape_dtype_string(b_shape, dtype)),
"dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for b_shape in one_dim_array_shapes))
def testPolySub(self, a_shape, b_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1, arg2: np.polysub(arg1, arg2)
jnp_fun = lambda arg1, arg2: jnp.polysub(arg1, arg2)
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
Expand Down

0 comments on commit 29740de

Please sign in to comment.