Skip to content

Commit

Permalink
Simplify nanmean with logical not
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Sep 21, 2019
1 parent 17e5783 commit 12de814
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,26 @@ def nan_reduction(a, axis=None, out=None, keepdims=False, **kwargs):
nansum = _make_nan_reduction(onp.nansum, sum, 0, nan_if_all_nan=False)
nanprod = _make_nan_reduction(onp.nanprod, prod, 1, nan_if_all_nan=False)

<<<<<<< HEAD
=======
@_wraps(onp.nanmean)
def nanmean(a, axis=None, dtype=None, out=None, keepdims=False):
if out is not None:
raise ValueError("nanmean does not support the `out` argument.")
# Check and Count all non-NaN values
nan_mask = logical_not(isnan(a))
normalizer = sum(nan_mask, axis=axis, dtype=int32, keepdims=keepdims)
normalizer = lax.convert_element_type(normalizer, dtype)
#Perform mean calculation
if dtype is None:
if (onp.issubdtype(lax._dtype(a), onp.bool_) or
onp.issubdtype(lax._dtype(a), onp.integer)):
dtype = xla_bridge.canonicalize_dtype(onp.float64)
else:
dtype = lax._dtype(a)
td = true_divide(nansum(a, axis, dtype=dtype, keepdims=keepdims), normalizer)
return lax.convert_element_type(td, dtype)
>>>>>>> 7c6f468... Simplify nanmean with logical not

def _make_cumulative_reduction(onp_reduction, window_reduce, init_val,
squash_nan=False):
Expand Down

0 comments on commit 12de814

Please sign in to comment.