Jit-compatible MJX model randomization #2406
-
IntroHi! I am a masters' student at the University of Tübingen, Germany, and I use MuJoCo/MJX for my research on robotic manipulation for obstacle avoidance in a RL reaching task. I have implemented my own environment, similar to what was done in the MuJoCo MJX Playground, and I train the agent with Brax's PPO implementation. My setupMuJoCo/MJX 3.2.7, Brax 0.12.1, Python 3.12, Ubuntu 24.04 My questionGoalDo domain randomization on the initial state of the environment, namely the robot position and obstacle properties (position, rotation, size, shape). What is workingThe randomization of the initial robot joint positions and obstacle position + orientation is perfectly working for me. It does not change the model definition and I can set those on episode reset with qpos, qvel, mocap_pos and mocap_quat. What I struggle with and need helpRandomizing the obstacle size (and possibly shape - but I'll leave that out for this question) requires changing the model definition, which means I have to recompile the model. Sampling the size with jax.numpy and modifying the mjspec with it works fine, as long as I don't jit-compile the function with. However, since all this would be done in the reset function of my environment, which is jit-compiled, means, it will be jit-compiled. That makes sampling the size with jax.numpy and modifying the mjspec not possible. Ideas I came up with:
How would you approach this randomization, and do you have any idea on how I can do this? Or is it even possible in any way and I have to pre-sample and just live with the memory footprint? What would be my dream solution?A jit-compatible model compilation function to dynamically change the model between episodes and use it accelerated on GPU/TPU. Minimal model and/or code that explain my questionHere is a stripped down example of what I am trying to achieve. Code: import jax
import mujoco
from mujoco import mjx
rng1, rng2 = jax.random.split(jax.random.key(0))
def create_model_with_sampled_size(rng):
size = jax.random.uniform(rng, shape=(3,), minval=0.01, maxval=0.1)
spec = mujoco.MjSpec()
spec.worldbody.add_body(
mocap=True,
name="obstacle_body",
).add_geom(
name="obstacle_geom",
size=size,
)
model = mjx.put_model(spec.compile())
return model
# This is possible
model = create_model_with_sampled_size(rng1)
# This is not possible
fn = jax.jit(create_model_with_sampled_size)
model = fn(rng2) Confirmations
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Hello! One approach you could also consider is to simply have multiple objects in your scene already, and teleport only the currently relevant one in and leave the other far away (to make collision checks cheaper with it). Of course, this had many drawbacks. For one, it forces you to heavily discretize your distributions. |
Beta Was this translation helpful? Give feedback.
-
I finally implemented a randomization function, similar to how it is done in mujoco playground, which uses the parallelization of envs with mjx/brax to randomize the model. This solution has some limitations. The fields Here's a snippet of the code I used in case anyone stumbles across this discussion and is interested. Be careful because it contains a sampling function that I don't provide, and I'm accessing some variables in my environment.in_axes = jax.tree.map(lambda _: None, env.model)
in_axes = in_axes.tree_replace(
{
"stat.meaninertia": 0,
"stat.meansize": 0,
"stat.extent": 0,
"body_inertia": 0,
"dof_M0": 0,
"geom_size": 0,
"geom_rbound": 0,
# Theoretically, the following fields should be replaced as well as they can
# change, but it is not possible to vmap over them.
# "bvh_aabb": 0,
# "geom_aabb": 0,
# "geom_rbound_hfield": 0,
}
)
model_randomized: mjx.Model = env.model.tree_replace(
{
"stat.meaninertia": jnp.repeat(
jnp.expand_dims(env.model.stat.meaninertia, 0), num_envs, axis=0
),
"stat.meansize": jnp.repeat(
jnp.expand_dims(env.model.stat.meansize, 0), num_envs, axis=0
),
"stat.extent": jnp.repeat(
jnp.expand_dims(env.model.stat.extent, 0), num_envs, axis=0
),
"body_inertia": jnp.repeat(
jnp.expand_dims(env.model.body_inertia, 0), num_envs, axis=0
),
"dof_M0": jnp.repeat(
jnp.expand_dims(env.model.dof_M0, 0), num_envs, axis=0
),
"geom_size": jnp.repeat(
jnp.expand_dims(env.model.geom_size, 0), num_envs, axis=0
),
"geom_rbound": jnp.repeat(
jnp.expand_dims(env.model.geom_rbound, 0), num_envs, axis=0
),
}
)
spec = env.spec
geom = env.spec.geoms[env.obstacle_geom_ids[0]]
for idx in range(num_envs):
rng, rng1 = jax.random.split(rng, num=2)
size = obstacles.sample_size(
rng1, mjx.GeomType(geom.type), env.obstacle_size_limits
)
geom.size = size
mj_model = spec.compile()
model_randomized = model_randomized.tree_replace(
{
"stat.meaninertia": model_randomized.stat.meaninertia.at[idx].set(
mj_model.stat.meaninertia
),
"stat.meansize": model_randomized.stat.meansize.at[idx].set(
mj_model.stat.meansize
),
"stat.extent": model_randomized.stat.extent.at[idx].set(
mj_model.stat.extent
),
"body_inertia": model_randomized.body_inertia.at[idx].set(
mj_model.body_inertia
),
"dof_M0": model_randomized.dof_M0.at[idx].set(mj_model.dof_M0),
"geom_size": model_randomized.geom_size.at[idx].set(mj_model.geom_size),
"geom_rbound": model_randomized.geom_rbound.at[idx].set(
mj_model.geom_rbound
),
}
) Related issue: #1607 |
Beta Was this translation helpful? Give feedback.
I finally implemented a randomization function, similar to how it is done in mujoco playground, which uses the parallelization of envs with mjx/brax to randomize the model.
This solution has some limitations. The fields
bvh_aabb
,geom_aabb
andgeom_rbound_hfield
are not batchable, but change on recompile when I change the size of the object. However,bvh_aabb
is restricted to mujoco and therefore not used in MJX andgeom_rbound_hfield
is only accessed if a mjGEOM_HFIELD geometry is used, which I do not intend to do. Therefore, I am fine with them not being batchable. Forgeom_aabb
I'm not sure because I couldn't find any usage within MJX, but I guess I am fine with it as well.Here's a s…