Skip to content

SafeRL-Lab/BenchNetRL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

94 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Logo

BenchNetRL: The Right Network for the Right RL Task

πŸ’» Code Β· πŸ“„ Paper Β· 🚩 Issues


Directory Structure

BenchNetRL/
β”œβ”€β”€ README.md               # This file
β”œβ”€β”€ requirements.txt        # Python dependencies
β”œβ”€β”€ env_utils.py            # Environment wrappers and creators
β”œβ”€β”€ exp_utils.py            # Experiment argument parsing and logging utilities
β”œβ”€β”€ gae.py                  # Generalized Advantage Estimation implementation
β”œβ”€β”€ layers.py               # Neural network layer utilities and transformer modules
β”œβ”€β”€ ppo.py                  # Vanilla PPO implementation
β”œβ”€β”€ ppo_lstm.py             # PPO with LSTM / GRU recurrent policies
β”œβ”€β”€ ppo_mamba.py            # PPO with Mamba / Mamba-2 recurrent SSM
β”œβ”€β”€ ppo_trxl.py             # PPO with Transformer-XL (TrXL) / GTrXL memory
β”‚
β”œβ”€β”€ envs/                   # Custom environment implementations for quick memory tests
β”‚   β”œβ”€β”€ poc_memory_env.py    # Proof-of-concept memory environment (PocMemoryEnv)
β”‚   └── pom_env.py           # Proof-of-memory Gym environment (PoMEnv)
β”‚
└── scripts/                # Baseline experiment scripts
    └── baselines/
        β”œβ”€β”€ atari.sh         # Atari benchmark commands
        β”œβ”€β”€ classic_control.sh # Classic control benchmark commands
        β”œβ”€β”€ minigrid.sh      # MiniGrid benchmark commands
        └── mujoco.sh        # MuJoCo benchmark commands

βš™οΈ Installation

Clone the repository:

git clone https://github.com/SafeRL-Lab/BenchNetRL.git
cd BenchNetRL

Create a Python environment (recommended using conda or virtualenv):

python -m venv venv
source venv/bin/activate  # on Linux/Mac
venv\Scripts\activate   # on Windows

Install dependencies:

pip install -r requirements.txt

Ensuring CUDA Toolkit

Before installing CUDA-enabled PyTorch, make sure you have NVIDIA’s CUDA toolkit installed and your drivers up to date. You can download and install CUDA 12.4 from NVIDIA:

  1. Visit the CUDA Toolkit Archive: https://developer.nvidia.com/cuda-toolkit-archive

  2. Select CUDA Toolkit 12.4 for your operating system and follow the installation guide.

Install Mamba/SSM library:

The Mamba and Mamba2 recurrent state-space models are required for ppo_mamba.py and ppo_mamba2.py. These modules are not included in this repository and must be installed separately. Ensure you are on a Linux system with a compatible CUDA version.

Follow instructions from the Mamba repository to install Mamba:

git clone https://github.com/state-spaces/mamba.git

Note: Mamba requires Linux and specific CUDA drivers. Please refer to the Mamba repository for installation details and supported CUDA versions.

Fixing PyTorch CUDA build issues

If you encounter an AttributeError related to torch.cuda.reset_peak_memory_stats, it means you have a CPU-only or incompatible PyTorch build. To resolve:

Uninstall any existing torch packages

pip uninstall -y torch torchvision torchaudio

Reinstall CUDA-enabled PyTorch (matching your CUDA toolkit, e.g. 12.3):

pip install --index-url https://download.pytorch.org/whl/cu123 \
  torch torchvision torchaudio

Verify CUDA is available:

python - <<EOF

import torch
print("Torch version:", torch.version)
print("CUDA available:", torch.cuda.is_available())
EOF

Optional guard in ppo.py: in case some setups still miss the function, open ppo.py and replace:

torch.cuda.reset_peak_memory_stats()

with:

if torch.cuda.is_available() and hasattr(torch.cuda, "reset_peak_memory_stats"):
    torch.cuda.reset_peak_memory_stats()

πŸš€ Usage

Running the Experiments

Use the provided scripts under scripts/ours/ to launch our experiments. For example:

bash scripts/ours/atari.sh

Running Custom Experiments

Example command for PPO + Mamba on Breakout:

python ppo_mamba.py \
  --gym-id ALE/Breakout-v5 \
  --total-timesteps 10000000 \
  --num-envs 16 \
  --num-minibatches 8 \
  --hidden-dim 450 \
  --expand 1 \
  --track \
  --wandb-project-name atari-bench \
  --exp-name ppo_mamba

Replace the script name (ppo.py, ppo_lstm.py, ppo_mamba.py, ppo_trxl.py) and flags as needed.

File Descriptions

  • env_utils.py: Wraps Gym environments with preprocessing such as frame stacking, masking, video recording.

  • exp_utils.py: Command-line argument parsing and logging setup.

  • gae.py: Advantage and return computation (GAE).

  • layers.py: layer_init, attention modules, Transformer, SSM interfaces.

  • ppo.py: Various PPO implementations (vanilla, LSTM/GRU, Mamba, Mamba2, Transformer-XL).

  • envs/: Custom memory-focused Gym environments.

  • scripts/ours/: Shell scripts for reproducible benchmarks.

πŸ“ˆ Performance Metrics - Average Across 9 Environments

Architectures

  • PPO-1: Standard PPO with 1-frame observation (no frame stacking).
  • PPO-4: PPO with 4-frame observation stacking (temporal context via stacked frames).
  • LSTM, GRU, TrXL, GTrXL, Mamba, Mamba-2: Sequence-based models with varying architectures to capture temporal dependencies in environment dynamics.
Metric PPO-1 PPO-4 LSTM GRU TrXL GTrXL Mamba Mamba-2
Steps Per Second (↑) 3539 3305 604 701 1856 1890 2734 2455
Training Time (min) (↓) 16.59 18.84 121.90 91.04 30.33 29.42 21.20 22.97
Inference Latency (ms) (↓) 0.856 0.899 1.006 0.971 2.171 2.147 1.304 1.489
GPU Mem. Allocated (GB) (↓) 0.035 0.660 0.194 0.194 1.765 1.330 0.217 0.219
GPU Mem. Reserved (GB) (↓) 0.327 0.983 0.343 0.349 5.508 4.968 0.362 0.662

Below are key performance metrics visualized by architecture group.

🟦 PPO | 🟧 Classic Seq | 🟩 Transformers | πŸŸ₯ Mamba

πŸ“ Each architecture is color-coded by family for quick reference.

πŸ“Š Results

MuJoCo Environments


Atari Environments


MiniGrid Environments


OpenAI Gym Environments


πŸ“„ Citation

If you find the repository useful, please cite the study

@article{ivan2025benchnetrl,
  title={RLBenchNet: The Right Network for the Right Reinforcement Learning Task},
  author={Smirnov, Ivan and Gu, Shangding},
  journal={Arxiv},
  year={2025}
}

About

πŸ”₯Systematic Benchmarking of Neural Network Architectures in Reinforcement Learning.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •