Skip to content

Commit

Permalink
Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimen…
Browse files Browse the repository at this point in the history
…tal.rnn module."

PiperOrigin-RevId: 490499003
  • Loading branch information
jax authors committed Nov 23, 2022
1 parent fe56a19 commit d1fbdbc
Show file tree
Hide file tree
Showing 15 changed files with 0 additions and 1,378 deletions.
7 changes: 0 additions & 7 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,3 @@ pytype_library(
":jax",
],
)

pytype_library(
name = "rnn",
srcs = ["experimental/rnn.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
3 changes: 0 additions & 3 deletions jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ def _parse_version(v: str) -> Tuple[int, ...]:
# branch on the Jax github.
xla_extension_version = getattr(xla_client, '_version', 0)

if xla_extension_version > 108:
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error

can_execute_with_token = (
xla_extension_version >= 89 and hasattr(
xla_client.LoadedExecutable # type: ignore
Expand Down
Loading

0 comments on commit d1fbdbc

Please sign in to comment.