Skip to content

Commit

Permalink
Add support for 64-bit FFTs. (jax-ml#3290)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp authored Jun 2, 2020
1 parent 3909875 commit a06b122
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ http_archive(
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "642f5a1bc191dfb96b2d7ed1cfb8f2a1515b5169b8de4381c75193cef8404b92",
strip_prefix = "tensorflow-b25fb1fe32094b60f5a53ad5f986ad65a9f05919",
sha256 = "99231c027ad22e1a82866d2e6bc60379d06d0a75793ac09b547282eb5b382d37",
strip_prefix = "tensorflow-37aaafb0c1baa7acd0607748326cc12faf556277",
urls = [
"https://github.com/tensorflow/tensorflow/archive/b25fb1fe32094b60f5a53ad5f986ad65a9f05919.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/37aaafb0c1baa7acd0607748326cc12faf556277.tar.gz",
],
)

Expand Down
13 changes: 8 additions & 5 deletions jax/lax/lax_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax.interpreters import xla
from jax.util import prod
from . import dtypes, lax
from .. import lib
from ..lib import xla_client
from ..interpreters import ad
from ..interpreters import batching
Expand All @@ -35,16 +36,18 @@
]

def _promote_to_complex(arg):
dtype = onp.result_type(arg, onp.complex64)
# XLA's FFT op only supports C64.
if dtype == onp.complex128:
dtype = dtypes.result_type(arg, onp.complex64)
# XLA's FFT op only supports C64 in jaxlib versions 0.1.47 and earlier.
# TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
if lib.version <= (0, 1, 47) and dtype == onp.complex128:
dtype = onp.complex64
return lax.convert_element_type(arg, dtype)

def _promote_to_real(arg):
dtype = onp.result_type(arg, onp.float64)
dtype = dtypes.result_type(arg, onp.float64)
# XLA's FFT op only supports F32.
if dtype == onp.float64:
# TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
if lib.version <= (0, 1, 47) and dtype == onp.float64:
dtype = onp.float32
return lax.convert_element_type(arg, dtype)

Expand Down
2 changes: 1 addition & 1 deletion jaxlib/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.1.47"
__version__ = "0.1.48"
5 changes: 2 additions & 3 deletions tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl.testing import absltest
from absl.testing import parameterized

import jax
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
Expand All @@ -29,9 +30,7 @@


float_dtypes = [np.float32, np.float64]
# TODO(b/144573940): np.complex128 isn't supported by XLA, and the JAX
# implementation casts to complex64.
complex_dtypes = [np.complex64]
complex_dtypes = [np.complex64, np.complex128]
inexact_dtypes = float_dtypes + complex_dtypes
int_dtypes = [np.int32, np.int64]
bool_dtypes = [np.bool_]
Expand Down

0 comments on commit a06b122

Please sign in to comment.