Skip to content

Commit

Permalink
Remove CPU test variant.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669359594
  • Loading branch information
Google-ML-Automation authored and jax authors committed Aug 30, 2024
1 parent 164b884 commit 2f3990d
Showing 1 changed file with 0 additions and 14 deletions.
14 changes: 0 additions & 14 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ jax_test(
name = "api_test",
srcs = ["api_test.py"],
shard_count = 10,
tags = ["test_cpu_thunks"],
)

jax_test(
Expand Down Expand Up @@ -339,7 +338,6 @@ jax_test(
jax_test(
name = "infeed_test",
srcs = ["infeed_test.py"],
tags = ["test_cpu_thunks"],
deps = [
"//jax:experimental_host_callback",
],
Expand All @@ -349,7 +347,6 @@ jax_test(
name = "jax_jit_test",
srcs = ["jax_jit_test.py"],
main = "jax_jit_test.py",
tags = ["test_cpu_thunks"],
)

py_test(
Expand Down Expand Up @@ -440,7 +437,6 @@ jax_test(
"gpu": 30,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
)

jax_test(
Expand All @@ -451,7 +447,6 @@ jax_test(
"gpu": 20,
"tpu": 20,
},
tags = ["test_cpu_thunks"],
)

jax_test(
Expand All @@ -472,7 +467,6 @@ jax_test(
"gpu": 10,
"tpu": 10,
},
tags = ["test_cpu_thunks"],
)

jax_test(
Expand All @@ -483,13 +477,11 @@ jax_test(
"gpu": 10,
"tpu": 10,
},
tags = ["test_cpu_thunks"],
)

jax_test(
name = "lax_numpy_vectorize_test",
srcs = ["lax_numpy_vectorize_test.py"],
tags = ["test_cpu_thunks"],
)

jax_test(
Expand Down Expand Up @@ -554,7 +546,6 @@ jax_test(
"gpu": 40,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
deps = [
"//jax:internal_test_util",
"//jax:lax_reference",
Expand Down Expand Up @@ -584,7 +575,6 @@ jax_test(
"gpu": 40,
"tpu": 20,
},
tags = ["test_cpu_thunks"],
)

jax_test(
Expand All @@ -595,7 +585,6 @@ jax_test(
"gpu": 40,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
)

Expand All @@ -607,7 +596,6 @@ jax_test(
"gpu": 40,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
)

Expand Down Expand Up @@ -650,7 +638,6 @@ jax_test(
"gpu": 40,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
)

jax_test(
Expand Down Expand Up @@ -1169,7 +1156,6 @@ py_test(
jax_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.py"],
tags = ["test_cpu_thunks"],
deps = [
"//jax:compilation_cache_internal",
"//jax:compiler",
Expand Down

0 comments on commit 2f3990d

Please sign in to comment.