Skip to content

Commit

Permalink
Merge pull request jax-ml#10266 from jakevdp:ndarray-tile
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 442062659
  • Loading branch information
jax authors committed Apr 15, 2022
2 parents 6eec758 + be5c84d commit a303e4b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->


## jax 0.3.8 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
* Changes:
* The `DeviceArray.tile()` method is deprecated, because numpy arrays do not have a
`tile()` method. As a replacement for this, use {func}`jax.numpy.tile`
({jax-issue}`#10266`).

## jax 0.3.7 (April 15, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7).
Expand Down
16 changes: 14 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import builtins
import collections
from functools import partial
from functools import partial, wraps as functools_wraps
import operator
import types
from typing import Any, Sequence, FrozenSet, Optional, Tuple, Union
Expand Down Expand Up @@ -4605,7 +4605,15 @@ def _operator_round(number, ndigits=None):
_diff_methods = ["choose", "conj", "conjugate", "copy", "cumprod", "cumsum",
"diagonal", "dot", "max", "mean", "min", "prod", "ptp",
"ravel", "repeat", "sort", "squeeze", "std", "sum",
"swapaxes", "take", "tile", "trace", "var"]
"swapaxes", "take", "trace", "var"]


def _deprecate_function(fun, msg):
@functools_wraps(fun)
def wrapped(*args, **kwargs):
warnings.warn(msg, FutureWarning)
return fun(*args, **kwargs)
return wrapped

# These methods are mentioned explicitly by nondiff_methods, so we create
# _not_implemented implementations of them here rather than in __init__.py.
Expand Down Expand Up @@ -4936,6 +4944,8 @@ def _set_shaped_array_attributes(shaped_array):
# Forward methods and properties using core.{aval_method, aval_property}:
for method_name in _nondiff_methods + _diff_methods:
setattr(shaped_array, method_name, core.aval_method(globals()[method_name]))
# TODO(jakevdp): remove tile method after August 2022
setattr(shaped_array, "tile", core.aval_method(_deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead.")))
setattr(shaped_array, "reshape", core.aval_method(_reshape))
setattr(shaped_array, "transpose", core.aval_method(_transpose))
setattr(shaped_array, "flatten", core.aval_method(ravel))
Expand Down Expand Up @@ -4967,6 +4977,8 @@ def _set_device_array_base_attributes(device_array):
setattr(device_array, "__{}__".format(operator_name), function)
for method_name in _nondiff_methods + _diff_methods:
setattr(device_array, method_name, globals()[method_name])
# TODO(jakevdp): remove tile method after August 2022
setattr(device_array, "tile", _deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead."))
setattr(device_array, "reshape", _reshape)
setattr(device_array, "transpose", _transpose)
setattr(device_array, "flatten", ravel)
Expand Down

0 comments on commit a303e4b

Please sign in to comment.