Skip to content

Commit

Permalink
Merge pull request numpy#25711 from mhvk/pocketfft-cpp
Browse files Browse the repository at this point in the history
ENH: support float and longdouble in FFT using C++ pocketfft version
  • Loading branch information
seberg authored Feb 9, 2024
2 parents f6f3e41 + b8c020f commit 0291a81
Show file tree
Hide file tree
Showing 11 changed files with 492 additions and 2,744 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
[submodule "numpy/_core/src/highway"]
path = numpy/_core/src/highway
url = https://github.com/google/highway.git
[submodule "numpy/fft/pocketfft"]
path = numpy/fft/pocketfft
url = https://github.com/mreineck/pocketfft
15 changes: 11 additions & 4 deletions doc/release/upcoming_changes/25536.new_feature.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
``out`` for `numpy.fft`
-----------------------
The various FFT routines in `numpy.fft` have gained an ``out``
argument that can be used for in-place calculations.
`numpy.fft` support for different precisions and in-place calculations
----------------------------------------------------------------------

The various FFT routines in `numpy.fft` now do their calculations natively in
float, double, or long double precision, depending on the input precision,
instead of always calculating in double precision. Hence, the calculation will
now be less precise for single and more precise for long double precision.
The data type of the output array will now be adjusted accordingly.

Furthermore, all FFT routines have gained an ``out`` argument that can be used
for in-place calculations.
80 changes: 31 additions & 49 deletions numpy/fft/_pocketfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
import warnings

from numpy.lib.array_utils import normalize_axis_index
from numpy._core import asarray, empty, zeros, swapaxes, conjugate, take, sqrt
from numpy._core import (asarray, empty, zeros, swapaxes, result_type,
conjugate, take, sqrt, reciprocal)
from . import _pocketfft_umath as pfu
from numpy._core import overrides

Expand All @@ -47,12 +48,24 @@
# divided. This replaces the original, more intuitive 'fct` parameter to avoid
# divisions by zero (or alternatively additional checks) in the case of
# zero-length axes during its computation.
def _raw_fft(a, n, axis, is_real, is_forward, inv_norm, out=None):
axis = normalize_axis_index(axis, a.ndim)
if n is None:
n = a.shape[axis]
def _raw_fft(a, n, axis, is_real, is_forward, norm, out=None):
if n < 1:
raise ValueError(f"Invalid number of FFT data points ({n}) specified.")

# Calculate the normalization factor, passing in the array dtype to
# avoid precision loss in the possible sqrt or reciprocal.
if not is_forward:
norm = _swap_direction(norm)

fct = 1/inv_norm
if norm is None or norm == "backward":
fct = 1
elif norm == "ortho":
fct = reciprocal(sqrt(n, dtype=a.real.dtype))
elif norm == "forward":
fct = reciprocal(n, dtype=a.real.dtype)
else:
raise ValueError(f'Invalid norm value {norm}; should be "backward",'
'"ortho" or "forward".')

n_out = n
if is_real:
Expand All @@ -64,47 +77,20 @@ def _raw_fft(a, n, axis, is_real, is_forward, inv_norm, out=None):
else:
ufunc = pfu.fft if is_forward else pfu.ifft

axis = normalize_axis_index(axis, a.ndim)

if out is None:
if is_real and not is_forward: # irfft, complex in, real output.
out_dtype = result_type(a.real.dtype, 1.0)
else: # Others, complex output.
out_dtype = result_type(a.dtype, 1j)
out = empty(a.shape[:axis] + (n_out,) + a.shape[axis+1:],
dtype=complex if is_forward or not is_real else float)
dtype=out_dtype)
elif ((shape := getattr(out, "shape", None)) is not None
and (len(shape) != a.ndim or shape[axis] != n_out)):
raise ValueError("output array has wrong shape.")
# Note: for backward compatibility, we want to accept longdouble as well,
# even though it is at reduced precision. To tell the promotor that we
# want to do that, we set the signature (to the only the ufunc has).
# Then, the default casting='same_kind' will take care of the rest.
# TODO: create separate float, double, and longdouble loops.
return ufunc(a, fct, axes=[(axis,), (), (axis,)], out=out,
signature=ufunc.types[0])


def _get_forward_norm(n, norm):
if n < 1:
raise ValueError(f"Invalid number of FFT data points ({n}) specified.")

if norm is None or norm == "backward":
return 1
elif norm == "ortho":
return sqrt(n)
elif norm == "forward":
return n
raise ValueError(f'Invalid norm value {norm}; should be "backward",'
'"ortho" or "forward".')


def _get_backward_norm(n, norm):
if n < 1:
raise ValueError(f"Invalid number of FFT data points ({n}) specified.")

if norm is None or norm == "backward":
return n
elif norm == "ortho":
return sqrt(n)
elif norm == "forward":
return 1
raise ValueError(f'Invalid norm value {norm}; should be "backward", '
'"ortho" or "forward".')
return ufunc(a, fct, axes=[(axis,), (), (axis,)], out=out)


_SWAP_DIRECTION_MAP = {"backward": "forward", None: "forward",
Expand Down Expand Up @@ -220,8 +206,7 @@ def fft(a, n=None, axis=-1, norm=None, out=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
inv_norm = _get_forward_norm(n, norm)
output = _raw_fft(a, n, axis, False, True, inv_norm, out)
output = _raw_fft(a, n, axis, False, True, norm, out)
return output


Expand Down Expand Up @@ -327,8 +312,7 @@ def ifft(a, n=None, axis=-1, norm=None, out=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
inv_norm = _get_backward_norm(n, norm)
output = _raw_fft(a, n, axis, False, False, inv_norm, out=out)
output = _raw_fft(a, n, axis, False, False, norm, out=out)
return output


Expand Down Expand Up @@ -426,8 +410,7 @@ def rfft(a, n=None, axis=-1, norm=None, out=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
inv_norm = _get_forward_norm(n, norm)
output = _raw_fft(a, n, axis, True, True, inv_norm, out=out)
output = _raw_fft(a, n, axis, True, True, norm, out=out)
return output


Expand Down Expand Up @@ -536,8 +519,7 @@ def irfft(a, n=None, axis=-1, norm=None, out=None):
a = asarray(a)
if n is None:
n = (a.shape[axis] - 1) * 2
inv_norm = _get_backward_norm(n, norm)
output = _raw_fft(a, n, axis, True, False, inv_norm, out=out)
output = _raw_fft(a, n, axis, True, False, norm, out=out)
return output


Expand Down
Loading

0 comments on commit 0291a81

Please sign in to comment.