Skip to content

Commit

Permalink
Merge pull request jax-ml#13035 from jakevdp:jnp-put
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 485075125
  • Loading branch information
jax authors committed Oct 31, 2022
2 parents 71edfc7 + 8bde3a0 commit f3ddd56
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ namespace; they are listed below.
pad
percentile
piecewise
place
poly
polyadd
polyder
Expand All @@ -318,6 +319,7 @@ namespace; they are listed below.
product
promote_types
ptp
put
quantile
r_
rad2deg
Expand Down
26 changes: 26 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4885,6 +4885,32 @@ def wrapped(*args, **kwargs):
return wrapped


@_wraps(np.place, lax_description="""
Numpy function :func:`numpy.place` is not available in JAX and will raise a
:class:`NotImplementedError`, because ``np.place`` modifies its arguments in-place,
and in JAX arrays are immutable. A JAX-compatible approach to array updates
can be found in :attr:`jax.numpy.ndarray.at`.
""")
def place(*args, **kwargs):
raise NotImplementedError(
"jax.numpy.place is not implemented because JAX arrays cannot be modified in-place. "
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")


@_wraps(np.put, lax_description="""
Numpy function :func:`numpy.put` is not available in JAX and will raise a
:class:`NotImplementedError`, because ``np.put`` modifies its arguments in-place,
and in JAX arrays are immutable. A JAX-compatible approach to array updates
can be found in :attr:`jax.numpy.ndarray.at`.
""")
def put(*args, **kwargs):
raise NotImplementedError(
"jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. "
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")


### add method and operator overloads to arraylike classes

# We add operator overloads to DeviceArray and ShapedArray. These method and
Expand Down
18 changes: 12 additions & 6 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,18 @@ def wrap(op):
parameters = _parse_parameters(parsed.sections['Parameters'])
if extra_params:
parameters.update(_parse_extra_params(extra_params))
parsed.sections['Parameters'] = (
"Parameters\n"
"----------\n" +
"\n".join(_versionadded.split(desc)[0].rstrip() for p, desc in parameters.items()
if (code is None or p in code.co_varnames) and p not in skip_params)
)
parameters = {p: desc for p, desc in parameters.items()
if (code is None or p in code.co_varnames)
and p not in skip_params}
if parameters:
parsed.sections['Parameters'] = (
"Parameters\n"
"----------\n" +
"\n".join(_versionadded.split(desc)[0].rstrip()
for p, desc in parameters.items())
)
else:
del parsed.sections['Parameters']

docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,10 @@
percentile as percentile,
pi as pi,
piecewise as piecewise,
place as place,
printoptions as printoptions,
promote_types as promote_types,
put as put,
quantile as quantile,
ravel as ravel,
ravel_multi_index as ravel_multi_index,
Expand Down

0 comments on commit f3ddd56

Please sign in to comment.