Skip to content

Isharkii/cartpole-game-project

Repository files navigation

Production-Quality Reinforcement Learning Framework

Python 3.8+ PyTorch License: MIT

A complete, modular, and production-ready reinforcement learning framework built with PyTorch and Gymnasium. Supports DQN, PPO, and A2C algorithms with parallel environment training and interactive visualization.

Features

  • Multiple RL Algorithms

    • Deep Q-Network (DQN) with Double DQN and Dueling DQN variants
    • Proximal Policy Optimization (PPO) with GAE and clipping
    • Advantage Actor-Critic (A2C) with parallel environments
  • Production-Ready Components

    • Vectorized environments for parallel training
    • Comprehensive logging with TensorBoard integration
    • Automatic checkpointing and model saving
    • Evaluation and visualization tools
    • Command-line interface for easy usage
  • Best Practices

    • Modular and extensible architecture
    • Type hints and documentation
    • Configuration management via YAML
    • Reproducible experiments with seed control
    • GPU/CPU/MPS device support

Installation

Quick Install

pip install -r requirements.txt
pip install -e .

Development Install

pip install -r requirements.txt
pip install -e ".[dev]"

Quick Start

1. Train an Agent

Train a PPO agent on CartPole:

python examples/train_cartpole.py

This will train the agent and display progress. Training takes about 1-2 minutes.

2. Watch Your Trained Agent (with Pygame!)

After training, watch your agent play with a beautiful graphical interface:

python play_trained_agent.py logs/cartpole_ppo_example

This opens an interactive pygame window showing:

  • Live agent gameplay with smooth 30 FPS
  • Real-time reward and step counters
  • Episode statistics and averages
  • Press SPACE to continue between episodes
  • Press ESC to quit

3. Visualize Training Progress

View training curves and metrics:

python visualize_results.py logs/cartpole_ppo_example

Create a demo video:

python visualize_results.py logs/cartpole_ppo_example --show-agent --save-video demo.mp4

Advanced Training

Train a DQN agent on LunarLander:

python examples/train_lunarlander.py

Or use custom configurations:

python -m rl_framework.cli train --config configs/cartpole_dqn.yaml

Generate summary report:

python -m rl_framework.cli plot \
  --log-dir logs/cartpole_dqn \
  --summary \
  --output summary.png

Project Structure

rl-gym-framework/
├── src/
│   └── rl_framework/
│       ├── __init__.py
│       ├── config.py              # Configuration management
│       ├── trainer.py             # Training pipeline
│       ├── evaluator.py           # Evaluation utilities
│       ├── visualization.py       # Plotting tools
│       ├── cli.py                 # Command-line interface
│       ├── agents/
│       │   ├── __init__.py
│       │   ├── base.py            # Base agent class
│       │   ├── dqn.py             # DQN agent
│       │   ├── ppo.py             # PPO agent
│       │   └── a2c.py             # A2C agent
│       ├── networks/
│       │   ├── __init__.py
│       │   ├── base.py            # Network utilities
│       │   ├── dqn_network.py     # Q-networks
│       │   └── policy_network.py  # Actor-Critic networks
│       ├── envs/
│       │   ├── __init__.py
│       │   ├── wrappers.py        # Environment wrappers
│       │   ├── vec_env.py         # Vectorized environments
│       │   └── env_factory.py     # Environment creation
│       └── utils/
│           ├── __init__.py
│           ├── replay_buffer.py   # Experience replay
│           ├── logger.py          # Logging utilities
│           └── running_stats.py   # Statistics tracking
├── configs/                       # Configuration files
│   ├── cartpole_dqn.yaml
│   ├── cartpole_ppo.yaml
│   ├── lunarlander_dqn.yaml
│   ├── lunarlander_ppo.yaml
│   └── acrobot_a2c.yaml
├── tests/                         # Unit tests
├── examples/                      # Example scripts
├── requirements.txt
├── setup.py
├── pyproject.toml
└── README.md

Configuration

Configuration files use YAML format. Example structure:

algorithm: PPO
experiment_name: my_experiment

env:
  name: CartPole-v1
  num_envs: 4
  gamma: 0.99

network:
  hidden_dims: [64, 64]
  activation: tanh

ppo:
  n_steps: 2048
  batch_size: 64
  learning_rate: 0.0003
  clip_range: 0.2

training:
  total_timesteps: 200000
  eval_frequency: 10000
  seed: 42
  device: auto

Algorithms

DQN (Deep Q-Network)

Off-policy value-based method. Supports:

  • Double DQN for reduced overestimation
  • Dueling DQN for better value estimation
  • Prioritized replay (optional)

Best for: Discrete action spaces, sample-efficient learning

PPO (Proximal Policy Optimization)

On-policy policy gradient method. Features:

  • Clipped surrogate objective
  • Generalized Advantage Estimation (GAE)
  • Multiple epochs per batch
  • KL divergence early stopping

Best for: Both discrete and continuous actions, stable training

A2C (Advantage Actor-Critic)

Synchronous on-policy method. Features:

  • Parallel environment execution
  • Advantage estimation
  • Shared actor-critic network

Best for: Fast training with parallel envs, discrete/continuous actions

Advanced Usage

Custom Configuration

Override config parameters via CLI:

python -m rl_framework.cli train \
  --config configs/cartpole_ppo.yaml \
  --total-timesteps 500000 \
  --seed 123 \
  --experiment-name my_run

Multi-Environment Training

Configure parallel environments in YAML:

env:
  num_envs: 8  # Run 8 parallel environments

Custom Networks

Modify network architecture in config:

network:
  hidden_dims: [256, 256, 128]
  activation: relu
  use_layer_norm: true

Video Recording

Record agent gameplay:

python -m rl_framework.cli eval \
  --config configs/lunarlander_ppo.yaml \
  --checkpoint checkpoints/lunarlander_ppo/checkpoint_1000000 \
  --record-video videos/lunarlander.mp4

Performance Tips

  1. Use Parallel Environments: For on-policy methods (PPO, A2C), increase num_envs
  2. GPU Acceleration: Set device: cuda for large networks
  3. Hyperparameter Tuning: Start with provided configs and adjust
  4. Observation Normalization: Enable for continuous state spaces
  5. Frame Stacking: Use for partial observability

Example Training Times

Hardware: NVIDIA RTX 3080, AMD Ryzen 9 5900X

Environment Algorithm Timesteps Time Final Reward
CartPole-v1 DQN 100K ~2 min 500
CartPole-v1 PPO 200K ~3 min 500
LunarLander-v2 DQN 500K ~15 min 250+
LunarLander-v2 PPO 1M ~20 min 280+

API Usage

Python API

from rl_framework import Config, Trainer

# Load configuration
config = Config.from_yaml("configs/cartpole_ppo.yaml")

# Customize if needed
config.training.total_timesteps = 500000
config.training.seed = 42

# Create and run trainer
trainer = Trainer(config)
trainer.train()
trainer.close()

Custom Training Loop

from rl_framework.agents import PPOAgent
from rl_framework.envs import make_vec_env
from rl_framework.config import Config

config = Config.from_yaml("configs/cartpole_ppo.yaml")

# Create environment
env = make_vec_env(config.env, seed=42)

# Create agent
agent = PPOAgent(
    observation_dim=env.observation_space.shape[0],
    action_dim=env.action_space.n,
    ppo_config=config.ppo,
    network_config=config.network,
    env_config=config.env,
    device="cuda",
)

# Training loop
obs = env.reset()
for step in range(1000000):
    actions, values, log_probs = agent.select_action(obs)
    obs, rewards, dones, infos = env.step(actions)
    # ... collect rollouts and train

Monitoring Training

TensorBoard

Launch TensorBoard to monitor training:

tensorboard --logdir logs/

Metrics tracked:

  • Episode rewards and lengths
  • Loss values (policy, value, entropy)
  • Learning rates
  • Exploration rates (DQN)
  • KL divergence (PPO)

Console Output

Real-time progress with:

  • Training steps and progress percentage
  • Steps per second (SPS)
  • Recent evaluation rewards

Testing

Run unit tests:

pytest tests/

With coverage:

pytest tests/ --cov=src/rl_framework --cov-report=html

Supported Environments

Tested with Gymnasium environments:

  • Classic Control: CartPole, Acrobot, MountainCar, Pendulum
  • Box2D: LunarLander, BipedalWalker, CarRacing
  • Atari: All Atari games (with ALE-py)
  • Custom: Easily add your own Gym-compatible environments

Troubleshooting

Common Issues

Issue: Out of memory errors Solution: Reduce batch_size, buffer_size, or num_envs

Issue: Training unstable/diverging Solution: Reduce learning_rate, enable normalize_obs, adjust clip_range

Issue: Slow training Solution: Increase num_envs, use GPU (device: cuda), reduce eval_frequency

Issue: Poor performance Solution: Tune hyperparameters, increase total_timesteps, check reward shaping

Citation

If you use this framework in your research, please cite:

@software{rl_gym_framework,
  title = {Production-Quality Reinforcement Learning Framework},
  author = {Your Name},
  year = {2024},
  url = {https://github.com/yourusername/rl-gym-framework}
}

Contributing

Contributions are welcome! Please:

  1. Fork the repository
  2. Create a feature branch
  3. Add tests for new features
  4. Submit a pull request

License

MIT License - see LICENSE file for details

Acknowledgments

  • Built on PyTorch and Gymnasium
  • Inspired by Stable-Baselines3 and CleanRL
  • Algorithm implementations based on original papers

Resources

Support

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages