Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pmap-ing slower with new CPU runtime #26616

Open
lockwo opened this issue Feb 19, 2025 · 0 comments
Open

pmap-ing slower with new CPU runtime #26616

lockwo opened this issue Feb 19, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@lockwo
Copy link
Contributor

lockwo commented Feb 19, 2025

Description

Something I noticed while using diffrax, was that the adaptive solvers where much slower using pmaping integration with the new runtime on CPUs (pmaping is used over sharding for this reason #26586). I adapted the code from the aforementioned issue to also show that pmap-ing is slower on the new runtime.

import os
import multiprocessing as mp

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    mp.cpu_count()
)

import jax
import jax.numpy as jnp

def solve(init, key):
    def inner_loop_cond(state):
        t, y, _ = state
        return y.squeeze() < 2

    def inner_loop_body(state):
        t, y, theta = state
        return (t + 0.1, y + 0.1, theta)
    
    def outer_loop_cond(state):
        _, _, _, count = state
        return count < 5000
    
    def outer_loop_body(state):
        t, y, theta, count = state
        y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
        new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
        return (new_t, new_y, theta, count + 1)

    inner_while_loop = jax.lax.while_loop
    outer_while_loop = jax.lax.while_loop
    theta = 5.0
    t_initial = 0.0
    y_initial = init
    count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
    final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
    return final_state[1]

batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)

num_devices = len(jax.devices())

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

fn = jax.jit(jax.vmap(solve))
pmap_fn = jax.pmap(fn)

_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()

import time

start_time = time.time()
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()
end_time = time.time()

elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.6f} seconds")

with 0.4.31
Elapsed time: 0.002367 seconds

with 0.4.33
Elapsed time: 0.015590 seconds

with 0.5.0
Elapsed time: 0.018911 seconds

This example is of course trivial, but represents the core subroutine of adaptive SDE solvers. Currently this can be solved by disabling the new CPU thunk runtime, but I'm just reporting it so hopefully it can be fixed in the future :).

System info (python version, jaxlib version, accelerator, etc.)

multiple jax version, CPU, tested on Mac and colab

@lockwo lockwo added the bug Something isn't working label Feb 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant