This project is developed during the Google Summer of Code 2025.
Reinforcement Learning (RL) training in Brax operates at high performance, utilizing JAX for efficient parallelized training. However, due to the nature of JAX’s computation model, training occurs in a highly abstracted and batched manner, making it challenging to inspect agent behavior in real-time. Currently, users must either wait until training finishes to evaluate policies or extract rollout data manually, which is inefficient and restricts debugging capabilities. A Brax Training Viewer would empower users to visualize the evolution of the policy while training is ongoing, bridging the gap between RL research and practical interpretability.
The Brax Training Viewer offers an interactive, real-time visualization tool for monitoring reinforcement learning (RL) policies during training, utilizing the official PPO training function from Brax. This package allows users to synchronize the training with a MuJoCo-based viewer, showcasing the effect of the actions taken by the RL policy across numerous parallel environments. Users can benefit from real-time action updates, which enable them to track the evolution of the policy over time. A toggle option to enable or disable synchronization will be provided to enhance performance, allowing users to pause visualization and speed up the training process temporarily.
As this tool extends the official Brax PPO training function, it ensures seamless compatibility with current reinforcement learning pipelines. This allows for easy implementation without changing existing RL workflows, making it an excellent choice for researchers and practitioners needing to visualize, debug, and analyze RL policies in Brax with minimal configuration.
- Synchronized MuJoCo Visualization with Brax Training: Extract actions generated by the policy network at each training step and apply them to a parallel running MuJoCo simulation.
- Support Parallel Environments: The viewer will visualize multiple parallel agents, reflecting the exact training conditions in Brax.
- Allow Enable/Disable Rendering: Since the copy process from GPU to CPU slows down the training process, users can dynamically enable or disable the data transfer to have high-speed training without visual feedback or keep visualization synchronized at the cost of slower training.
- (Optionally) install Python virtual environment conda
- (Optionally) creat a virtual environment
conda create -n test python=3.10
conda activate test
cd
to the root folder of this repo- run
git submodule update --init --recursive
to pull modified brax library - run
pip install .
- run
pip install -r requirements.txt
- (Optionally) install Jax in hardware accelaration version
pip install -U "jax[cuda12]"
orpip install -U "jax[cuda11]"
orpip install -U "jax[tpu]"
- you can try examples in
demo/
folder. For example,python demo/cartpole.py
- open a web browser, go to page
http://127.0.0.1:8000/
for the viewer
UV is an extremely fast Python package installer and resolver, written in Rust.
cd
to the root folder of this repo- run
git submodule update --init --recursive
to pull modified brax library - Install uv
pip install uv
- Create a virtual environment
uv venv
- Activate the virtual environment
source .venv/bin/activate
- Install dependencies
uv pip install -r requirements.txt
- Install the project
uv pip install .
- You can try the examples provided in the
demo/
folder. For instance, runpython demo/training_example.py
- Open a web browser and navigate to
http://127.0.0.1:8000/
to see the viewer.
Integrating the viewer into your existing Brax training script is straightforward. The core idea is to instantiate the viewer, run it, initialize it with your environment, and pass it to the training function.
This is the simplest use case, ideal for standard Brax environments. The following example demonstrates how to visualize a humanoid
agent.
-
Import necessary modules:
from brax import envs from braxviewer.WebViewer import WebViewer from braxviewer.brax.brax.training.agents.ppo import train as ppo
-
Create the Viewer and Environment:
# Instantiate the viewer viewer = WebViewer() # Start the viewer server in the background viewer.run() # Get a standard Brax environment env = envs.get_environment(env_name='humanoid', backend='positional') # Initialize the viewer with the environment structure viewer.init(env)
-
Pass Viewer to the Training Function: Modify your
train_fn
to accept theviewer
object.# Get the training function for the 'humanoid' environment train_fn = ... # Your functools.partial or direct call to ppo.train # Pass the viewer instance to the train function make_inference_fn, params, _ = train_fn( environment=env, viewer=viewer )
To visualize multiple environments running in parallel, use WebViewerBatched
. This viewer arranges the agents in a 3D grid. The following example is based on cartpole_batched.py
.
-
Import modules:
from braxviewer.WebViewerBatched import WebViewerBatched from braxviewer.brax.brax.training.agents.ppo import train as ppo # Your custom environment class, e.g., CartPole
-
Define Grid and Create Viewer:
WebViewerBatched
requires information about the grid layout.num_parallel_envs = 8 grid_dims = (4, 2, 1) # 4 columns - x axis, 2 rows - y axis, 1 layer - z axis env_offset = (4.0, 4.0, 2.0) # Spacing between envs in (x, y, z) # Instantiate the batched viewer with grid info viewer = WebViewerBatched( grid_dims=grid_dims, env_offset=env_offset ) viewer.run()
-
Prepare a Concatenated Environment for Visualization: The viewer needs a single large environment definition that contains all parallel agents.
# Your robot XML. Examples can be seen from brax official repository:https://github.com/google/brax/tree/300b1079363894733fa1090c6bb055b881eb0ac1/brax/envs/assets xml_model = "..." # Create a concatenated XML string for all envs concatenated_xml = WebViewerBatched.concatenate_envs_xml( xml_string=xml_model, num_envs=num_parallel_envs, grid_dims=grid_dims, env_offset=env_offset ) # Create a temporary environment instance from this big XML for initialization env_for_visualization_init = CartPole(xml_model=concatenated_xml) # Initialize the viewer with the concatenated environment viewer.init(env_for_visualization_init)
-
Train on the Original Environment: The actual training still happens on the original, single-agent environment class. Brax's training function handles the parallelization internally.
# Training uses the original, single environment definition env_for_training = CartPole(xml_model=xml_model) train_fn = functools.partial( ppo.train, num_envs=num_parallel_envs, viewer=viewer, # ... other ppo parameters ) train_fn(environment=env_for_training)