Skip to content

Offline RL benchmark project featuring a custom Gym environment, dual observation modes, reward shaping, and real-time PyGame rendering

Notifications You must be signed in to change notification settings

mansoor-mamnoon/offline-rl-agent

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

40 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

๐Ÿš€ Project Overview

NeuroQuant Agent is a fully custom offline reinforcement learning benchmark, built from the ground up with real-time constraints, compression-aware inference, and deployment to latency-constrained environments.

The project begins with a custom-built 10ร—10 gridworld environment that supports:

  • ๐Ÿ” Directional movement: The agent can turn left, go forward, or turn right relative to its current orientation.
  • ๐Ÿ‘๏ธ Partial observability: Instead of seeing the entire map, the agent receives a 3ร—3 view centered around its position.
  • โ›” Obstacles: Impassable wall tiles block the agent's path and require navigation.
  • ๐ŸŽฏ Goal tile: A single terminal state gives a large positive reward when reached, ending the episode.
  • ๐Ÿ–ฅ๏ธ Real-time PyGame rendering: Each simulation step is rendered at a capped 10 FPS for visual inspection and timing fidelity.

This environment is used as the basis for:

  • Generating offline replay buffers
  • Training offline RL agents using CQL, BCQ, or TD3+BC
  • Benchmarking model compression tradeoffs (quantization, pruning, distillation)
  • Real-time deployment of agents under latency and memory constraints

๐Ÿง  Environment Design

The environment is a 10ร—10 gridworld with directional agent movement, obstacles, and a single terminal goal. Key features:

  • ๐Ÿ” Action space: Turn left, move forward, turn right (relative to current orientation)
  • ๐Ÿ‘๏ธ Partial observability: Agent receives a 3ร—3 window centered on its current location
  • ๐Ÿ”ข Dual observation modes:
    • Image: 3ร—3 local grid (int matrix)
    • Vector: Agent position and goal coordinates as a flat vector
  • ๐ŸŽฏ Reward structure:
    • +10 for reaching the goal (sparse)
    • -0.1 per step (dense penalty)
  • โ›” Obstacles: Defined in the grid and block movement
  • ๐Ÿ–ฅ๏ธ Real-time rendering: PyGame visualization at 10 FPS

๐Ÿง  Replay Buffer Generation

We simulate random or scripted agents in the custom Gridworld environment to collect experience data for offline RL training.

Each transition includes:

  • observation
  • action
  • reward
  • next_observation
  • done

These transitions are saved into a compressed .npz buffer (dataset/replay_buffer.npz), which can later be loaded for training Conservative Q-Learning (CQL), TD3+BC, or BCQ agents.

Additional outputs include:

  • โœ… Episode metadata (average reward, length, and total transitions) saved to dataset/metadata.txt
  • ๐Ÿ“Š A histogram of reward distribution over episodes saved to dataset/reward_histogram.png

To generate the dataset, run:

python dataset/collect.py --episodes 100

This will generate 10k+ transitions across 100 episodes using a random policy.


๐Ÿ“ Project Structure

offline-rl-agent/
โ”‚
โ”œโ”€โ”€ env/                    # Custom Gym environment (NeuroQuantEnv)
โ”‚   โ””โ”€โ”€ neuroquant_env.py
โ”‚
โ”œโ”€โ”€ dataset/                # Replay buffer collection + visualizations
โ”‚   โ”œโ”€โ”€ collect.py          # Random/scripted policy buffer generation
โ”‚   โ”œโ”€โ”€ viz.py              # t-SNE, reward, and action plots
โ”‚   โ”œโ”€โ”€ replay_buffer.npz   # (gitignored) Collected transitions
โ”‚   โ”œโ”€โ”€ reward_histogram.png
โ”‚   โ”œโ”€โ”€ metadata.txt
โ”‚
โ”œโ”€โ”€ docs/
โ”‚   โ””โ”€โ”€ plots/              # Visual outputs of dataset
โ”‚       โ”œโ”€โ”€ tsne_obs.png
โ”‚       โ”œโ”€โ”€ action_distribution.png
โ”‚       โ””โ”€โ”€ episode_rewards.png
โ”‚
โ”œโ”€โ”€ .gitignore
โ”œโ”€โ”€ README.md
โ””โ”€โ”€ run_env_test.py         # Debug script to manually interact with env

๐Ÿ“Š Dataset Visualizations

We visualize the replay buffer to verify coverage and distribution:

These plots are generated via:

python dataset/viz.py

๐Ÿ“ฆ Getting Started

# 1. Clone and enter the repo
git clone https://github.com/mansoor-mamnoon/offline-rl-agent.git
cd offline-rl-agent

# 2. Set up virtual environment
python3 -m venv .venv
source .venv/bin/activate

# 3. Install dependencies
pip install -r requirements.txt

# 4. Run environment manually
python env/run_env_test.py

# 5. Collect dataset
python dataset/collect.py --episodes 100

# 6. Visualize dataset
python dataset/viz.py

Training the CQL Agent

We implement a Conservative Q-Learning (CQL) agent using PyTorch. The agent is trained offline on a replay buffer generated from scripted or random policy.

Key Features:

  • Vector observation space (4D: [agent_x, agent_y, goal_x, goal_y])
  • Discrete action space with 3 actions
  • Bellman loss, conservative loss, and optional behavior cloning (BC) loss

Run training:

python agent/train.py

Training logs print loss components every 100 epochs.


Let me know if you'd like to tune hyperparameters or visualize learning curves next!

๐Ÿง  Training Loss Visualization

Below is the training loss of the Conservative Q-Learning (CQL) agent across 1000 epochs:

CQL Training Losses

  • Bellman Loss measures TD error between predicted Q and target Q.
  • Conservative Loss regularizes Q-values to avoid overestimation.
  • Behavior Cloning Loss aligns the policy to dataset behavior.

These curves help validate that learning is progressing smoothly.

Logging, Evaluation, and Checkpointing

To monitor training progress and ensure the CQL agent is learning effectively, we implemented:

โœ… Features Added

  • ๐Ÿ” Evaluation Loop:
    • Every 100 epochs, the agent is evaluated on a held-out batch of offline transitions.
    • Evaluation metrics:
      • Policy Accuracy: how often the agent matches actions from the dataset.
      • Average Q-Value: the mean predicted return across sampled transitions.
  • ๐Ÿ“‰ Loss Logging:
    • Training losses logged per epoch:
      • Bellman loss (temporal difference)
      • Conservative loss (Q regularization)
      • Behavior cloning (BC) loss
  • ๐Ÿ’พ Checkpointing:
    • Automatically saves the q_net and policy when policy accuracy improves.
    • Saved to: checkpoints/best_q.pt and checkpoints/best_policy.pt
  • ๐Ÿ“Š TensorBoard Integration:

To run TensorBoard:

tensorboard --logdir=logs

You can monitor live training and evaluation updates in your browser at:
๐Ÿ‘‰ http://localhost:6006

๐Ÿ“‚ Files Modified

  • agent/train.py: Main training loop updated with:

    • Evaluation every 100 epochs
    • TensorBoard logging of loss and accuracy metrics
    • Checkpoint saving logic for best-performing policy
  • checkpoints/: Directory created to store .pt model weights


๐Ÿง  Model Compression: Quantization + Pruning

We implemented model compression techniques to reduce memory usage and inference latency of the offline RL agent without sacrificing reward. Two approaches were explored:

๐Ÿ”ง Techniques Used

  1. Static Quantization using PyTorch's torch.quantization pipeline.
  2. Structured Pruning (via torch.nn.utils.prune.ln_structured) to remove 30โ€“60% of neurons from linear layers.
  3. Unstructured Pruning (optional) to sparsify weights within layers for additional compression.
  4. Fine-tuning after pruning to recover performance.

๐Ÿ“ˆ Visualization of Tradeoffs

Each point below represents a model version โ€” plotted by reward and latency, with bubble size representing memory usage.

Compression Tradeoff

๐Ÿ’ป macOS Silicon (M1/M2) Warning

Static quantization (using torch.quantization.convert()) is not currently supported on macOS ARM (M1/M2 chips). You may see the following error:

