Skip to content

Commit

Permalink
Merge pull request jax-ml#15184 from jakevdp:move-median
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 519003606
  • Loading branch information
jax authors committed Mar 24, 2023
2 parents f981243 + 6f8885a commit e9bc7ee
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 310 deletions.
3 changes: 2 additions & 1 deletion jax/_src/lax/eigh.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import jax
import jax._src.numpy.lax_numpy as jnp
import jax._src.numpy.linalg as jnp_linalg
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax import lax
from jax._src.lax import qdwh
Expand Down Expand Up @@ -360,7 +361,7 @@ def nearly_diagonal_case(agenda, blocks, eigenvectors):
def default_case(agenda, blocks, eigenvectors):
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
# TODO: Improve this?
split_point = jnp.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(
H, b, split_point, V0=V)

Expand Down
201 changes: 1 addition & 200 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int],
return array

stat_funcs: Dict[str, PadStatFunc] = {
"maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": median}
"maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": reductions.median}

pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width")
pad_width_arr = np.array(pad_width)
Expand Down Expand Up @@ -4582,161 +4582,6 @@ def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -
return c


@util._wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util.check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False)

@util._wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util.check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, True)

def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
a, = util.promote_dtypes_inexact(a)
keepdim = []
if issubdtype(a.dtype, np.complexfloating):
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
if axis is None:
a = ravel(a)
axis = 0
elif isinstance(axis, tuple):
keepdim = list(shape(a))
nd = ndim(a)
axis = tuple(_canonicalize_axis(ax, nd) for ax in axis)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis')
for ax in axis:
keepdim[ax] = 1

keep = set(range(nd)) - set(axis)
# prepare permutation
dimensions = list(range(nd))
for i, s in enumerate(sorted(keep)):
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
do_not_touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx not in axis)
touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx in axis)
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
axis = _canonicalize_axis(-1, ndim(a))
else:
axis = _canonicalize_axis(axis, ndim(a))

q_shape = shape(q)
q_ndim = ndim(q)
if q_ndim > 1:
raise ValueError(f"q must be have rank <= 1, got shape {shape(q)}")

a_shape = shape(a)

if squash_nans:
a = where(ufuncs.isnan(a), nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis)
counts = reductions.sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype,
keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
q = lax.mul(q, lax.sub(counts, _lax_const(q, 1)))
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)

low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
low = lax.convert_element_type(low, int64)
high = lax.convert_element_type(high, int64)
out_shape = q_shape + shape_after_reduction
index = [lax.broadcasted_iota(int64, out_shape, dim + q_ndim)
for dim in range(len(shape_after_reduction))]
if keepdims:
index[axis] = low
else:
index.insert(axis, low)
low_value = a[tuple(index)]
index[axis] = high
high_value = a[tuple(index)]
else:
a = where(reductions.any(ufuncs.isnan(a), axis=axis, keepdims=True), nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(array(a_shape[axis]), lax_internal._dtype(q))
q = lax.mul(q, n - 1)
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)

low = lax.clamp(_lax_const(low, 0), low, n - 1)
high = lax.clamp(_lax_const(high, 0), high, n - 1)
low = lax.convert_element_type(low, int64)
high = lax.convert_element_type(high, int64)

slice_sizes = list(a_shape)
slice_sizes[axis] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(range(
q_ndim,
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
collapsed_slice_dims=() if keepdims else (axis,),
start_index_map=(axis,))
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
if q_ndim == 1:
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
broadcast_dimensions=(0,))
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
broadcast_dimensions=(0,))

if interpolation == "linear":
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
lax.mul(high_value.astype(q.dtype), high_weight))
elif interpolation == "lower":
result = low_value
elif interpolation == "higher":
result = high_value
elif interpolation == "nearest":
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif interpolation == "midpoint":
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
else:
raise ValueError(f"interpolation={interpolation!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [shape(q)[0], *keepdim]
result = reshape(result, keepdim)
return lax.convert_element_type(result, a.dtype)


@partial(vectorize, excluded={0, 2, 3})
def _searchsorted_via_scan(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
op = _sort_le_comparator if side == 'left' else _sort_lt_comparator
Expand Down Expand Up @@ -4859,50 +4704,6 @@ def _const(v):
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)


@util._wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util.check_arraylike("percentile", a, q)
q, = util.promote_dtypes_inexact(q)
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method, keepdims=keepdims)

@util._wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util.check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, float32(100.0))
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method,
keepdims=keepdims)

@util._wraps(np.median, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
util.check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')

@util._wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
util.check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
method='midpoint')


@util._wraps(np.place, lax_description="""
Numpy function :func:`numpy.place` is not available in JAX and will raise a
Expand Down
Loading

0 comments on commit e9bc7ee

Please sign in to comment.