A complete, modular, and production-ready reinforcement learning framework built with PyTorch and Gymnasium. Supports DQN, PPO, and A2C algorithms with parallel environment training and interactive visualization.
-
Multiple RL Algorithms
- Deep Q-Network (DQN) with Double DQN and Dueling DQN variants
- Proximal Policy Optimization (PPO) with GAE and clipping
- Advantage Actor-Critic (A2C) with parallel environments
-
Production-Ready Components
- Vectorized environments for parallel training
- Comprehensive logging with TensorBoard integration
- Automatic checkpointing and model saving
- Evaluation and visualization tools
- Command-line interface for easy usage
-
Best Practices
- Modular and extensible architecture
- Type hints and documentation
- Configuration management via YAML
- Reproducible experiments with seed control
- GPU/CPU/MPS device support
pip install -r requirements.txt
pip install -e .pip install -r requirements.txt
pip install -e ".[dev]"Train a PPO agent on CartPole:
python examples/train_cartpole.pyThis will train the agent and display progress. Training takes about 1-2 minutes.
After training, watch your agent play with a beautiful graphical interface:
python play_trained_agent.py logs/cartpole_ppo_exampleThis opens an interactive pygame window showing:
- Live agent gameplay with smooth 30 FPS
- Real-time reward and step counters
- Episode statistics and averages
- Press SPACE to continue between episodes
- Press ESC to quit
View training curves and metrics:
python visualize_results.py logs/cartpole_ppo_exampleCreate a demo video:
python visualize_results.py logs/cartpole_ppo_example --show-agent --save-video demo.mp4Train a DQN agent on LunarLander:
python examples/train_lunarlander.pyOr use custom configurations:
python -m rl_framework.cli train --config configs/cartpole_dqn.yamlGenerate summary report:
python -m rl_framework.cli plot \
--log-dir logs/cartpole_dqn \
--summary \
--output summary.pngrl-gym-framework/
├── src/
│ └── rl_framework/
│ ├── __init__.py
│ ├── config.py # Configuration management
│ ├── trainer.py # Training pipeline
│ ├── evaluator.py # Evaluation utilities
│ ├── visualization.py # Plotting tools
│ ├── cli.py # Command-line interface
│ ├── agents/
│ │ ├── __init__.py
│ │ ├── base.py # Base agent class
│ │ ├── dqn.py # DQN agent
│ │ ├── ppo.py # PPO agent
│ │ └── a2c.py # A2C agent
│ ├── networks/
│ │ ├── __init__.py
│ │ ├── base.py # Network utilities
│ │ ├── dqn_network.py # Q-networks
│ │ └── policy_network.py # Actor-Critic networks
│ ├── envs/
│ │ ├── __init__.py
│ │ ├── wrappers.py # Environment wrappers
│ │ ├── vec_env.py # Vectorized environments
│ │ └── env_factory.py # Environment creation
│ └── utils/
│ ├── __init__.py
│ ├── replay_buffer.py # Experience replay
│ ├── logger.py # Logging utilities
│ └── running_stats.py # Statistics tracking
├── configs/ # Configuration files
│ ├── cartpole_dqn.yaml
│ ├── cartpole_ppo.yaml
│ ├── lunarlander_dqn.yaml
│ ├── lunarlander_ppo.yaml
│ └── acrobot_a2c.yaml
├── tests/ # Unit tests
├── examples/ # Example scripts
├── requirements.txt
├── setup.py
├── pyproject.toml
└── README.md
Configuration files use YAML format. Example structure:
algorithm: PPO
experiment_name: my_experiment
env:
name: CartPole-v1
num_envs: 4
gamma: 0.99
network:
hidden_dims: [64, 64]
activation: tanh
ppo:
n_steps: 2048
batch_size: 64
learning_rate: 0.0003
clip_range: 0.2
training:
total_timesteps: 200000
eval_frequency: 10000
seed: 42
device: autoOff-policy value-based method. Supports:
- Double DQN for reduced overestimation
- Dueling DQN for better value estimation
- Prioritized replay (optional)
Best for: Discrete action spaces, sample-efficient learning
On-policy policy gradient method. Features:
- Clipped surrogate objective
- Generalized Advantage Estimation (GAE)
- Multiple epochs per batch
- KL divergence early stopping
Best for: Both discrete and continuous actions, stable training
Synchronous on-policy method. Features:
- Parallel environment execution
- Advantage estimation
- Shared actor-critic network
Best for: Fast training with parallel envs, discrete/continuous actions
Override config parameters via CLI:
python -m rl_framework.cli train \
--config configs/cartpole_ppo.yaml \
--total-timesteps 500000 \
--seed 123 \
--experiment-name my_runConfigure parallel environments in YAML:
env:
num_envs: 8 # Run 8 parallel environmentsModify network architecture in config:
network:
hidden_dims: [256, 256, 128]
activation: relu
use_layer_norm: trueRecord agent gameplay:
python -m rl_framework.cli eval \
--config configs/lunarlander_ppo.yaml \
--checkpoint checkpoints/lunarlander_ppo/checkpoint_1000000 \
--record-video videos/lunarlander.mp4- Use Parallel Environments: For on-policy methods (PPO, A2C), increase
num_envs - GPU Acceleration: Set
device: cudafor large networks - Hyperparameter Tuning: Start with provided configs and adjust
- Observation Normalization: Enable for continuous state spaces
- Frame Stacking: Use for partial observability
Hardware: NVIDIA RTX 3080, AMD Ryzen 9 5900X
| Environment | Algorithm | Timesteps | Time | Final Reward |
|---|---|---|---|---|
| CartPole-v1 | DQN | 100K | ~2 min | 500 |
| CartPole-v1 | PPO | 200K | ~3 min | 500 |
| LunarLander-v2 | DQN | 500K | ~15 min | 250+ |
| LunarLander-v2 | PPO | 1M | ~20 min | 280+ |
from rl_framework import Config, Trainer
# Load configuration
config = Config.from_yaml("configs/cartpole_ppo.yaml")
# Customize if needed
config.training.total_timesteps = 500000
config.training.seed = 42
# Create and run trainer
trainer = Trainer(config)
trainer.train()
trainer.close()from rl_framework.agents import PPOAgent
from rl_framework.envs import make_vec_env
from rl_framework.config import Config
config = Config.from_yaml("configs/cartpole_ppo.yaml")
# Create environment
env = make_vec_env(config.env, seed=42)
# Create agent
agent = PPOAgent(
observation_dim=env.observation_space.shape[0],
action_dim=env.action_space.n,
ppo_config=config.ppo,
network_config=config.network,
env_config=config.env,
device="cuda",
)
# Training loop
obs = env.reset()
for step in range(1000000):
actions, values, log_probs = agent.select_action(obs)
obs, rewards, dones, infos = env.step(actions)
# ... collect rollouts and trainLaunch TensorBoard to monitor training:
tensorboard --logdir logs/Metrics tracked:
- Episode rewards and lengths
- Loss values (policy, value, entropy)
- Learning rates
- Exploration rates (DQN)
- KL divergence (PPO)
Real-time progress with:
- Training steps and progress percentage
- Steps per second (SPS)
- Recent evaluation rewards
Run unit tests:
pytest tests/With coverage:
pytest tests/ --cov=src/rl_framework --cov-report=htmlTested with Gymnasium environments:
- Classic Control: CartPole, Acrobot, MountainCar, Pendulum
- Box2D: LunarLander, BipedalWalker, CarRacing
- Atari: All Atari games (with ALE-py)
- Custom: Easily add your own Gym-compatible environments
Issue: Out of memory errors
Solution: Reduce batch_size, buffer_size, or num_envs
Issue: Training unstable/diverging
Solution: Reduce learning_rate, enable normalize_obs, adjust clip_range
Issue: Slow training
Solution: Increase num_envs, use GPU (device: cuda), reduce eval_frequency
Issue: Poor performance
Solution: Tune hyperparameters, increase total_timesteps, check reward shaping
If you use this framework in your research, please cite:
@software{rl_gym_framework,
title = {Production-Quality Reinforcement Learning Framework},
author = {Your Name},
year = {2024},
url = {https://github.com/yourusername/rl-gym-framework}
}Contributions are welcome! Please:
- Fork the repository
- Create a feature branch
- Add tests for new features
- Submit a pull request
MIT License - see LICENSE file for details
- Built on PyTorch and Gymnasium
- Inspired by Stable-Baselines3 and CleanRL
- Algorithm implementations based on original papers
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Email: your.email@example.com