NeuroQuant Agent is a fully custom offline reinforcement learning benchmark, built from the ground up with real-time constraints, compression-aware inference, and deployment to latency-constrained environments.
The project begins with a custom-built 10ร10 gridworld environment that supports:
- ๐ Directional movement: The agent can turn left, go forward, or turn right relative to its current orientation.
- ๐๏ธ Partial observability: Instead of seeing the entire map, the agent receives a 3ร3 view centered around its position.
- โ Obstacles: Impassable wall tiles block the agent's path and require navigation.
- ๐ฏ Goal tile: A single terminal state gives a large positive reward when reached, ending the episode.
- ๐ฅ๏ธ Real-time PyGame rendering: Each simulation step is rendered at a capped 10 FPS for visual inspection and timing fidelity.
This environment is used as the basis for:
- Generating offline replay buffers
- Training offline RL agents using CQL, BCQ, or TD3+BC
- Benchmarking model compression tradeoffs (quantization, pruning, distillation)
- Real-time deployment of agents under latency and memory constraints
The environment is a 10ร10 gridworld with directional agent movement, obstacles, and a single terminal goal. Key features:
- ๐ Action space: Turn left, move forward, turn right (relative to current orientation)
- ๐๏ธ Partial observability: Agent receives a 3ร3 window centered on its current location
- ๐ข Dual observation modes:
- Image: 3ร3 local grid (int matrix)
- Vector: Agent position and goal coordinates as a flat vector
- ๐ฏ Reward structure:
+10
for reaching the goal (sparse)-0.1
per step (dense penalty)
- โ Obstacles: Defined in the grid and block movement
- ๐ฅ๏ธ Real-time rendering: PyGame visualization at 10 FPS
We simulate random or scripted agents in the custom Gridworld environment to collect experience data for offline RL training.
Each transition includes:
observation
action
reward
next_observation
done
These transitions are saved into a compressed .npz
buffer (dataset/replay_buffer.npz
), which can later be loaded for training Conservative Q-Learning (CQL), TD3+BC, or BCQ agents.
Additional outputs include:
- โ
Episode metadata (average reward, length, and total transitions) saved to
dataset/metadata.txt
- ๐ A histogram of reward distribution over episodes saved to
dataset/reward_histogram.png
To generate the dataset, run:
python dataset/collect.py --episodes 100
This will generate 10k+ transitions across 100 episodes using a random policy.
offline-rl-agent/
โ
โโโ env/ # Custom Gym environment (NeuroQuantEnv)
โ โโโ neuroquant_env.py
โ
โโโ dataset/ # Replay buffer collection + visualizations
โ โโโ collect.py # Random/scripted policy buffer generation
โ โโโ viz.py # t-SNE, reward, and action plots
โ โโโ replay_buffer.npz # (gitignored) Collected transitions
โ โโโ reward_histogram.png
โ โโโ metadata.txt
โ
โโโ docs/
โ โโโ plots/ # Visual outputs of dataset
โ โโโ tsne_obs.png
โ โโโ action_distribution.png
โ โโโ episode_rewards.png
โ
โโโ .gitignore
โโโ README.md
โโโ run_env_test.py # Debug script to manually interact with env
We visualize the replay buffer to verify coverage and distribution:
- ๐ t-SNE of Observations: clusters state embeddings in 2D
- ๐ฎ Action Distribution: histogram over agent actions
- ๐ฏ Episode Reward Distribution: how returns are spread across episodes
These plots are generated via:
python dataset/viz.py
# 1. Clone and enter the repo
git clone https://github.com/mansoor-mamnoon/offline-rl-agent.git
cd offline-rl-agent
# 2. Set up virtual environment
python3 -m venv .venv
source .venv/bin/activate
# 3. Install dependencies
pip install -r requirements.txt
# 4. Run environment manually
python env/run_env_test.py
# 5. Collect dataset
python dataset/collect.py --episodes 100
# 6. Visualize dataset
python dataset/viz.py
We implement a Conservative Q-Learning (CQL) agent using PyTorch. The agent is trained offline on a replay buffer generated from scripted or random policy.
Key Features:
- Vector observation space (4D: [agent_x, agent_y, goal_x, goal_y])
- Discrete action space with 3 actions
- Bellman loss, conservative loss, and optional behavior cloning (BC) loss
Run training:
python agent/train.py
Training logs print loss components every 100 epochs.
Let me know if you'd like to tune hyperparameters or visualize learning curves next!
Below is the training loss of the Conservative Q-Learning (CQL) agent across 1000 epochs:
- Bellman Loss measures TD error between predicted Q and target Q.
- Conservative Loss regularizes Q-values to avoid overestimation.
- Behavior Cloning Loss aligns the policy to dataset behavior.
These curves help validate that learning is progressing smoothly.
To monitor training progress and ensure the CQL agent is learning effectively, we implemented:
- ๐ Evaluation Loop:
- Every 100 epochs, the agent is evaluated on a held-out batch of offline transitions.
- Evaluation metrics:
- Policy Accuracy: how often the agent matches actions from the dataset.
- Average Q-Value: the mean predicted return across sampled transitions.
- ๐ Loss Logging:
- Training losses logged per epoch:
- Bellman loss (temporal difference)
- Conservative loss (Q regularization)
- Behavior cloning (BC) loss
- Training losses logged per epoch:
- ๐พ Checkpointing:
- Automatically saves the
q_net
andpolicy
when policy accuracy improves. - Saved to:
checkpoints/best_q.pt
andcheckpoints/best_policy.pt
- Automatically saves the
- ๐ TensorBoard Integration:
- Visualizations include:
To run TensorBoard:
tensorboard --logdir=logs
You can monitor live training and evaluation updates in your browser at:
๐ http://localhost:6006
-
agent/train.py
: Main training loop updated with:- Evaluation every 100 epochs
- TensorBoard logging of loss and accuracy metrics
- Checkpoint saving logic for best-performing policy
-
checkpoints/
: Directory created to store.pt
model weights
We implemented model compression techniques to reduce memory usage and inference latency of the offline RL agent without sacrificing reward. Two approaches were explored:
- Static Quantization using PyTorch's
torch.quantization
pipeline. - Structured Pruning (via
torch.nn.utils.prune.ln_structured
) to remove 30โ60% of neurons from linear layers. - Unstructured Pruning (optional) to sparsify weights within layers for additional compression.
- Fine-tuning after pruning to recover performance.
Each point below represents a model version โ plotted by reward and latency, with bubble size representing memory usage.
Static quantization (using torch.quantization.convert()
) is not currently supported on macOS ARM (M1/M2 chips). You may see the following error:
NotImplementedError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). 'quantized::linear' is only available for these backends: [MPS, Meta, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMeta, Tracer, AutocastCPU, AutocastMTIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].
To avoid this, switch to dynamic quantization, which works on macOS and still gives performance benefits on CPUs:
from torch.quantization import quantize_dynamic
quantized_model = quantize_dynamic(
PolicyNetwork(state_dim=4, action_dim=3),
{torch.nn.Linear},
dtype=torch.qint8
)
quantized_model.load_state_dict(torch.load("checkpoints/best_policy.pt"))
No need for qconfig
, prepare()
, or convert()
โ just quantize and run.
๐งช Results (Sample)
Model | Reward | Latency (ms) | Memory (MB) |
---|---|---|---|
Original | 8.30 | 0.23 | 341.86 |
Pruned | 8.30 | 0.24 | 341.86 |
Quantized (Dynamic) | ~8.30 | ~0.20 | ~330.00 |
With higher pruning ratios or quantization + pruning combinations, further improvements can be achieved.
๐ Code Locations
- Compression logic:
agent/compress.py
- Evaluation and plotting: part of
compress.py
(runs automatically) - Trained models saved in:
checkpoints/
- Visualization saved to:
docs/plots/compression_tradeoff.png
python agent/compress.py
We implement knowledge distillation to compress a large policy model (BigMLP) into a smaller student model (SmallMLP). The student is trained on soft labels from the teacherโs output logits using KL divergence loss.
KL Divergence between teacher and student softmax outputs across training epochs.
Average reward over 10 episodes after distillation. Evaluation performed using
NeuroQuantEnv
.
- Trained student saved at:
checkpoints/small_mlp_distilled.pt
- Loss log:
logs/distill_loss.log
- Reward log:
logs/student_reward.log
We built a real-time inference loop to benchmark the performance of compressed models in a live environment. The goal was to deploy a distilled model (SmallMLP
) and verify whether it could sustain high-speed, low-latency decision-making under realistic constraints.
- Load a compressed, lightweight model (
small_mlp_distilled.pt
) - Step through
NeuroQuantEnv
in a real-time loop (โค 100ms per frame) - Log per-frame:
- โ Inference latency (ms)
- โ Memory usage (MB)
- โ Actions taken
- Display live FPS and latency in terminal
- Save and plot performance metrics
- ๐ง Model:
SmallMLP
distilled fromBigMLP
- ๐ Average Latency: 1.08 ms
- ๐๏ธ Average FPS: 929.04
- ๐ง Memory Usage: ~194.55 MB
๐ฎ Starting real-time inference loop...
[Frame 1] Latency: 1.15 ms | FPS: 869.19 | Mem: 194.42 MB | Action: 2
[Frame 10] Latency: 1.05 ms | FPS: 950.23 | Mem: 194.55 MB | Action: 0
...
๐ฏ Real-Time Inference Complete
๐ Total Time: 1.88 s
๐ Avg Latency: 1.08 ms | Avg FPS: 929.04
- ๐ Each point = one environment step
- ๐ Latency remained stable across steps (~1ms)
- ๐ฏ FPS consistently exceeded 900
inference/run_realtime_inference.py # Real-time engine
checkpoints/small_mlp_distilled.pt # Compressed model
logs/day11_metrics.csv # Per-frame metrics
docs/plots/day11_latency.png # Latency graph
docs/plots/day11_fps.png # FPS graph
We upgraded the inference engine to simulate a realistic deployment environment where latency spikes trigger automatic shutdown, and all runtime metrics are logged for post-analysis.
- โ Shutdown if inference latency > 150 ms
- โ
Log each frame with:
- Timestamp
- Inference Latency
- Memory Usage
- Cumulative Reward
- โ
Write to:
results/session_X.csv
- ๐ Each point = 1 environment step
- ๐จ Inference aborted if latency >150ms
- ๐ง All inference logs saved for replayability
inference/run_realtime_inference.py # Real-time agent w/ shutdown & logging
results/session_X.csv # Per-frame logs (auto-numbered)
scripts/plot_day12_session.py # Plotting script
docs/plots/day12_latency.png # Per-frame latency plot
docs/plots/day12_reward.png # Cumulative reward plot
The environment supports saving full episodes as GIFs using the render_episode_gif()
function.