Skip to content

Commit 9fff9ae

Browse files
committed
Update
1 parent 0b70244 commit 9fff9ae

File tree

4 files changed

+39
-3
lines changed

4 files changed

+39
-3
lines changed

jax/_src/xla_bridge.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
XlaBackend = xla_client.Client
6464

65+
MIN_COMPUTE_CAPABILITY = 52
6566

6667
# TODO(phawkins): Remove jax_xla_backend.
6768
_XLA_BACKEND = config.DEFINE_string(
@@ -252,6 +253,19 @@ def make_cpu_client() -> xla_client.Client:
252253
)
253254

254255

256+
def _check_cuda_compute_capability(devices_to_check):
257+
for idx in devices_to_check:
258+
compute_cap = cuda_versions.cuda_compute_capability(idx)
259+
if compute_cap < MIN_COMPUTE_CAPABILITY:
260+
warnings.warn(
261+
f"Device {idx} has CUDA compute capability {compute_cap/10} which is "
262+
"lower than the minimum supported compute capability "
263+
f"{MIN_COMPUTE_CAPABILITY/10}. See "
264+
"https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for "
265+
"more details",
266+
RuntimeWarning
267+
)
268+
255269
def _check_cuda_versions():
256270
assert cuda_versions is not None
257271

@@ -311,15 +325,16 @@ def make_gpu_client(
311325
if visible_devices != "all":
312326
allowed_devices = {int(x) for x in visible_devices.split(",")}
313327

314-
if platform_name == "cuda":
315-
_check_cuda_versions()
316-
317328
use_mock_gpu_client = _USE_MOCK_GPU_CLIENT.value
318329
num_nodes = (
319330
_MOCK_NUM_GPUS.value
320331
if use_mock_gpu_client
321332
else distributed.global_state.num_processes
322333
)
334+
if platform_name == "cuda":
335+
_check_cuda_versions()
336+
devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count())
337+
_check_cuda_compute_capability(devices_to_check)
323338

324339
return xla_client.make_gpu_client(
325340
distributed_client=distributed.global_state.client,

jaxlib/cuda/versions.cc

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ NB_MODULE(_versions, m) {
4545
m.def("cusolver_get_version", &CusolverGetVersion);
4646
m.def("cublas_get_version", &CublasGetVersion);
4747
m.def("cusparse_get_version", &CusparseGetVersion);
48+
m.def("cuda_compute_capability", &CudaComputeCapability);
49+
m.def("cuda_device_count", &CudaDeviceCount);
4850
}
4951

5052
} // namespace

jaxlib/cuda/versions_helpers.cc

+17
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,22 @@ size_t CudnnGetVersion() {
8484
}
8585
return version;
8686
}
87+
int CudaComputeCapability(int device) {
88+
int major, minor;
89+
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
90+
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
91+
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
92+
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)));
93+
return major * 10 + minor;
94+
}
95+
96+
int CudaDeviceCount() {
97+
int device_count = 0;
98+
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0)));
99+
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count)));
100+
101+
return device_count;
102+
}
103+
87104

88105
} // namespace jax::cuda

jaxlib/cuda/versions_helpers.h

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ int CusolverGetVersion();
2929
int CublasGetVersion();
3030
int CusparseGetVersion();
3131
size_t CudnnGetVersion();
32+
int CudaComputeCapability(int);
33+
int CudaDeviceCount();
3234

3335
} // namespace jax::cuda
3436

0 commit comments

Comments
 (0)