NotImplementedError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). 'quantized::linear' is only available for these backends: [MPS, Meta, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMeta, Tracer, AutocastCPU, AutocastMTIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

โœ… Fix: Use Dynamic Quantization on macOS

To avoid this, switch to dynamic quantization, which works on macOS and still gives performance benefits on CPUs:

from torch.quantization import quantize_dynamic

quantized_model = quantize_dynamic(
    PolicyNetwork(state_dim=4, action_dim=3),
    {torch.nn.Linear},
    dtype=torch.qint8
)
quantized_model.load_state_dict(torch.load("checkpoints/best_policy.pt"))

No need for qconfig, prepare(), or convert() โ€” just quantize and run.

๐Ÿงช Results (Sample)

Model Reward Latency (ms) Memory (MB)
Original 8.30 0.23 341.86
Pruned 8.30 0.24 341.86
Quantized (Dynamic) ~8.30 ~0.20 ~330.00

With higher pruning ratios or quantization + pruning combinations, further improvements can be achieved.

๐Ÿ“‚ Code Locations

  • Compression logic: agent/compress.py
  • Evaluation and plotting: part of compress.py (runs automatically)
  • Trained models saved in: checkpoints/
  • Visualization saved to: docs/plots/compression_tradeoff.png

โ–ถ๏ธ Run it via:

python agent/compress.py

๐Ÿ” Distillation-Based Compression

We implement knowledge distillation to compress a large policy model (BigMLP) into a smaller student model (SmallMLP). The student is trained on soft labels from the teacherโ€™s output logits using KL divergence loss.

๐Ÿ“ˆ Distillation Loss Curve

Distillation Loss

KL Divergence between teacher and student softmax outputs across training epochs.

๐ŸŽฏ Final Reward of Student Policy

Student Reward

Average reward over 10 episodes after distillation. Evaluation performed using NeuroQuantEnv.


๐Ÿ“ Outputs

  • Trained student saved at: checkpoints/small_mlp_distilled.pt
  • Loss log: logs/distill_loss.log
  • Reward log: logs/student_reward.log

We built a real-time inference loop to benchmark the performance of compressed models in a live environment. The goal was to deploy a distilled model (SmallMLP) and verify whether it could sustain high-speed, low-latency decision-making under realistic constraints.

๐ŸŽฏ Objectives

  • Load a compressed, lightweight model (small_mlp_distilled.pt)
  • Step through NeuroQuantEnv in a real-time loop (โ‰ค 100ms per frame)
  • Log per-frame:
    • โœ… Inference latency (ms)
    • โœ… Memory usage (MB)
    • โœ… Actions taken
  • Display live FPS and latency in terminal
  • Save and plot performance metrics

๐Ÿ“ˆ Key Results

  • ๐Ÿง  Model: SmallMLP distilled from BigMLP
  • ๐Ÿš€ Average Latency: 1.08 ms
  • ๐ŸŽž๏ธ Average FPS: 929.04
  • ๐Ÿง  Memory Usage: ~194.55 MB

๐Ÿ–ฅ๏ธ Sample Output

๐ŸŽฎ Starting real-time inference loop...
[Frame 1] Latency: 1.15 ms | FPS: 869.19 | Mem: 194.42 MB | Action: 2
[Frame 10] Latency: 1.05 ms | FPS: 950.23 | Mem: 194.55 MB | Action: 0
...
๐ŸŽฏ Real-Time Inference Complete
๐Ÿ•’ Total Time: 1.88 s
๐Ÿ“ˆ Avg Latency: 1.08 ms | Avg FPS: 929.04

๐Ÿ“Š Real-Time Inference Visualizations

  • ๐Ÿ” Each point = one environment step
  • ๐Ÿ“‰ Latency remained stable across steps (~1ms)
  • ๐ŸŽฏ FPS consistently exceeded 900

๐Ÿ“‚ Files Involved

inference/run_realtime_inference.py   # Real-time engine
checkpoints/small_mlp_distilled.pt    # Compressed model
logs/day11_metrics.csv                # Per-frame metrics
docs/plots/day11_latency.png          # Latency graph
docs/plots/day11_fps.png              # FPS graph

We upgraded the inference engine to simulate a realistic deployment environment where latency spikes trigger automatic shutdown, and all runtime metrics are logged for post-analysis.

๐ŸŽฏ Objectives

  • โŒ Shutdown if inference latency > 150 ms
  • โœ… Log each frame with:
    • Timestamp
    • Inference Latency
    • Memory Usage
    • Cumulative Reward
  • โœ… Write to: results/session_X.csv

๐Ÿ“Š Logs Visualized

  • ๐Ÿ” Each point = 1 environment step
  • ๐Ÿšจ Inference aborted if latency >150ms
  • ๐Ÿง  All inference logs saved for replayability

๐Ÿ“‚ Files Involved

inference/run_realtime_inference.py     # Real-time agent w/ shutdown & logging
results/session_X.csv                   # Per-frame logs (auto-numbered)
scripts/plot_day12_session.py           # Plotting script
docs/plots/day12_latency.png            # Per-frame latency plot
docs/plots/day12_reward.png             # Cumulative reward plot

๐ŸŽฅ Demos + GIFs

The environment supports saving full episodes as GIFs using the render_episode_gif() function.

Sample run saved to docs/replays/test_run.gif: Sample Replay

About

Offline RL benchmark project featuring a custom Gym environment, dual observation modes, reward shaping, and real-time PyGame rendering

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •