Skip to content

Commit

Permalink
Merge pull request jax-ml#1982 from noble-ai/rfftfreq
Browse files Browse the repository at this point in the history
added rfftfreq, tests, and documentation link.
  • Loading branch information
mattjj authored Jan 11, 2020
2 parents 34ede6b + 8b6f660 commit acbd267
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ jax.numpy.fft
fft2
ifft2
fftfreq
rfftfreq

jax.numpy.linalg
----------------
Expand Down
21 changes: 21 additions & 0 deletions jax/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,27 @@ def fftfreq(n, d=1.0):
return k / (d * n)


@_wraps(onp.fft.rfftfreq)
def rfftfreq(n, d=1.0):
if isinstance(n, list) or isinstance(n, tuple):
raise ValueError(
"The n argument of jax.np.fft.rfftfreq only takes an int. "
"Got n = %s." % list(n))

elif isinstance(d, list) or isinstance(d, tuple):
raise ValueError(
"The d argument of jax.np.fft.rfftfreq only takes a single value. "
"Got d = %s." % list(d))

if n % 2 == 0:
k = np.arange(0, n // 2 + 1)

else:
k = np.arange(0, (n - 1) // 2 + 1)

return k / (d * n)


@_wraps(onp.fft.fftn)
def fftn(a, s=None, axes=None, norm=None):
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)
Expand Down
47 changes: 46 additions & 1 deletion tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,52 @@ def testFftfreqErrors(self, n):
self.assertRaisesRegex(
ValueError,
"The n argument of jax.np.fft.{} only takes an int. "
"Got n = \\[0, 1, 2\\]".format(name),
"Got n = \\[0, 1, 2\\].".format(name),
lambda: func(n=n)
)
self.assertRaisesRegex(
ValueError,
"The d argument of jax.np.fft.{} only takes a single value. "
"Got d = \\[0, 1, 2\\].".format(name),
lambda: func(n=10, d=n)
)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_size={}_d={}".format(
jtu.format_shape_dtype_string([size], dtype), d),
"dtype": dtype, "size": size, "rng_factory": rng_factory, "d": d}
for rng_factory in [jtu.rand_default]
for dtype in all_dtypes
for size in [9, 10, 101, 102]
for d in [0.1, 2.]))
def testRfftfreq(self, size, d, dtype, rng_factory):
rng = rng_factory()
args_maker = lambda: (rng([size], dtype),)
np_op = np.fft.rfftfreq
onp_op = onp.fft.rfftfreq
np_fn = lambda a: np_op(size, d=d)
onp_fn = lambda a: onp_op(size, d=d)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
# Test gradient for differentiable types.
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
jtu.check_grads(np_fn, args_maker(), order=1, atol=tol, rtol=tol)
jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}".format(n),
"n": n}
for n in [[0, 1, 2]]))
def testRfftfreqErrors(self, n):
name = 'rfftfreq'
func = np.fft.rfftfreq
self.assertRaisesRegex(
ValueError,
"The n argument of jax.np.fft.{} only takes an int. "
"Got n = \\[0, 1, 2\\].".format(name),
lambda: func(n=n)
)
self.assertRaisesRegex(
Expand Down

0 comments on commit acbd267

Please sign in to comment.