Skip to content

Commit

Permalink
Merge CUDA and ROCM kernel code in jaxlib.
Browse files Browse the repository at this point in the history
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
  • Loading branch information
hawkinsp authored and jax authors committed Oct 25, 2022
1 parent 621f066 commit a852710
Show file tree
Hide file tree
Showing 48 changed files with 1,920 additions and 4,854 deletions.
28 changes: 14 additions & 14 deletions build/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,25 @@ def prepare_wheel(sources_path):
copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir)

cuda_dir = os.path.join(jaxlib_dir, "cuda")
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice")
os.makedirs(libdevice_dir)
copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir)
copy_file(f"__main__/jaxlib/cuda/_cusolver.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_cublas.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_cuda_linalg.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_cuda_prng.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir)
rocm_dir = os.path.join(jaxlib_dir, "rocm")
if exists(f"__main__/jaxlib/rocm/_hipsolver.{pyext}"):
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
os.makedirs(rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_hipsolver.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_hipblas.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_hip_linalg.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_hip_prng.{pyext}", dst_dir=rocm_dir)
if exists(f"__main__/jaxlib/cuda/_cusparse.{pyext}"):
copy_file(f"__main__/jaxlib/cuda/_cusparse.{pyext}", dst_dir=cuda_dir)
if exists(f"__main__/jaxlib/rocm/_hipsparse.{pyext}"):
copy_file(f"__main__/jaxlib/rocm/_hipsparse.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir)
if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"):
copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir)
if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"):
copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir)


mlir_dir = os.path.join(jaxlib_dir, "mlir")
Expand Down
103 changes: 65 additions & 38 deletions jaxlib/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,31 @@ licenses(["notice"])

package(default_visibility = ["//:__subpackages__"])

cc_library(
name = "cuda_vendor",
hdrs = [
"//jaxlib/gpu:vendor.h",
],
defines = ["JAX_GPU_CUDA=1"],
deps = [
"@local_config_cuda//cuda:cuda_headers",
],
)

cc_library(
name = "cuda_gpu_kernel_helpers",
srcs = ["cuda_gpu_kernel_helpers.cc"],
hdrs = ["cuda_gpu_kernel_helpers.h"],
srcs = [
"//jaxlib/gpu:gpu_kernel_helpers.cc",
],
hdrs = [
"//jaxlib/gpu:gpu_kernel_helpers.h",
],
copts = [
"-fexceptions",
],
features = ["-use_header_modules"],
deps = [
":cuda_vendor",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/memory",
Expand All @@ -47,10 +63,11 @@ cc_library(

cc_library(
name = "cublas_kernels",
srcs = ["cublas_kernels.cc"],
hdrs = ["cublas_kernels.h"],
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
Expand All @@ -71,31 +88,32 @@ cc_library(
)

pybind_extension(
name = "_cublas",
srcs = ["cublas.cc"],
name = "_blas",
srcs = ["//jaxlib/gpu:blas.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cublas",
module_name = "_blas",
deps = [
":cublas_kernels",
":cuda_vendor",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@pybind11",
],
)

cc_library(
name = "cusolver_kernels",
srcs = ["cusolver_kernels.cc"],
hdrs = ["cusolver_kernels.h"],
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
hdrs = ["//jaxlib/gpu:solver_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
Expand All @@ -108,16 +126,17 @@ cc_library(
)

pybind_extension(
name = "_cusolver",
srcs = ["cusolver.cc"],
name = "_solver",
srcs = ["//jaxlib/gpu:solver.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cusolver",
module_name = "_solver",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":cusolver_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
Expand All @@ -131,10 +150,11 @@ pybind_extension(

cc_library(
name = "cusparse_kernels",
srcs = ["cusparse_kernels.cc"],
hdrs = ["cusparse_kernels.h"],
srcs = ["//jaxlib/gpu:sparse_kernels.cc"],
hdrs = ["//jaxlib/gpu:sparse_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
Expand All @@ -148,16 +168,17 @@ cc_library(
)

pybind_extension(
name = "_cusparse",
srcs = ["cusparse.cc"],
name = "_sparse",
srcs = ["//jaxlib/gpu:sparse.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cusparse",
module_name = "_sparse",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":cusparse_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
Expand All @@ -179,12 +200,13 @@ pybind_extension(
cc_library(
name = "cuda_lu_pivot_kernels",
srcs = [
"cuda_lu_pivot_kernels.cc",
"//jaxlib/gpu:lu_pivot_kernels.cc",
],
hdrs = ["cuda_lu_pivot_kernels.h"],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_lu_pivot_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
Expand All @@ -194,30 +216,32 @@ cc_library(
cuda_library(
name = "cuda_lu_pivot_kernels_impl",
srcs = [
"cuda_lu_pivot_kernels.cu.cc",
"//jaxlib/gpu:lu_pivot_kernels.cu.cc",
],
hdrs = ["cuda_lu_pivot_kernels.h"],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)

pybind_extension(
name = "_cuda_linalg",
srcs = ["cuda_linalg.cc"],
name = "_linalg",
srcs = ["//jaxlib/gpu:linalg.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cuda_linalg",
module_name = "_linalg",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_lu_pivot_kernels",
":cuda_lu_pivot_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@local_config_cuda//cuda:cuda_headers",
Expand All @@ -228,12 +252,13 @@ pybind_extension(
cc_library(
name = "cuda_prng_kernels",
srcs = [
"cuda_prng_kernels.cc",
"//jaxlib/gpu:prng_kernels.cc",
],
hdrs = ["cuda_prng_kernels.h"],
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_prng_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
Expand All @@ -243,26 +268,27 @@ cc_library(
cuda_library(
name = "cuda_prng_kernels_impl",
srcs = [
"cuda_prng_kernels.cu.cc",
"//jaxlib/gpu:prng_kernels.cu.cc",
],
hdrs = ["cuda_prng_kernels.h"],
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)

pybind_extension(
name = "_cuda_prng",
srcs = ["cuda_prng.cc"],
name = "_prng",
srcs = ["//jaxlib/gpu:prng.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cuda_prng",
module_name = "_prng",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_prng_kernels",
Expand All @@ -275,12 +301,13 @@ pybind_extension(

cc_library(
name = "cuda_gpu_kernels",
srcs = ["cuda_gpu_kernels.cc"],
srcs = ["//jaxlib/gpu:gpu_kernels.cc"],
visibility = ["//visibility:public"],
deps = [
":cublas_kernels",
":cuda_lu_pivot_kernels",
":cuda_prng_kernels",
":cuda_vendor",
":cusolver_kernels",
":cusparse_kernels",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
Expand All @@ -291,10 +318,10 @@ cc_library(
py_library(
name = "cuda_gpu_support",
deps = [
":_cublas",
":_cuda_linalg",
":_cuda_prng",
":_cusolver",
":_cusparse",
":_blas",
":_linalg",
":_prng",
":_solver",
":_sparse",
],
)
Loading

0 comments on commit a852710

Please sign in to comment.