From 5d3824133cc4d055903174a0b56322020b91c97e Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Sun, 15 Jun 2025 20:24:37 +0530 Subject: [PATCH 1/5] Add code for A3C --- sota-implementations/a3c/a3c_atari.py | 227 +++++++++++++++++++ sota-implementations/a3c/config_atari.yaml | 46 ++++ sota-implementations/a3c/utils_atari.py | 241 +++++++++++++++++++++ 3 files changed, 514 insertions(+) create mode 100644 sota-implementations/a3c/a3c_atari.py create mode 100644 sota-implementations/a3c/config_atari.yaml create mode 100644 sota-implementations/a3c/utils_atari.py diff --git a/sota-implementations/a3c/a3c_atari.py b/sota-implementations/a3c/a3c_atari.py new file mode 100644 index 00000000000..f4ed4f198f6 --- /dev/null +++ b/sota-implementations/a3c/a3c_atari.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +from copy import deepcopy + +import hydra +import torch + +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim +import tqdm + +from torchrl.collectors import SyncDataCollector +from torchrl.objectives import A2CLoss +from torchrl.objectives.value.advantages import GAE + +from torchrl.record.loggers import generate_exp_name, get_logger +from utils_atari import make_parallel_env, make_ppo_models + + +torch.set_float32_matmul_precision("high") + + +class SharedAdam(torch.optim.Adam): + def __init__(self, params, **kwargs): + super().__init__(params, **kwargs) + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["step"] = torch.zeros(1) + state["exp_avg"] = torch.zeros_like(p.data) + state["exp_avg_sq"] = torch.zeros_like(p.data) + state["exp_avg"].share_memory_() + state["exp_avg_sq"].share_memory_() + state["step"].share_memory_() + + +class A3CWorker(mp.Process): + def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=None): + super().__init__() + self.name = name + self.cfg = cfg + + self.optimizer = optimizer + + self.device = cfg.loss.device or torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) + + self.frame_skip = 4 + self.total_frames = cfg.collector.total_frames // self.frame_skip + self.frames_per_batch = cfg.collector.frames_per_batch // self.frame_skip + self.mini_batch_size = cfg.loss.mini_batch_size // self.frame_skip + self.test_interval = cfg.logger.test_interval // self.frame_skip + + self.global_actor = global_actor + self.global_critic = global_critic + self.local_actor = deepcopy(global_actor) + self.local_critic = deepcopy(global_critic) + + self.logger = logger + + self.adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=self.local_critic, + average_gae=True, + vectorized=not cfg.compile.compile, + device=self.device, + ) + self.loss_module = A2CLoss( + actor_network=self.local_actor, + critic_network=self.local_critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + + self.adv_module.set_keys(done="end-of-life", terminated="end-of-life") + self.loss_module.set_keys(done="end-of-life", terminated="end-of-life") + + def update(self, batch, max_grad_norm=None): + if max_grad_norm is None: + max_grad_norm = self.cfg.optim.max_grad_norm + + loss = self.loss_module(batch) + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + loss_sum.backward() + + for local_param, global_param in zip( + self.local_actor.parameters(), self.global_actor.parameters() + ): + global_param._grad = local_param.grad + + for local_param, global_param in zip( + self.local_critic.parameters(), self.global_critic.parameters() + ): + global_param._grad = local_param.grad + + gn = torch.nn.utils.clip_grad_norm_( + self.loss_module.parameters(), max_norm=max_grad_norm + ) + + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + + return ( + loss.select("loss_critic", "loss_entropy", "loss_objective") + .detach() + .set("grad_norm", gn) + ) + + def run(self): + cfg = self.cfg + + collector = SyncDataCollector( + create_env_fn=make_parallel_env( + cfg.env.env_name, + num_envs=cfg.env.num_envs, + device=self.device, + gym_backend=cfg.env.backend, + ), + policy=self.local_actor, + frames_per_batch=self.frames_per_batch, + total_frames=self.total_frames, + device=self.device, + storing_device=self.device, + policy_device=self.device, + compile_policy=False, + cudagraph_policy=False, + ) + + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=self.total_frames) + num_mini_batches = self.frames_per_batch // self.mini_batch_size + total_network_updates = ( + self.total_frames // self.frames_per_batch + ) * num_mini_batches + lr = cfg.optim.lr + + c_iter = iter(collector) + total_iter = len(collector) + + for _ in range(total_iter): + data = next(c_iter) + + metrics_to_log = {} + frames_in_batch = data.numel() + collected_frames += self.frames_per_batch * self.frame_skip + pbar.update(frames_in_batch) + + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "terminated"]] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log[ + "train/episode_length" + ] = episode_length.sum().item() / len(episode_length) + + with torch.no_grad(): + data = self.adv_module(data) + data_reshape = data.reshape(-1) + losses = [] + + mini_batches = data_reshape.split(self.mini_batch_size) + for batch in mini_batches: + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in self.optimizer.param_groups: + group["lr"] = lr * alpha + + num_network_updates += 1 + loss = self.update(batch).clone() + losses.append(loss) + + losses = torch.stack(losses).float().mean() + + for key, value in losses.items(): + metrics_to_log[f"train/{key}"] = value.item() + + metrics_to_log["train/lr"] = lr * alpha + if self.logger: + for key, value in metrics_to_log.items(): + self.logger.log_scalar(key, value, collected_frames) + collector.shutdown() + + +@hydra.main(config_path="", config_name="config_atari", version_base="1.1") +def main(cfg: DictConfig): # noqa: F821 + + global_actor, global_critic, global_critic_head = make_ppo_models( + cfg.env.env_name, device=cfg.loss.device, gym_backend=cfg.env.backend + ) + global_model = nn.ModuleList([global_actor, global_critic_head]) + global_model.share_memory() + optimizer = SharedAdam(global_model.parameters(), lr=cfg.optim.lr) + + num_workers = cfg.multiprocessing.num_workers + + if num_workers is None: + num_workers = mp.cpu_count() + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name("A3C", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger( + cfg.logger.backend, + logger_name="a3c", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + workers = [ + A3CWorker(f"worker_{i}", cfg, global_actor, global_critic, optimizer, logger) + for i in range(num_workers) + ] + [w.start() for w in workers] + [w.join() for w in workers] + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/a3c/config_atari.yaml b/sota-implementations/a3c/config_atari.yaml new file mode 100644 index 00000000000..a4dbb0def31 --- /dev/null +++ b/sota-implementations/a3c/config_atari.yaml @@ -0,0 +1,46 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + backend: gymnasium + num_envs: 1 + +# collector +collector: + frames_per_batch: 800 + total_frames: 40_000_000 + +# logger +logger: + backend: wandb + project_name: torchrl_example_a2c + group_name: null + exp_name: Atari_Schulman17 + test_interval: 40_000_000 + num_test_episodes: 3 + video: False + +# Optim +optim: + lr: 0.0001 + eps: 1.0e-8 + weight_decay: 0.0 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + mini_batch_size: 80 + gae_lambda: 0.95 + critic_coef: 0.25 + entropy_coef: 0.01 + loss_critic_type: l2 + device: + +compile: + compile: False + compile_mode: + cudagraphs: False + +multiprocessing: + num_workers: 16 diff --git a/sota-implementations/a3c/utils_atari.py b/sota-implementations/a3c/utils_atari.py new file mode 100644 index 00000000000..655e1cec79f --- /dev/null +++ b/sota-implementations/a3c/utils_atari.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import ale_py # noqa: F401 + +import numpy as np +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data.tensor_specs import CategoricalBox +from torchrl.envs import ( + CatFrames, + DoubleToFloat, + EndOfLifeTransform, + EnvCreator, + ExplorationType, + GrayScale, + GymEnv, + NoopResetEnv, + ParallelEnv, + Resize, + RewardSum, + set_gym_backend, + SignTransform, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + TanhNormal, + ValueOperator, +) +from torchrl.record import VideoRecorder + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_base_env( + env_name="BreakoutNoFrameskip-v4", + gym_backend="gymnasium", + frame_skip=4, + device="cpu", + is_test=False, +): + with set_gym_backend(gym_backend): + env = GymEnv( + env_name, + frame_skip=frame_skip, + from_pixels=True, + pixels_only=False, + device=device, + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + env.append_transform(EndOfLifeTransform()) + return env + + +def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False): + env = ParallelEnv( + num_envs, + EnvCreator( + lambda: make_base_env(env_name, gym_backend=gym_backend, is_test=is_test), + ), + serial_for_single=True, + device=device, + ) + env = TransformedEnv(env) + env.append_transform(DoubleToFloat()) + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + if not is_test: + env.append_transform(SignTransform(in_keys=["reward"])) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment, device): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox): + num_outputs = proof_environment.action_spec_unbatched.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + else: # is ContinuousBox + num_outputs = proof_environment.action_spec_unbatched.shape + distribution_class = TanhNormal + distribution_kwargs = { + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), + } + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + device=device, + ) + common_cnn_output = common_cnn(torch.ones(input_shape, device=device)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + device=device, + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + device=device, + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=proof_environment.full_action_spec_unbatched.to(device), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + device=device, + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name, device, gym_backend): + + proof_environment = make_parallel_env( + env_name, num_envs=1, device="cpu", gym_backend=gym_backend + ) + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment, device=device + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + with torch.no_grad(): + td = proof_environment.fake_tensordict().expand(1) + td = actor_critic(td.to(device)) + del td + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + critic_head = actor_critic.get_value_head() + + del proof_environment + + return actor, critic, critic_head + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + test_env.apply(dump_video) + del td_test + return test_rewards.mean() From 748b673ac3f582cd12c87678387e989f392afd74 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Sun, 15 Jun 2025 22:03:04 +0530 Subject: [PATCH 2/5] Add readme --- sota-implementations/a3c/README.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 sota-implementations/a3c/README.md diff --git a/sota-implementations/a3c/README.md b/sota-implementations/a3c/README.md new file mode 100644 index 00000000000..473db57f32e --- /dev/null +++ b/sota-implementations/a3c/README.md @@ -0,0 +1,21 @@ +# Reproducing Asynchronous Advantage Actor Critic (A3C) Algorithm Results + +This repository contains scripts that enable training agents using the Asynchronous Advantage Actor Critic (A3C) Algorithm on Atari environments. We follow the original paper [Asynchronous Methods for Deep Reinforcement Learning](https://arxiv.org/abs/1602.01783) by Mnih et al. (2016) to implement the A3C algorithm with a fixed number of steps during the collection phase. + +## Examples Structure + +Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: + +1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. `a3c_atari.py`). + +2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. `utils_atari.py`). + +3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. `config_atari.yaml`). + +## Running the Examples + +You can execute the A3C algorithm on Atari environments by running the following command: + +```bash +python a3c_atari.py +``` From ecbec8b8e8337ec9aa51d92782d26e5eb4ae50e4 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Mon, 7 Jul 2025 01:23:20 +0530 Subject: [PATCH 3/5] log only worker-0 stats --- sota-implementations/a3c/a3c_atari.py | 5 ++++- sota-implementations/a3c/config_atari.yaml | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sota-implementations/a3c/a3c_atari.py b/sota-implementations/a3c/a3c_atari.py index f4ed4f198f6..931c9e14b25 100644 --- a/sota-implementations/a3c/a3c_atari.py +++ b/sota-implementations/a3c/a3c_atari.py @@ -181,7 +181,10 @@ def run(self): metrics_to_log[f"train/{key}"] = value.item() metrics_to_log["train/lr"] = lr * alpha - if self.logger: + + # Logging only on the first worker in the dashboard. + # Alternatively, you can use a distributed logger, or aggregate metrics from all workers. + if self.logger and self.name == "worker_0": for key, value in metrics_to_log.items(): self.logger.log_scalar(key, value, collected_frames) collector.shutdown() diff --git a/sota-implementations/a3c/config_atari.yaml b/sota-implementations/a3c/config_atari.yaml index a4dbb0def31..3950bb4aa4a 100644 --- a/sota-implementations/a3c/config_atari.yaml +++ b/sota-implementations/a3c/config_atari.yaml @@ -7,14 +7,14 @@ env: # collector collector: frames_per_batch: 800 - total_frames: 40_000_000 + total_frames: 40_000_00 # logger logger: backend: wandb - project_name: torchrl_example_a2c + project_name: torchrl_example_a3c group_name: null - exp_name: Atari_Schulman17 + exp_name: a3c_atari_training test_interval: 40_000_000 num_test_episodes: 3 video: False From 72eea77229a4d44f68794e1a298c75fa10c9918e Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Mon, 7 Jul 2025 01:27:30 +0530 Subject: [PATCH 4/5] Add linux test script file, and sota-check for a3c --- .../unittest/linux_sota/scripts/test_sota.py | 8 ++++++ sota-check/run_a3c_atari.sh | 27 +++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 sota-check/run_a3c_atari.sh diff --git a/.github/unittest/linux_sota/scripts/test_sota.py b/.github/unittest/linux_sota/scripts/test_sota.py index f3513bcfca2..b6c3bbec2b0 100644 --- a/.github/unittest/linux_sota/scripts/test_sota.py +++ b/.github/unittest/linux_sota/scripts/test_sota.py @@ -15,6 +15,14 @@ ), "Composite LP must be set to False. Run this test with COMPOSITE_LP_AGGREGATE=0" commands = { + "a3c_atari": """python sota-implementations/a3c/a3c_atari.py \ + collector.total_frames=80 \ + collector.frames_per_batch=20 \ + loss.mini_batch_size=20 \ + logger.backend= \ + env.backend=gym \ + multiprocessing.num_workers=4 +""", "dt": """python sota-implementations/decision_transformer/dt.py \ optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ diff --git a/sota-check/run_a3c_atari.sh b/sota-check/run_a3c_atari.sh new file mode 100644 index 00000000000..68b21563f45 --- /dev/null +++ b/sota-check/run_a3c_atari.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +#SBATCH --job-name=a2c_atari +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/a2c_atari_%j.txt +#SBATCH --error=slurm_errors/a2c_atari_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="a3c_atari" + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/a3c/a3c_atari.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log +fi From d95de87db6f1ee4f5cc4ed985a84ae70d443ebea Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Mon, 7 Jul 2025 01:30:47 +0530 Subject: [PATCH 5/5] modify sota-check a3c --- sota-check/run_a3c_atari.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-check/run_a3c_atari.sh b/sota-check/run_a3c_atari.sh index 68b21563f45..061d17c46eb 100644 --- a/sota-check/run_a3c_atari.sh +++ b/sota-check/run_a3c_atari.sh @@ -1,11 +1,11 @@ #!/bin/bash -#SBATCH --job-name=a2c_atari +#SBATCH --job-name=a3c_atari #SBATCH --ntasks=32 #SBATCH --cpus-per-task=1 #SBATCH --gres=gpu:1 -#SBATCH --output=slurm_logs/a2c_atari_%j.txt -#SBATCH --error=slurm_errors/a2c_atari_%j.txt +#SBATCH --output=slurm_logs/a3c_atari_%j.txt +#SBATCH --error=slurm_errors/a3c_atari_%j.txt current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit"