Skip to content

Commit

Permalink
Fix benchmark script
Browse files Browse the repository at this point in the history
  • Loading branch information
StafaH committed Nov 15, 2024
1 parent 6723ce8 commit b0d0332
Showing 1 changed file with 54 additions and 23 deletions.
77 changes: 54 additions & 23 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
"""Variation of mjx testspeed for benchmarking batch renderings"""

import os
import time
from typing import Tuple
import jax
import jax.numpy as jp
import numpy as np
import mujoco
from mujoco import mjx
from mujoco.mjx._src.test_util import _measure
from mujoco.mjx._src import io
from mujoco.mjx._src import forward
from etils import epath
import functools

from madrona_mjx.renderer import BatchRenderer

Expand Down Expand Up @@ -47,7 +51,6 @@ def load_model(path: str):
model = mujoco.MjModel.from_xml_string(xml, assets)
return model


def benchmark(
m: mujoco.MjModel,
nstep: int = 1000,
Expand Down Expand Up @@ -81,37 +84,65 @@ def benchmark(
np.array([0, 1, 2]), False, args.use_raytracer,
None)

@jax.pmap
def init(key):
key = jax.random.split(key, batch_size // jax.device_count())
rng = jax.random.PRNGKey(seed=2)
rng, *key = jax.random.split(rng, args.num_worlds + 1)

def dr(sys, rng):
"""Randomizes the mjx.Model."""
@jax.vmap
def random_init(key):
d = io.make_data(m)
qvel = 0.01 * jax.random.normal(key, shape=(m.nv,))
d = d.replace(qvel=qvel)
rt, rgb, depth = renderer.init(mjx_data)
return d, rt
def rand(rng):
rng, color_rng = jax.random.split(rng, 2)
new_color = jax.random.uniform(color_rng, (1,), minval=0.0, maxval=0.4)
geom_rgba = sys.geom_rgba.at[0, 0:1].set(new_color)
geom_matid = sys.geom_matid.at[:].set(-1)
geom_matid = geom_matid.at[0].set(-2)

return geom_rgba, geom_matid

geom_rgba, geom_matid = rand(rng)

return random_init(key)
in_axes = jax.tree_util.tree_map(lambda x: None, sys)
in_axes = in_axes.tree_replace({
'geom_rgba': 0,
'geom_matid': 0,
})

key = jax.random.split(jax.random.key(0), jax.device_count())
d, rt = init(key)
jax.block_until_ready(d)
sys = sys.tree_replace({
'geom_rgba': geom_rgba,
'geom_matid': geom_matid,
})

@jax.pmap
def unroll(d):
return sys, in_axes

randomization_rng = jax.random.split(rng, args.num_worlds)
v_randomization_fn = functools.partial(dr, rng=randomization_rng)

v_mjx_model, v_in_axes = v_randomization_fn(m)

def init(rng, sys):
def init_(rng, sys):
data = mjx.make_data(sys)
data.replace(qpos=0.01 * jax.random.uniform(rng, shape=(sys.nq,)))
data = mjx.forward(sys, data)
render_token, rgb, depth = renderer.init(data, sys)
return data, render_token, rgb, depth
return jax.vmap(init_, in_axes=[0, v_in_axes])(rng, sys)

v_mjx_data, render_token, rgb, depth = init(jp.asarray(key), v_mjx_model)
jax.block_until_ready(v_mjx_data)

@jax.jit
def unroll(v_data):
@jax.vmap
def step(d, _):
d = forward.step(m, d)
_, rgb, depth = renderer.render(rt, d)
return d, rgb, None

d, rgb, _ = jax.lax.scan(step, d, None, length=nstep, unroll=unroll_steps)
_, rgb, depth = renderer.render(render_token, d)
return d, None
d, _ = jax.lax.scan(step, v_data, None, length=nstep, unroll=unroll_steps)

return d, rgb
return d

jit_time, run_time = _measure(unroll, d)
jit_time, run_time = _measure(unroll, v_mjx_data)
steps = nstep * batch_size

return jit_time, run_time, steps
Expand All @@ -122,7 +153,7 @@ def step(d, _):
model = load_model(args.mjcf)

print(f'Rolling out {args.nstep} steps at dt = {model.opt.timestep:.3f}...')
jit_time, run_time, steps = mjx.benchmark(
jit_time, run_time, steps = benchmark(
model,
args.nstep,
args.num_worlds,
Expand Down

0 comments on commit b0d0332

Please sign in to comment.