-
Is there a way to tell JAX to only use the CPU and not the GPU? I've previously used |
Beta Was this translation helpful? Give feedback.
Answered by
selamw1
Mar 17, 2025
Replies: 1 comment 1 reply
-
Even after JAX has been imported, you can still force it to use the CPU using the following methods:
import jax
jax.config.update("jax_default_device", jax.devices("cpu")[0])
# Create an array and perform a computation on the CPU
x = jnp.ones((3, 3))
y = jnp.linalg.inv(x + jnp.eye(3))
print(y.device)
# TFRT_CPU_0
import jax
import jax.numpy as jnp
cpu_device = jax.devices("cpu")[0]
# Explicitly place an array on the CPU
x = jax.device_put(jnp.array([1, 2, 3]), cpu_device)
# Run computation on CPU
with jax.default_device(cpu_device):
y = jnp.dot(x, x)
print(y.device)
# TFRT_CPU_0 |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
cool-RR
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Even after JAX has been imported, you can still force it to use the CPU using the following methods: