Skip to content

Commit

Permalink
Fix jax2tf fft test failure.
Browse files Browse the repository at this point in the history
FFT no longer lowers to a custom call on CPU

PiperOrigin-RevId: 579865297
  • Loading branch information
hawkinsp authored and jax authors committed Nov 6, 2023
1 parent 462ef16 commit 41430ac
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ def func(x):
# An old lowering, with ducc_fft. We keep it for 6 months.
data = self.load_testdata(cpu_ducc_fft.data_2023_03_17)
# We have changed the lowering for fft since we saved this data.
self.run_one_test(func, data,
expect_current_custom_calls=["dynamic_ducc_fft"])
# FFT no longer lowers to a custom call.
self.run_one_test(func, data, expect_current_custom_calls=[])

# A newer lowering, with dynamic_ducc_fft.
data = self.load_testdata(cpu_ducc_fft.data_2023_06_14)
self.run_one_test(func, data)
# FFT no longer lowers to a custom call.
self.run_one_test(func, data, expect_current_custom_calls=[])

def cholesky_input(self, shape, dtype):
a = jtu.rand_default(self.rng())(shape, dtype)
Expand Down

0 comments on commit 41430ac

Please sign in to comment.