Skip to content

Commit

Permalink
split BUILD file, move up license files
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 18, 2018
1 parent 0dfa736 commit 9ae0f3a
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 118 deletions.
File renamed without changes.
File renamed without changes.
30 changes: 30 additions & 0 deletions examples/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
py_binary(
name = "interactive",
srcs = ["interactive.py"],
deps = ["//jax:libjax"],
)

py_library(
name = "datasets",
srcs = ["datasets.py"],
)

py_binary(
name = "mnist_classifier",
srcs = ["mnist_classifier.py"],
deps = [
":datasets",
"//jax:libjax",
],
)

py_binary(
name = "mnist_vae",
srcs = ["mnist_vae.py"],
deps = [
":datasets",
":minmax",
":stax",
"//jax:libjax",
],
)
118 changes: 0 additions & 118 deletions jax/BUILD
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
# JAX is Autograd and XLA
package(default_visibility = ["//visibility:public"])

licenses(["notice"]) # Apache 2

exports_files(["LICENSE"])

load(":build_defs.bzl", "jax_test")

package_group(name = "jax")

py_library(
name = "libjax",
srcs = glob(
Expand All @@ -29,128 +19,20 @@ py_library(
],
)

jax_test(
name = "core_test",
srcs = ["tests/core_test.py"],
shard_count = {
"cpu": 5,
},
)

jax_test(
name = "lax_test",
srcs = ["tests/lax_test.py"],
shard_count = {
"cpu": 40,
"gpu": 20,
},
)

jax_test(
name = "lax_numpy_test",
srcs = ["tests/lax_numpy_test.py"],
shard_count = {
"cpu": 40,
"gpu": 20,
},
)

jax_test(
name = "lax_numpy_indexing_test",
srcs = ["tests/lax_numpy_indexing_test.py"],
shard_count = {
"cpu": 10,
"gpu": 2,
},
)

jax_test(
name = "lax_scipy_test",
srcs = ["tests/lax_scipy_test.py"],
shard_count = {
"cpu": 10,
"gpu": 2,
},
)

jax_test(
name = "random_test",
srcs = ["tests/random_test.py"],
)

jax_test(
name = "api_test",
srcs = ["tests/api_test.py"],
)

jax_test(
name = "batching_test",
srcs = ["tests/batching_test.py"],
)

py_binary(
name = "interactive",
srcs = ["examples/interactive.py"],
deps = [":libjax"],
)

py_library(
name = "stax",
srcs = ["experimental/stax.py"],
deps = [":libjax"],
)

jax_test(
name = "stax_test",
srcs = ["tests/stax_test.py"],
deps = [":stax"],
)

py_library(
name = "minmax",
srcs = ["experimental/minmax.py"],
deps = [":libjax"],
)

jax_test(
name = "minmax_test",
srcs = ["tests/minmax_test.py"],
deps = [":minmax"],
)

py_library(
name = "lapax",
srcs = ["experimental/lapax.py"],
deps = [":libjax"],
)

jax_test(
name = "lapax_test",
srcs = ["tests/lapax_test.py"],
deps = [":lapax"],
)

py_library(
name = "datasets",
srcs = ["examples/datasets.py"],
)

py_binary(
name = "mnist_classifier",
srcs = ["examples/mnist_classifier.py"],
deps = [
":datasets",
":libjax",
],
)

py_binary(
name = "mnist_vae",
srcs = ["examples/mnist_vae.py"],
deps = [
":datasets",
":libjax",
":minmax",
":stax",
],
)
78 changes: 78 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
load(":build_defs.bzl", "jax_test")

jax_test(
name = "core_test",
srcs = ["tests/core_test.py"],
shard_count = {
"cpu": 5,
},
)

jax_test(
name = "lax_test",
srcs = ["tests/lax_test.py"],
shard_count = {
"cpu": 40,
"gpu": 20,
},
)

jax_test(
name = "lax_numpy_test",
srcs = ["tests/lax_numpy_test.py"],
shard_count = {
"cpu": 40,
"gpu": 20,
},
)

jax_test(
name = "lax_numpy_indexing_test",
srcs = ["tests/lax_numpy_indexing_test.py"],
shard_count = {
"cpu": 10,
"gpu": 2,
},
)

jax_test(
name = "lax_scipy_test",
srcs = ["tests/lax_scipy_test.py"],
shard_count = {
"cpu": 10,
"gpu": 2,
},
)

jax_test(
name = "random_test",
srcs = ["tests/random_test.py"],
)

jax_test(
name = "api_test",
srcs = ["tests/api_test.py"],
)

jax_test(
name = "batching_test",
srcs = ["tests/batching_test.py"],
)

jax_test(
name = "stax_test",
srcs = ["tests/stax_test.py"],
deps = [":stax"],
)

jax_test(
name = "minmax_test",
srcs = ["tests/minmax_test.py"],
deps = [":minmax"],
)

jax_test(
name = "lapax_test",
srcs = ["tests/lapax_test.py"],
deps = [":lapax"],
)
File renamed without changes.

0 comments on commit 9ae0f3a

Please sign in to comment.