-
Notifications
You must be signed in to change notification settings - Fork 393
[Feature, Example] A3C Atari Implementation for TorchRL #3001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
5d38241
748b673
ecbec8b
72eea77
7b9ba6b
d95de87
c4184f6
7cfa7d7
87ec6f3
bba7ba5
b49e35a
a6eb18d
836f03e
19209e2
3ebad77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Reproducing Asynchronous Advantage Actor Critic (A3C) Algorithm Results | ||
|
||
This repository contains scripts that enable training agents using the Asynchronous Advantage Actor Critic (A3C) Algorithm on Atari environments. We follow the original paper [Asynchronous Methods for Deep Reinforcement Learning](https://arxiv.org/abs/1602.01783) by Mnih et al. (2016) to implement the A3C algorithm with a fixed number of steps during the collection phase. | ||
|
||
## Examples Structure | ||
|
||
Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: | ||
|
||
1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. `a3c_atari.py`). | ||
|
||
2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. `utils_atari.py`). | ||
|
||
3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. `config_atari.yaml`). | ||
|
||
## Running the Examples | ||
|
||
You can execute the A3C algorithm on Atari environments by running the following command: | ||
|
||
```bash | ||
python a3c_atari.py | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
|
||
import hydra | ||
import torch | ||
|
||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
import torch.optim | ||
import tqdm | ||
|
||
from torchrl.collectors import SyncDataCollector | ||
from torchrl.objectives import A2CLoss | ||
from torchrl.objectives.value.advantages import GAE | ||
|
||
from torchrl.record.loggers import generate_exp_name, get_logger | ||
from utils_atari import make_parallel_env, make_ppo_models | ||
|
||
|
||
torch.set_float32_matmul_precision("high") | ||
|
||
|
||
class SharedAdam(torch.optim.Adam): | ||
def __init__(self, params, **kwargs): | ||
super().__init__(params, **kwargs) | ||
for group in self.param_groups: | ||
for p in group["params"]: | ||
state = self.state[p] | ||
state["step"] = torch.zeros(1) | ||
state["exp_avg"] = torch.zeros_like(p.data) | ||
state["exp_avg_sq"] = torch.zeros_like(p.data) | ||
state["exp_avg"].share_memory_() | ||
state["exp_avg_sq"].share_memory_() | ||
state["step"].share_memory_() | ||
|
||
|
||
class A3CWorker(mp.Process): | ||
def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=None): | ||
super().__init__() | ||
self.name = name | ||
self.cfg = cfg | ||
|
||
self.optimizer = optimizer | ||
|
||
self.device = cfg.loss.device or torch.device( | ||
"cuda:0" if torch.cuda.is_available() else "cpu" | ||
) | ||
|
||
self.frame_skip = 4 | ||
self.total_frames = cfg.collector.total_frames // self.frame_skip | ||
self.frames_per_batch = cfg.collector.frames_per_batch // self.frame_skip | ||
self.mini_batch_size = cfg.loss.mini_batch_size // self.frame_skip | ||
self.test_interval = cfg.logger.test_interval // self.frame_skip | ||
|
||
self.global_actor = global_actor | ||
self.global_critic = global_critic | ||
self.local_actor = deepcopy(global_actor) | ||
self.local_critic = deepcopy(global_critic) | ||
|
||
self.logger = logger | ||
|
||
self.adv_module = GAE( | ||
gamma=cfg.loss.gamma, | ||
lmbda=cfg.loss.gae_lambda, | ||
value_network=self.local_critic, | ||
average_gae=True, | ||
vectorized=not cfg.compile.compile, | ||
device=self.device, | ||
) | ||
self.loss_module = A2CLoss( | ||
actor_network=self.local_actor, | ||
critic_network=self.local_critic, | ||
loss_critic_type=cfg.loss.loss_critic_type, | ||
entropy_coef=cfg.loss.entropy_coef, | ||
critic_coef=cfg.loss.critic_coef, | ||
) | ||
|
||
self.adv_module.set_keys(done="end-of-life", terminated="end-of-life") | ||
self.loss_module.set_keys(done="end-of-life", terminated="end-of-life") | ||
|
||
def update(self, batch, max_grad_norm=None): | ||
if max_grad_norm is None: | ||
max_grad_norm = self.cfg.optim.max_grad_norm | ||
|
||
loss = self.loss_module(batch) | ||
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] | ||
loss_sum.backward() | ||
|
||
for local_param, global_param in zip( | ||
self.local_actor.parameters(), self.global_actor.parameters() | ||
): | ||
global_param._grad = local_param.grad | ||
|
||
for local_param, global_param in zip( | ||
self.local_critic.parameters(), self.global_critic.parameters() | ||
): | ||
global_param._grad = local_param.grad | ||
|
||
gn = torch.nn.utils.clip_grad_norm_( | ||
self.loss_module.parameters(), max_norm=max_grad_norm | ||
) | ||
Comment on lines
+110
to
+122
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain what we do here? What do we use the _grad for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _grad is used to store the gradients for each parameter. |
||
|
||
self.optimizer.step() | ||
self.optimizer.zero_grad(set_to_none=True) | ||
|
||
return ( | ||
loss.select("loss_critic", "loss_entropy", "loss_objective") | ||
.detach() | ||
.set("grad_norm", gn) | ||
) | ||
|
||
def run(self): | ||
cfg = self.cfg | ||
|
||
collector = SyncDataCollector( | ||
create_env_fn=make_parallel_env( | ||
cfg.env.env_name, | ||
num_envs=cfg.env.num_envs, | ||
device=self.device, | ||
gym_backend=cfg.env.backend, | ||
), | ||
policy=self.local_actor, | ||
frames_per_batch=self.frames_per_batch, | ||
total_frames=self.total_frames, | ||
device=self.device, | ||
storing_device=self.device, | ||
policy_device=self.device, | ||
compile_policy=False, | ||
cudagraph_policy=False, | ||
) | ||
|
||
collected_frames = 0 | ||
num_network_updates = 0 | ||
pbar = tqdm.tqdm(total=self.total_frames) | ||
num_mini_batches = self.frames_per_batch // self.mini_batch_size | ||
total_network_updates = ( | ||
self.total_frames // self.frames_per_batch | ||
) * num_mini_batches | ||
lr = cfg.optim.lr | ||
|
||
c_iter = iter(collector) | ||
total_iter = len(collector) | ||
|
||
for _ in range(total_iter): | ||
data = next(c_iter) | ||
|
||
metrics_to_log = {} | ||
frames_in_batch = data.numel() | ||
collected_frames += self.frames_per_batch * self.frame_skip | ||
pbar.update(frames_in_batch) | ||
|
||
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] | ||
if len(episode_rewards) > 0: | ||
episode_length = data["next", "step_count"][data["next", "terminated"]] | ||
metrics_to_log["train/reward"] = episode_rewards.mean().item() | ||
metrics_to_log[ | ||
"train/episode_length" | ||
] = episode_length.sum().item() / len(episode_length) | ||
|
||
with torch.no_grad(): | ||
data = self.adv_module(data) | ||
data_reshape = data.reshape(-1) | ||
losses = [] | ||
|
||
mini_batches = data_reshape.split(self.mini_batch_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To shuffle things a bit I usually rely on a replay buffer instance rather than just splitting the data |
||
for batch in mini_batches: | ||
alpha = 1.0 | ||
if cfg.optim.anneal_lr: | ||
alpha = 1 - (num_network_updates / total_network_updates) | ||
for group in self.optimizer.param_groups: | ||
group["lr"] = lr * alpha | ||
|
||
num_network_updates += 1 | ||
loss = self.update(batch).clone() | ||
losses.append(loss) | ||
|
||
losses = torch.stack(losses).float().mean() | ||
|
||
for key, value in losses.items(): | ||
metrics_to_log[f"train/{key}"] = value.item() | ||
|
||
metrics_to_log["train/lr"] = lr * alpha | ||
if self.logger: | ||
for key, value in metrics_to_log.items(): | ||
self.logger.log_scalar(key, value, collected_frames) | ||
collector.shutdown() | ||
|
||
|
||
@hydra.main(config_path="", config_name="config_atari", version_base="1.1") | ||
def main(cfg: DictConfig): # noqa: F821 | ||
|
||
global_actor, global_critic, global_critic_head = make_ppo_models( | ||
cfg.env.env_name, device=cfg.loss.device, gym_backend=cfg.env.backend | ||
) | ||
global_model = nn.ModuleList([global_actor, global_critic_head]) | ||
global_model.share_memory() | ||
optimizer = SharedAdam(global_model.parameters(), lr=cfg.optim.lr) | ||
|
||
num_workers = cfg.multiprocessing.num_workers | ||
|
||
if num_workers is None: | ||
num_workers = mp.cpu_count() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should have way fewer workers - I think we need users to tell us how many. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That can be configured in the config_atari. You want me to explicitly set it to some constant here? |
||
logger = None | ||
if cfg.logger.backend: | ||
exp_name = generate_exp_name("A3C", f"{cfg.logger.exp_name}_{cfg.env.env_name}") | ||
logger = get_logger( | ||
cfg.logger.backend, | ||
logger_name="a3c", | ||
experiment_name=exp_name, | ||
wandb_kwargs={ | ||
"config": dict(cfg), | ||
"project": cfg.logger.project_name, | ||
"group": cfg.logger.group_name, | ||
}, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I usually see is that the logger is only passed to the first worker. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yea, I did that because I thought logging any single worker should be a good representative of the global model since anyway the weights are being copied. Logging all the worker might not be really useful but that can be done as well. |
||
|
||
workers = [ | ||
A3CWorker(f"worker_{i}", cfg, global_actor, global_critic, optimizer, logger) | ||
for i in range(num_workers) | ||
] | ||
[w.start() for w in workers] | ||
[w.join() for w in workers] | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
backend: gymnasium | ||
num_envs: 1 | ||
|
||
# collector | ||
collector: | ||
frames_per_batch: 800 | ||
total_frames: 40_000_000 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
project_name: torchrl_example_a2c | ||
group_name: null | ||
exp_name: Atari_Schulman17 | ||
test_interval: 40_000_000 | ||
num_test_episodes: 3 | ||
video: False | ||
|
||
# Optim | ||
optim: | ||
lr: 0.0001 | ||
eps: 1.0e-8 | ||
weight_decay: 0.0 | ||
max_grad_norm: 40.0 | ||
anneal_lr: True | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
mini_batch_size: 80 | ||
gae_lambda: 0.95 | ||
critic_coef: 0.25 | ||
entropy_coef: 0.01 | ||
loss_critic_type: l2 | ||
device: | ||
|
||
compile: | ||
compile: False | ||
compile_mode: | ||
cudagraphs: False | ||
|
||
multiprocessing: | ||
num_workers: 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we move this to the utils file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do it