Skip to content

Commit

Permalink
Use pytree defined in tensorflow. (jax-ml#4087)
Browse files Browse the repository at this point in the history
It also adds some tests on the scalar C++ conversion.
  • Loading branch information
jblespiau authored Aug 18, 2020
1 parent fe69d3c commit 2ab6b42
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 929 deletions.
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ http_archive(
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "f6f5e2c7900c4eb334c6070fe54fcf4b719ae334de691c89c816c9cfa7bee727",
strip_prefix = "tensorflow-9ce9e779f1510f688771090527fcd45de41691ac",
sha256 = "03dd1adcfe560a634a63bafe582632d1130376cecf5d91ab03c3b26da2d839c7",
strip_prefix = "tensorflow-a2c58558b6f98fd54a3b40097269bdf592195969",
urls = [
"https://github.com/tensorflow/tensorflow/archive/9ce9e779f1510f688771090527fcd45de41691ac.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/a2c58558b6f98fd54a3b40097269bdf592195969.tar.gz",
],
)

Expand Down
1 change: 0 additions & 1 deletion build/install_xla_in_source_tree.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ fi
# Copy the XLA dependencies into jax/lib, fixing up some imports to point to the
# new location.
cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/pytree.so)" "${TARGET}/jaxlib"
if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then
cp -f "$(rlocation __main__/jaxlib/cublas_kernels.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib"
Expand Down
6 changes: 5 additions & 1 deletion jax/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def _check_jaxlib_version():

from jaxlib import xla_client
from jaxlib import lapack
from jaxlib import pytree
if version < (0, 1, 53):
from jaxlib import pytree # pytype: disable=import-error
else:
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
from jaxlib import cusolver
try:
from jaxlib import cuda_prng
Expand Down
41 changes: 0 additions & 41 deletions jaxlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -89,47 +89,6 @@ py_library(
],
)

cc_library(
name = "pytree_lib",
srcs = ["pytree.cc"],
hdrs = ["pytree.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@pybind11",
],
)

pybind_extension(
name = "pytree",
srcs = ["pytree_extension.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "pytree",
deps = [
":pytree_lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@pybind11",
],
)

pybind_extension(
name = "cublas_kernels",
srcs = ["cublas.cc"],
Expand Down
Loading

0 comments on commit 2ab6b42

Please sign in to comment.