Skip to content

Commit

Permalink
PR Response Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Sep 23, 2019
1 parent c312729 commit 3d21393
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
10 changes: 4 additions & 6 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,15 +1235,13 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=False):
if (onp.issubdtype(lax._dtype(a), onp.bool_) or
onp.issubdtype(lax._dtype(a), onp.integer)):
return mean(a, axis, dtype, out, keepdims)
# Check and Count all non-NaN values
if dtype is None:
dtype = lax._dtype(a)
nan_mask = logical_not(isnan(a))
normalizer = sum(nan_mask, axis=axis, dtype=int32, keepdims=keepdims)
normalizer = lax.convert_element_type(normalizer, dtype)
if dtype is None:
dtype = lax._dtype(a)
#Perform mean calculation
td = true_divide(nansum(a, axis, dtype=dtype, keepdims=keepdims), normalizer)
return lax.convert_element_type(td, dtype)
td = lax.divide(nansum(a, axis, dtype=dtype, keepdims=keepdims), normalizer)
return td

def _make_cumulative_reduction(onp_reduction, window_reduce, init_val,
squash_nan=False):
Expand Down
16 changes: 8 additions & 8 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,14 @@ def rand_some_nan():

def rand(shape, dtype):
"""The random sampler function."""
if not onp.issubdtype(dtype, onp.floating):
# only float types have inf
return base_rand(shape, dtype)

if onp.issubdtype(dtype, onp.complexfloating):
base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype
return rand(shape, base_dtype) + 1j * rand(shape, base_dtype)

if not onp.issubdtype(dtype, onp.floating):
# only float types have inf
return base_rand(shape, dtype)

dims = _dims_of_shape(shape)
nan_flips = rng.rand(*dims) < 0.1

Expand All @@ -405,14 +405,14 @@ def rand_some_inf_and_nan():

def rand(shape, dtype):
"""The random sampler function."""
if not onp.issubdtype(dtype, onp.floating):
# only float types have inf
return base_rand(shape, dtype)

if onp.issubdtype(dtype, onp.complexfloating):
base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype
return rand(shape, base_dtype) + 1j * rand(shape, base_dtype)

if not onp.issubdtype(dtype, onp.floating):
# only float types have inf
return base_rand(shape, dtype)

dims = _dims_of_shape(shape)
posinf_flips = rng.rand(*dims) < 0.1
neginf_flips = rng.rand(*dims) < 0.1
Expand Down

0 comments on commit 3d21393

Please sign in to comment.