Skip to content

Specify CPU backend after JAX loaded #27159

Answered by selamw1
cool-RR asked this question in Q&A
Discussion options

You must be logged in to vote

Even after JAX has been imported, you can still force it to use the CPU using the following methods:

  1. Set a Default Device for JAX Operations
  • You can explicitly set the default backend to the CPU by updating JAX’s configuration:
import jax

jax.config.update("jax_default_device", jax.devices("cpu")[0])

# Create an array and perform a computation on the CPU
x = jnp.ones((3, 3))
y = jnp.linalg.inv(x + jnp.eye(3))  
print(y.device) 
# TFRT_CPU_0
  1. Manually Specify CPU for Computation
  • If you want to execute specific computations on the CPU while keeping GPU active for other tasks, you can explicitly place tensors and computations on the CPU:
import jax
import jax.numpy as jnp

cpu_device =

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@cool-RR
Comment options

Answer selected by cool-RR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants