Skip to content

pal-robotics/brax_training_viewer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Brax Training Viewer for Real-Time Policy Visualization

This project is developed during the Google Summer of Code 2025.

Reinforcement Learning (RL) training in Brax operates at high performance, utilizing JAX for efficient parallelized training. However, due to the nature of JAX’s computation model, training occurs in a highly abstracted and batched manner, making it challenging to inspect agent behavior in real-time. Currently, users must either wait until training finishes to evaluate policies or extract rollout data manually, which is inefficient and restricts debugging capabilities. A Brax Training Viewer would empower users to visualize the evolution of the policy while training is ongoing, bridging the gap between RL research and practical interpretability.

The Brax Training Viewer offers an interactive, real-time visualization tool for monitoring reinforcement learning (RL) policies during training, utilizing the official PPO training function from Brax. This package allows users to synchronize the training with a MuJoCo-based viewer, showcasing the effect of the actions taken by the RL policy across numerous parallel environments. Users can benefit from real-time action updates, which enable them to track the evolution of the policy over time. A toggle option to enable or disable synchronization will be provided to enhance performance, allowing users to pause visualization and speed up the training process temporarily.

As this tool extends the official Brax PPO training function, it ensures seamless compatibility with current reinforcement learning pipelines. This allows for easy implementation without changing existing RL workflows, making it an excellent choice for researchers and practitioners needing to visualize, debug, and analyze RL policies in Brax with minimal configuration.

Outcomes

  • Synchronized MuJoCo Visualization with Brax Training: Extract actions generated by the policy network at each training step and apply them to a parallel running MuJoCo simulation.
  • Support Parallel Environments: The viewer will visualize multiple parallel agents, reflecting the exact training conditions in Brax.
  • Allow Enable/Disable Rendering: Since the copy process from GPU to CPU slows down the training process, users can dynamically enable or disable the data transfer to have high-speed training without visual feedback or keep visualization synchronized at the cost of slower training.

Usage Conda

  • (Optionally) install Python virtual environment conda
  • (Optionally) creat a virtual environment conda create -n test python=3.10
  • conda activate test
  • cd to the root folder of this repo
  • run git submodule update --init --recursive to pull modified brax library
  • run pip install .
  • run pip install -r requirements.txt
  • (Optionally) install Jax in hardware accelaration version pip install -U "jax[cuda12]" or pip install -U "jax[cuda11]" or pip install -U "jax[tpu]"
  • you can try examples in demo/ folder. For example, python demo/cartpole.py
  • open a web browser, go to page http://127.0.0.1:8000/ for the viewer

Usage UV

UV is an extremely fast Python package installer and resolver, written in Rust.

  • cd to the root folder of this repo
  • run git submodule update --init --recursive to pull modified brax library
  • Install uv pip install uv
  • Create a virtual environment uv venv
  • Activate the virtual environment source .venv/bin/activate
  • Install dependencies uv pip install -r requirements.txt
  • Install the project uv pip install .
  • You can try the examples provided in the demo/ folder. For instance, run python demo/training_example.py
  • Open a web browser and navigate to http://127.0.0.1:8000/ to see the viewer.

How to use

Integrating the viewer into your existing Brax training script is straightforward. The core idea is to instantiate the viewer, run it, initialize it with your environment, and pass it to the training function.

For a Single Environment Visualization

This is the simplest use case, ideal for standard Brax environments. The following example demonstrates how to visualize a humanoid agent.

  1. Import necessary modules:

    from brax import envs
    from braxviewer.WebViewer import WebViewer
    from braxviewer.brax.brax.training.agents.ppo import train as ppo
  2. Create the Viewer and Environment:

    # Instantiate the viewer
    viewer = WebViewer()
    
    # Start the viewer server in the background
    viewer.run()
    
    # Get a standard Brax environment
    env = envs.get_environment(env_name='humanoid', backend='positional')
    
    # Initialize the viewer with the environment structure
    viewer.init(env)
  3. Pass Viewer to the Training Function: Modify your train_fn to accept the viewer object.

    # Get the training function for the 'humanoid' environment
    train_fn = ... # Your functools.partial or direct call to ppo.train
    
    # Pass the viewer instance to the train function
    make_inference_fn, params, _ = train_fn(
        environment=env,
        viewer=viewer
    )

For Batched (Parallel) Environment Visualization

To visualize multiple environments running in parallel, use WebViewerBatched. This viewer arranges the agents in a 3D grid. The following example is based on cartpole_batched.py.

  1. Import modules:

    from braxviewer.WebViewerBatched import WebViewerBatched
    from braxviewer.brax.brax.training.agents.ppo import train as ppo
    # Your custom environment class, e.g., CartPole
  2. Define Grid and Create Viewer: WebViewerBatched requires information about the grid layout.

    num_parallel_envs = 8
    grid_dims = (4, 2, 1)  # 4 columns - x axis, 2 rows - y axis, 1 layer - z axis
    env_offset = (4.0, 4.0, 2.0) # Spacing between envs in (x, y, z)
    
    # Instantiate the batched viewer with grid info
    viewer = WebViewerBatched(
        grid_dims=grid_dims,
        env_offset=env_offset
    )
    viewer.run()
  3. Prepare a Concatenated Environment for Visualization: The viewer needs a single large environment definition that contains all parallel agents.

    # Your robot XML. Examples can be seen from brax official repository:https://github.com/google/brax/tree/300b1079363894733fa1090c6bb055b881eb0ac1/brax/envs/assets
    xml_model = "..." 
    
    # Create a concatenated XML string for all envs
    concatenated_xml = WebViewerBatched.concatenate_envs_xml(
        xml_string=xml_model, 
        num_envs=num_parallel_envs, 
        grid_dims=grid_dims, 
        env_offset=env_offset
    )
    
    # Create a temporary environment instance from this big XML for initialization
    env_for_visualization_init = CartPole(xml_model=concatenated_xml)
    
    # Initialize the viewer with the concatenated environment
    viewer.init(env_for_visualization_init)
  4. Train on the Original Environment: The actual training still happens on the original, single-agent environment class. Brax's training function handles the parallelization internally.

    # Training uses the original, single environment definition
    env_for_training = CartPole(xml_model=xml_model)
    
    train_fn = functools.partial(
        ppo.train,
        num_envs=num_parallel_envs,
        viewer=viewer,
        # ... other ppo parameters
    )
    
    train_fn(environment=env_for_training)

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •