diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index b4a78ab6a02..f350e49b866 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -32,6 +32,7 @@ from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.libs.robohive import RoboHiveEnv parser = ArgumentParser() parser.add_argument( @@ -80,6 +81,16 @@ default="ALE/Pong-v5", help="Gym environment to be run.", ) +LIBS = { + "gym": GymEnv, + "robohive": RoboHiveEnv, +} +parser.add_argument( + "--lib", + default="gym", + help="Lib backend", + choices=list(LIBS.keys()), +) if __name__ == "__main__": args = parser.parse_args() num_workers = args.num_workers @@ -89,7 +100,8 @@ device_count = torch.cuda.device_count() - make_env = EnvCreator(lambda: GymEnv(args.env)) + lib = LIBS[args.lib] + make_env = EnvCreator(lambda: lib(args.env)) if args.worker_parallelism == "collector" or num_workers == 1: action_spec = make_env().action_spec else: diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index 8bf1fcf004f..461363a14ef 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -27,6 +27,7 @@ from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.libs.robohive import RoboHiveEnv parser = ArgumentParser() parser.add_argument( @@ -63,6 +64,16 @@ default="ALE/Pong-v5", help="Gym environment to be run.", ) +LIBS = { + "gym": GymEnv, + "robohive": RoboHiveEnv, +} +parser.add_argument( + "--lib", + default="gym", + help="Lib backend", + choices=list(LIBS.keys()), +) if __name__ == "__main__": args = parser.parse_args() num_workers = args.num_workers diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index 9f62b86f878..40fdbf1f1d4 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -32,6 +32,7 @@ from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.libs.robohive import RoboHiveEnv parser = ArgumentParser() parser.add_argument( @@ -75,6 +76,16 @@ default="ALE/Pong-v5", help="Gym environment to be run.", ) +LIBS = { + "gym": GymEnv, + "robohive": RoboHiveEnv, +} +parser.add_argument( + "--lib", + default="gym", + help="Lib backend", + choices=list(LIBS.keys()), +) if __name__ == "__main__": args = parser.parse_args() num_workers = args.num_workers diff --git a/examples/ppo/config.yaml b/examples/ppo/config.yaml deleted file mode 100644 index d7840906c92..00000000000 --- a/examples/ppo/config.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# task and env -defaults: - - hydra/job_logging: disabled - -env: - env_name: PongNoFrameskip-v4 - env_task: "" - env_library: gym - frame_skip: 4 - num_envs: 8 - noop: 1 - reward_scaling: 1.0 - from_pixels: True - n_samples_stats: 1000 - device: cuda:0 - -# collector -collector: - frames_per_batch: 4096 - total_frames: 40_000_000 - collector_device: cuda:0 # cpu - max_frames_per_traj: -1 - -# logger -logger: - backend: wandb - exp_name: ppo_pong_gym - log_interval: 10000 - -# Optim -optim: - device: cuda:0 - lr: 2.5e-4 - weight_decay: 0.0 - lr_scheduler: True - -# loss -loss: - gamma: 0.99 - mini_batch_size: 1024 - ppo_epochs: 10 - gae_lamdda: 0.95 - clip_epsilon: 0.1 - critic_coef: 0.5 - entropy_coef: 0.01 - loss_critic_type: l2 diff --git a/examples/ppo/config_atari.yaml b/examples/ppo/config_atari.yaml new file mode 100644 index 00000000000..ec22dcd71af --- /dev/null +++ b/examples/ppo/config_atari.yaml @@ -0,0 +1,35 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# collector +collector: + frames_per_batch: 4096 + total_frames: 40_000_000 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 40_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 2.5e-4 + eps: 1.0e-6 + weight_decay: 0.0 + max_grad_norm: 0.5 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + mini_batch_size: 1024 + ppo_epochs: 3 + gae_lambda: 0.95 + clip_epsilon: 0.1 + anneal_clip_epsilon: True + critic_coef: 1.0 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/ppo/config_example2.yaml b/examples/ppo/config_example2.yaml deleted file mode 100644 index 9d06c8a82ee..00000000000 --- a/examples/ppo/config_example2.yaml +++ /dev/null @@ -1,43 +0,0 @@ -# task and env -env: - env_name: HalfCheetah-v4 - env_task: "" - env_library: gym - frame_skip: 1 - num_envs: 1 - noop: 1 - reward_scaling: 1.0 - from_pixels: False - n_samples_stats: 3 - device: cuda - -# collector -collector: - frames_per_batch: 2048 - total_frames: 1_000_000 - collector_device: cuda # cpu - max_frames_per_traj: -1 - -# logger -logger: - backend: wandb - exp_name: ppo_halfcheetah_gym - log_interval: 10000 - -# Optim -optim: - device: cuda - lr: 3e-4 - weight_decay: 1e-4 - lr_scheduler: False - -# loss -loss: - gamma: 0.99 - mini_batch_size: 64 - ppo_epochs: 10 - gae_lamdda: 0.95 - clip_epsilon: 0.2 - critic_coef: 0.5 - entropy_coef: 0.0 - loss_critic_type: l2 diff --git a/examples/ppo/config_mujoco.yaml b/examples/ppo/config_mujoco.yaml new file mode 100644 index 00000000000..cb84b4bea23 --- /dev/null +++ b/examples/ppo/config_mujoco.yaml @@ -0,0 +1,32 @@ +# task and env +env: + env_name: HalfCheetah-v3 + +# collector +collector: + frames_per_batch: 2048 + total_frames: 1_000_000 + +# logger +logger: + backend: wandb + exp_name: Mujoco_Schulman17 + test_interval: 1_000_000 + num_test_episodes: 5 + +# Optim +optim: + lr: 3e-4 + weight_decay: 0.0 + anneal_lr: False + +# loss +loss: + gamma: 0.99 + mini_batch_size: 64 + ppo_epochs: 10 + gae_lambda: 0.95 + clip_epsilon: 0.2 + critic_coef: 0.25 + entropy_coef: 0.0 + loss_critic_type: l2 diff --git a/examples/ppo/config_myo.yaml b/examples/ppo/config_myo.yaml new file mode 100644 index 00000000000..4da1b03809d --- /dev/null +++ b/examples/ppo/config_myo.yaml @@ -0,0 +1,33 @@ +# task and env +env: + env_name: myoHandReachRandom-v0 + +# collector +collector: + frames_per_batch: 2048 + total_frames: 1_000_000 + num_envs: 1 + +# logger +logger: + backend: wandb + exp_name: myo_hand_reach + test_interval: 1_000_000 + num_test_episodes: 5 + +# Optim +optim: + lr: 3e-4 + weight_decay: 0.0 + anneal_lr: False + +# loss +loss: + gamma: 0.99 + mini_batch_size: 64 + ppo_epochs: 10 + gae_lambda: 0.95 + clip_epsilon: 0.2 + critic_coef: 0.25 + entropy_coef: 0.0 + loss_critic_type: l2 diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py deleted file mode 100644 index 7f532bc0c4d..00000000000 --- a/examples/ppo/ppo.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -"""PPO Example. - -This is a self-contained example of a PPO training script. - -Both state and pixel-based environments are supported. - -The helper functions are coded in the utils.py associated with this script. -""" -import hydra - - -@hydra.main(config_path=".", config_name="config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 - - import torch - import tqdm - from tensordict import TensorDict - from torchrl.envs.utils import ExplorationType, set_exploration_type - from utils import ( - make_collector, - make_data_buffer, - make_logger, - make_loss, - make_optim, - make_ppo_models, - make_test_env, - ) - - # Correct for frame_skip - cfg.collector.total_frames = cfg.collector.total_frames // cfg.env.frame_skip - cfg.collector.frames_per_batch = ( - cfg.collector.frames_per_batch // cfg.env.frame_skip - ) - mini_batch_size = cfg.loss.mini_batch_size = ( - cfg.loss.mini_batch_size // cfg.env.frame_skip - ) - - model_device = cfg.optim.device - actor, critic, critic_head = make_ppo_models(cfg) - - collector, state_dict = make_collector(cfg, policy=actor) - data_buffer = make_data_buffer(cfg) - loss_module, adv_module = make_loss( - cfg.loss, - actor_network=actor, - value_network=critic, - value_head=critic_head, - ) - optim = make_optim(cfg.optim, loss_module) - - batch_size = cfg.collector.total_frames * cfg.env.num_envs - num_mini_batches = batch_size // mini_batch_size - total_network_updates = ( - (cfg.collector.total_frames // batch_size) - * cfg.loss.ppo_epochs - * num_mini_batches - ) - - scheduler = None - if cfg.optim.lr_scheduler: - scheduler = torch.optim.lr_scheduler.LinearLR( - optim, total_iters=total_network_updates, start_factor=1.0, end_factor=0.1 - ) - - logger = None - if cfg.logger.backend: - logger = make_logger(cfg.logger) - test_env = make_test_env(cfg.env, state_dict) - record_interval = cfg.logger.log_interval - pbar = tqdm.tqdm(total=cfg.collector.total_frames) - collected_frames = 0 - - # Main loop - r0 = None - l0 = None - frame_skip = cfg.env.frame_skip - ppo_epochs = cfg.loss.ppo_epochs - total_done = 0 - for data in collector: - - frames_in_batch = data.numel() - total_done += data.get(("next", "done")).sum() - collected_frames += frames_in_batch * frame_skip - pbar.update(data.numel()) - - # Log end-of-episode accumulated rewards for training - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if logger is not None and len(episode_rewards) > 0: - logger.log_scalar( - "reward_training", episode_rewards.mean().item(), collected_frames - ) - - losses = TensorDict( - {}, batch_size=[ppo_epochs, -(frames_in_batch // -mini_batch_size)] - ) - for j in range(ppo_epochs): - # Compute GAE - with torch.no_grad(): - data = adv_module(data.to(model_device)).cpu() - - data_reshape = data.reshape(-1) - # Update the data buffer - data_buffer.extend(data_reshape) - - for i, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(model_device) - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, i] = loss.detach() - - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - - # Backward pass - loss_sum.backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=0.5 - ) - losses[j, i]["grad_norm"] = grad_norm - - optim.step() - if scheduler is not None: - scheduler.step() - optim.zero_grad() - - # Logging - if r0 is None: - r0 = data["next", "reward"].mean().item() - if l0 is None: - l0 = loss_sum.item() - pbar.set_description( - f"loss: {loss_sum.item(): 4.4f} (init: {l0: 4.4f}), reward: {data['next', 'reward'].mean(): 4.4f} (init={r0: 4.4f})" - ) - if i + 1 != -(frames_in_batch // -mini_batch_size): - print( - f"Should have had {- (frames_in_batch // -mini_batch_size)} iters but had {i}." - ) - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) - if logger is not None: - for key, value in losses.items(): - logger.log_scalar(key, value.item(), collected_frames) - logger.log_scalar("total_done", total_done, collected_frames) - - collector.update_policy_weights_() - - # Test current policy - if ( - logger is not None - and (collected_frames - frames_in_batch) // record_interval - < collected_frames // record_interval - ): - - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - test_env.eval() - actor.eval() - # Generate a complete episode - td_test = test_env.rollout( - policy=actor, - max_steps=10_000_000, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - ).clone() - logger.log_scalar( - "reward_testing", - td_test["next", "reward"].sum().item(), - collected_frames, - ) - actor.train() - del td_test - - -if __name__ == "__main__": - main() diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py new file mode 100644 index 00000000000..5a3783cce78 --- /dev/null +++ b/examples/ppo/ppo_atari.py @@ -0,0 +1,204 @@ +""" +This script reproduces the Proximal Policy Optimization (PPO) Algorithm +results from Schulman et al. 2017 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import numpy as np + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_atari import make_parallel_env, make_ppo_models + + device = "cpu" if not torch.cuda.is_available() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + mini_batch_size = cfg.loss.mini_batch_size // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Create models (check utils_atari.py) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name) + actor, critic, critic_head = ( + actor.to(device), + critic.to(device), + critic_head.to(device), + ) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch), + sampler=sampler, + batch_size=mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizer + optim = torch.optim.Adam( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + ) + + # Create logger + exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger(cfg.logger.backend, logger_name="ppo", experiment_name=exp_name) + + # Create test environment + test_env = make_parallel_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = ( + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + ) + + for data in collector: + + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Train loging + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + logger.log_scalar( + "reward_train", episode_rewards.mean().item(), collected_frames + ) + + # Apply episodic end of life + data["done"].copy_(data["end_of_life"]) + data["next", "done"].copy_(data["next", "end_of_life"]) + + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + for j in range(cfg.loss.ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for i, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1 - (num_network_updates / total_network_updates) + if cfg.optim.anneal_lr: + for g in optim.param_groups: + g["lr"] = cfg.optim.lr * alpha + if cfg.loss.anneal_clip_epsilon: + loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, i] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + logger.log_scalar(key, value.item(), collected_frames) + logger.log_scalar("lr", alpha * cfg.optim.lr, collected_frames) + logger.log_scalar( + "clip_epsilon", alpha * cfg.loss.clip_epsilon, collected_frames + ) + + # Test logging + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if (collected_frames - frames_in_batch) // test_interval < ( + collected_frames // test_interval + ): + actor.eval() + test_rewards = [] + for _ in range(cfg.logger.num_test_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + logger.log_scalar("reward_test", test_rewards.mean(), collected_frames) + actor.train() + + collector.update_policy_weights_() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/ppo/ppo_atari_pong.png b/examples/ppo/ppo_atari_pong.png deleted file mode 100644 index 639545f29e4..00000000000 Binary files a/examples/ppo/ppo_atari_pong.png and /dev/null differ diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py new file mode 100644 index 00000000000..67d31605b7c --- /dev/null +++ b/examples/ppo/ppo_mujoco.py @@ -0,0 +1,179 @@ +""" +This script reproduces the Proximal Policy Optimization (PPO) Algorithm +results from Schulman et al. 2017 for the on MuJoCo Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import numpy as np + import torch.optim + import tqdm + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_mujoco import make_env, make_ppo_models + + # Define paper hyperparameters + device = "cpu" if not torch.cuda.is_available() else "cuda" + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size + total_network_updates = ( + (cfg.collector.total_frames // cfg.collector.frames_per_batch) + * cfg.loss.ppo_epochs + * num_mini_batches + ) + + # Create models (check utils_mujoco.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(cfg.collector.frames_per_batch, device=device), + sampler=sampler, + batch_size=cfg.loss.mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizers + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + + # Create logger + exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger(cfg.logger.backend, logger_name="ppo", experiment_name=exp_name) + + # Create test environment + test_env = make_env(cfg.env.env_name, device) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + for data in collector: + + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Train loging + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + logger.log_scalar( + "reward_train", episode_rewards.mean().item(), collected_frames + ) + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + for j in range(cfg.loss.ppo_epochs): + + for i, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for g in actor_optim.param_groups: + g["lr"] = cfg.optim.lr * alpha + for g in critic_optim.param_groups: + g["lr"] = cfg.optim.lr * alpha + num_network_updates += 1 + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, i] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + actor_loss.backward() + critic_loss.backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + actor_optim.zero_grad() + critic_optim.zero_grad() + + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + logger.log_scalar(key, value.item(), collected_frames) + + # Test logging + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if (collected_frames - frames_in_batch) // cfg.logger.test_interval < ( + collected_frames // cfg.logger.test_interval + ): + actor.eval() + test_rewards = [] + for _ in range(cfg.logger.num_test_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + logger.log_scalar("reward_test", test_rewards.mean(), collected_frames) + actor.train() + + collector.update_policy_weights_() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/ppo/ppo_mujoco_halfcheetah.png b/examples/ppo/ppo_mujoco_halfcheetah.png deleted file mode 100644 index f168a5d40f3..00000000000 Binary files a/examples/ppo/ppo_mujoco_halfcheetah.png and /dev/null differ diff --git a/examples/ppo/ppo_myo.py b/examples/ppo/ppo_myo.py new file mode 100644 index 00000000000..d460890b837 --- /dev/null +++ b/examples/ppo/ppo_myo.py @@ -0,0 +1,191 @@ +""" +This script reproduces the Proximal Policy Optimization (PPO) Algorithm +results from Schulman et al. 2017 for the on MuJoCo Environments. +""" +import hydra + +from torchrl.collectors import MultiSyncDataCollector + + +@hydra.main(config_path=".", config_name="config_myo", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import numpy as np + import torch.optim + import tqdm + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_myo import make_env, make_ppo_models + + # Define paper hyperparameters + device = "cpu" if not torch.cuda.is_available() else "cuda" + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size + total_network_updates = ( + (cfg.collector.total_frames // cfg.collector.frames_per_batch) + * cfg.loss.ppo_epochs + * num_mini_batches + ) + + # Create models (check utils_mujoco.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + if cfg.collector.num_envs == 1: + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + else: + collector = MultiSyncDataCollector( + create_env_fn=cfg.collector.num_envs * [make_env(cfg.env.env_name, device)], + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(cfg.collector.frames_per_batch, device=device), + sampler=sampler, + batch_size=cfg.loss.mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizers + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + + # Create logger + exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger(cfg.logger.backend, logger_name="ppo", experiment_name=exp_name) + + # Create test environment + test_env = make_env(cfg.env.env_name, device) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + for data in collector: + + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Train loging + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + logger.log_scalar( + "reward_train", episode_rewards.mean().item(), collected_frames + ) + + for j in range(cfg.loss.ppo_epochs): + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + # Update the data buffer + data_buffer.empty() + data_buffer.extend(data_reshape) + + for i, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for g in actor_optim.param_groups: + g["lr"] = cfg.optim.lr * alpha + for g in critic_optim.param_groups: + g["lr"] = cfg.optim.lr * alpha + num_network_updates += 1 + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, i] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + actor_loss.backward() + critic_loss.backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + actor_optim.zero_grad() + critic_optim.zero_grad() + + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + logger.log_scalar(key, value.item(), collected_frames) + + # Test logging + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if (collected_frames - frames_in_batch) // cfg.logger.test_interval < ( + collected_frames // cfg.logger.test_interval + ): + actor.eval() + test_rewards = [] + for _ in range(cfg.logger.num_test_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + logger.log_scalar("reward_test", test_rewards.mean(), collected_frames) + actor.train() + + collector.update_policy_weights_() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/ppo/training_curves.md b/examples/ppo/training_curves.md deleted file mode 100644 index d9f99eadb42..00000000000 --- a/examples/ppo/training_curves.md +++ /dev/null @@ -1,13 +0,0 @@ -# PPO Example Results - -## Atari Pong Environment - -We tested the Proximal Policy Optimization (PPO) algorithm on the Atari Pong environment. The hyperparameters used for the training are specified in the config.yaml file and are the same as those used in the original PPO paper (https://arxiv.org/abs/1707.06347). - -![ppo_atari_pong.png](ppo_atari_pong.png) - -## MuJoCo HalfCheetah Environment - -Additionally, we also tested the PPO algorithm on the MuJoCo HalfCheetah environment. The hyperparameters used for the training are specified in the config_example2.yaml file and are also the same as those used in the original PPO paper. However, this implementation uses a shared policy-value architecture. - -![ppo_mujoco_halfcheetah.png](ppo_mujoco_halfcheetah.png) diff --git a/examples/ppo/utils.py b/examples/ppo/utils.py deleted file mode 100644 index 977d8e20b64..00000000000 --- a/examples/ppo/utils.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch.nn -import torch.optim -from tensordict.nn import NormalParamExtractor, TensorDictModule - -from torchrl.collectors import SyncDataCollector -from torchrl.data import CompositeSpec, LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement - -from torchrl.data.tensor_specs import DiscreteBox -from torchrl.envs import ( - CatFrames, - CatTensors, - DoubleToFloat, - EnvCreator, - ExplorationType, - GrayScale, - NoopResetEnv, - ObservationNorm, - ParallelEnv, - Resize, - RewardScaling, - RewardSum, - StepCounter, - ToTensorImage, - TransformedEnv, -) -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.modules import ( - ActorValueOperator, - ConvNet, - MLP, - OneHotCategorical, - ProbabilisticActor, - TanhNormal, - ValueOperator, -) -from torchrl.objectives import ClipPPOLoss -from torchrl.objectives.value.advantages import GAE -from torchrl.record.loggers import generate_exp_name, get_logger -from torchrl.trainers.helpers.envs import LIBS - - -DEFAULT_REWARD_SCALING = { - "Hopper-v1": 5, - "Walker2d-v1": 5, - "HalfCheetah-v1": 5, - "cheetah": 5, - "Ant-v2": 5, - "Humanoid-v2": 20, - "humanoid": 100, -} - - -# ==================================================================== -# Environment utils -# ----------------- - - -def make_base_env(env_cfg, from_pixels=None): - env_library = LIBS[env_cfg.env_library] - env_kwargs = { - "env_name": env_cfg.env_name, - "frame_skip": env_cfg.frame_skip, - "from_pixels": env_cfg.from_pixels - if from_pixels is None - else from_pixels, # for rendering - "pixels_only": False, - "device": env_cfg.device, - } - if env_library is DMControlEnv: - env_task = env_cfg.env_task - env_kwargs.update({"task_name": env_task}) - env = env_library(**env_kwargs) - return env - - -def make_transformed_env(base_env, env_cfg): - if env_cfg.noop > 1: - base_env = TransformedEnv(env=base_env, transform=NoopResetEnv(env_cfg.noop)) - from_pixels = env_cfg.from_pixels - if from_pixels: - return make_transformed_env_pixels(base_env, env_cfg) - else: - return make_transformed_env_states(base_env, env_cfg) - - -def make_transformed_env_pixels(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - env.append_transform(RewardScaling(0.0, reward_scaling)) - - env.append_transform(ToTensorImage()) - env.append_transform(GrayScale()) - env.append_transform(Resize(84, 84)) - env.append_transform(CatFrames(N=4, dim=-3)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter()) - - obs_norm = ObservationNorm(in_keys=["pixels"], standard_normal=True) - env.append_transform(obs_norm) - - env.append_transform(DoubleToFloat()) - return env - - -def make_transformed_env_states(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - - env.append_transform(RewardScaling(0.0, reward_scaling)) - - # we concatenate all the state vectors - # even if there is a single tensor, it'll be renamed in "observation_vector" - selected_keys = [ - key for key in env.observation_spec.keys(True, True) if key != "pixels" - ] - out_key = "observation_vector" - env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter()) - # obs_norm = ObservationNorm(in_keys=[out_key]) - # env.append_transform(obs_norm) - - env.append_transform(DoubleToFloat()) - return env - - -def make_parallel_env(env_cfg, state_dict): - num_envs = env_cfg.num_envs - env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg - ) - init_stats(env, 3, env_cfg.from_pixels) - env.load_state_dict(state_dict, strict=False) - return env - - -def get_stats(env_cfg): - env = make_transformed_env(make_base_env(env_cfg), env_cfg) - init_stats(env, env_cfg.n_samples_stats, env_cfg.from_pixels) - state_dict = env.state_dict() - for key in list(state_dict.keys()): - if key.endswith("loc") or key.endswith("scale"): - continue - del state_dict[key] - return state_dict - - -def init_stats(env, n_samples_stats, from_pixels): - for t in env.transform: - if isinstance(t, ObservationNorm): - if from_pixels: - t.init_stats( - n_samples_stats, - cat_dim=-4, - reduce_dim=tuple( - -i for i in range(1, len(t.parent.batch_size) + 5) - ), - keep_dims=(-1, -2, -3), - ) - else: - t.init_stats(n_samples_stats) - - -def make_test_env(env_cfg, state_dict): - env_cfg.num_envs = 1 - env = make_parallel_env(env_cfg, state_dict=state_dict) - return env - - -# ==================================================================== -# Collector and replay buffer -# --------------------------- - - -def make_collector(cfg, policy): - env_cfg = cfg.env - collector_cfg = cfg.collector - collector_class = SyncDataCollector - state_dict = get_stats(env_cfg) - collector = collector_class( - make_parallel_env(env_cfg, state_dict=state_dict), - policy, - frames_per_batch=collector_cfg.frames_per_batch, - total_frames=collector_cfg.total_frames, - device=collector_cfg.collector_device, - storing_device="cpu", - max_frames_per_traj=collector_cfg.max_frames_per_traj, - ) - return collector, state_dict - - -def make_data_buffer(cfg): - cfg_collector = cfg.collector - cfg_loss = cfg.loss - sampler = SamplerWithoutReplacement() - return TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg_collector.frames_per_batch), - sampler=sampler, - batch_size=cfg_loss.mini_batch_size, - ) - - -# ==================================================================== -# Model -# ----- -# -# We give one version of the model for learning from pixels, and one for state. -# TorchRL comes in handy at this point, as the high-level interactions with -# these models is unchanged, regardless of the modality. - - -def make_ppo_models(cfg): - - env_cfg = cfg.env - from_pixels = env_cfg.from_pixels - proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) - init_stats(proof_environment, 3, env_cfg.from_pixels) - - if not from_pixels: - # we must initialize the observation norm transform - # init_stats( - # proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels - # ) - common_module, policy_module, value_module = make_ppo_modules_state( - proof_environment - ) - else: - common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment - ) - - # Wrap modules in a single ActorCritic operator - actor_critic = ActorValueOperator( - common_operator=common_module, - policy_operator=policy_module, - value_operator=value_module, - ).to(cfg.optim.device) - - with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) - del td - - actor = actor_critic.get_policy_operator() - critic = actor_critic.get_value_operator() - critic_head = actor_critic.get_value_head() - - return actor, critic, critic_head - - -def make_ppo_modules_state(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["observation_vector"].shape - - # Define distribution class and kwargs - continuous_actions = False - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - continuous_actions = True - num_outputs = proof_environment.action_spec.shape[-1] * 2 - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.low, - "max": proof_environment.action_spec.space.high, - "tanh_loc": False, - } - - # Define input keys - in_keys = ["observation_vector"] - shared_features_size = 256 - - # Define a shared Module and TensorDictModule - common_mlp = MLP( - in_features=input_shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=shared_features_size, - num_cells=[64, 64], - ) - common_module = TensorDictModule( - module=common_mlp, - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=shared_features_size, out_features=num_outputs, num_cells=[] - ) - if continuous_actions: - policy_net = torch.nn.Sequential( - policy_net, NormalParamExtractor(scale_lb=1e-2) - ) - - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["loc", "scale"] if continuous_actions else ["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["loc", "scale"] if continuous_actions else ["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - safe=True, - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP(in_features=shared_features_size, out_features=1, num_cells=[]) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -def make_ppo_modules_pixels(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["pixels"].shape - - # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.low, - "max": proof_environment.action_spec.space.high, - } - - # Define input keys - in_keys = ["pixels"] - - # Define a shared Module and TensorDictModule (CNN + MLP) - common_cnn = ConvNet( - activation_class=torch.nn.ReLU, - num_cells=[32, 64, 64], - kernel_sizes=[8, 4, 3], - strides=[4, 2, 1], - ) - common_cnn_output = common_cnn(torch.ones(input_shape)) - common_mlp = MLP( - in_features=common_cnn_output.shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=512, - num_cells=[], - ) - common_mlp_output = common_mlp(common_cnn_output) - - # Define shared net as TensorDictModule - common_module = TensorDictModule( - module=torch.nn.Sequential(common_cnn, common_mlp), - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=common_mlp_output.shape[-1], - out_features=num_outputs, - activation_class=torch.nn.ReLU, - num_cells=[256], - ) - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - # safe=True, - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP( - activation_class=torch.nn.ReLU, - in_features=common_mlp_output.shape[-1], - out_features=1, - num_cells=[256], - ) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -# ==================================================================== -# PPO Loss -# --------- - - -def make_advantage_module(loss_cfg, value_network): - advantage_module = GAE( - gamma=loss_cfg.gamma, - lmbda=loss_cfg.gae_lamdda, - value_network=value_network, - average_gae=True, - ) - return advantage_module - - -def make_loss(loss_cfg, actor_network, value_network, value_head): - advantage_module = make_advantage_module(loss_cfg, value_network) - loss_module = ClipPPOLoss( - actor=actor_network, - critic=value_head, - clip_epsilon=loss_cfg.clip_epsilon, - loss_critic_type=loss_cfg.loss_critic_type, - entropy_coef=loss_cfg.entropy_coef, - critic_coef=loss_cfg.critic_coef, - normalize_advantage=True, - ) - return loss_module, advantage_module - - -def make_optim(optim_cfg, loss_module): - optim = torch.optim.Adam( - loss_module.parameters(), - lr=optim_cfg.lr, - weight_decay=optim_cfg.weight_decay, - ) - return optim - - -# ==================================================================== -# Logging and recording -# --------------------- - - -def make_logger(logger_cfg): - exp_name = generate_exp_name("PPO", logger_cfg.exp_name) - logger_cfg.exp_name = exp_name - logger = get_logger(logger_cfg.backend, logger_name="ppo", experiment_name=exp_name) - return logger diff --git a/examples/ppo/utils_atari.py b/examples/ppo/utils_atari.py new file mode 100644 index 00000000000..d0881f34f2a --- /dev/null +++ b/examples/ppo/utils_atari.py @@ -0,0 +1,212 @@ +import gym +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import DiscreteBox +from torchrl.envs import ( + CatFrames, + default_info_dict_reader, + DoubleToFloat, + EnvCreator, + ExplorationType, + GrayScale, + NoopResetEnv, + ParallelEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + TanhNormal, + ValueOperator, +) + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. It helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + + def step(self, action): + obs, rew, done, info = self.env.step(action) + lives = self.env.unwrapped.ale.lives() + info["end_of_life"] = False + if (lives < self.lives) or done: + info["end_of_life"] = True + self.lives = lives + return obs, rew, done, info + + def reset(self, **kwargs): + reset_data = self.env.reset(**kwargs) + self.lives = self.env.unwrapped.ale.lives() + return reset_data + + +def make_base_env( + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False +): + env = gym.make(env_name) + if not is_test: + env = EpisodicLifeEnv(env) + env = GymWrapper( + env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + reader = default_info_dict_reader(["end_of_life"]) + env.set_info_dict_reader(reader) + return env + + +def make_parallel_env(env_name, device, is_test=False): + num_envs = 8 + env = ParallelEnv( + num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + ) + env = TransformedEnv(env) + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + if not is_test: + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(DoubleToFloat()) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + if isinstance(proof_environment.action_spec.space, DiscreteBox): + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + else: # is ContinuousBox + num_outputs = proof_environment.action_spec.shape + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + } + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_parallel_env(env_name, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + with torch.no_grad(): + td = proof_environment.rollout(max_steps=100, break_when_any_done=False) + td = actor_critic(td) + del td + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + critic_head = actor_critic.get_value_head() + + del proof_environment + + return actor, critic, critic_head diff --git a/examples/ppo/utils_mujoco.py b/examples/ppo/utils_mujoco.py new file mode 100644 index 00000000000..82ee2e71747 --- /dev/null +++ b/examples/ppo/utils_mujoco.py @@ -0,0 +1,114 @@ +import gym +import torch.nn +import torch.optim + +from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + ClipTransform, + DoubleToFloat, + ExplorationType, + RewardSum, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name="HalfCheetah-v4", device="cpu"): + env = gym.make(env_name) + env = GymWrapper(env, device=device) + env = TransformedEnv(env) + env.append_transform(RewardSum()) + env.append_transform(VecNorm(in_keys=["observation"])) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) + env.append_transform(DoubleToFloat(in_keys=["observation"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_models_state(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["observation"].shape + + # Define policy output distribution class + num_outputs = proof_environment.action_spec.shape[-1] + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + "tanh_loc": False, + } + + # Define policy architecture + policy_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=num_outputs, # predict only loc + num_cells=[64, 64], + ) + + # Initialize policy weights + for layer in policy_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 1.0) + layer.bias.data.zero_() + + # Add state-independent normal scale + policy_mlp = torch.nn.Sequential( + policy_mlp, + AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + TensorDictModule( + module=policy_mlp, + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define value architecture + value_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=1, + num_cells=[64, 64], + ) + + # Initialize value weights + for layer in value_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 0.01) + layer.bias.data.zero_() + + # Define value module + value_module = ValueOperator( + value_mlp, + in_keys=["observation"], + ) + + return policy_module, value_module + + +def make_ppo_models(env_name): + proof_environment = make_env(env_name, device="cpu") + actor, critic = make_ppo_models_state(proof_environment) + return actor, critic diff --git a/examples/ppo/utils_myo.py b/examples/ppo/utils_myo.py new file mode 100644 index 00000000000..64947c0c487 --- /dev/null +++ b/examples/ppo/utils_myo.py @@ -0,0 +1,116 @@ +import gym +import torch.nn +import torch.optim + +from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + ClipTransform, + DoubleToFloat, + ExplorationType, + RewardSum, + TransformedEnv, + VecNorm, CatTensors, ExcludeTransform, +) +from torchrl.envs.libs.gym import GymWrapper +from torchrl.envs.libs.robohive import RoboHiveEnv +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name="HalfCheetah-v4", device="cpu"): + env = RoboHiveEnv(env_name, include_info=False, device=device) + env = TransformedEnv(env) + env.append_transform(RewardSum()) + env.append_transform(CatTensors(["qpos", "qvel", "tip_pos", "reach_err"], out_key="observation")) + env.append_transform(ExcludeTransform("time", "state", "rwd_dense", "rwd_dict", "visual_dict")) + env.append_transform(VecNorm(in_keys=["observation"])) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) + env.append_transform(DoubleToFloat(in_keys=["observation"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_models_state(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["observation"].shape + + # Define policy output distribution class + num_outputs = proof_environment.action_spec.shape[-1] + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + "tanh_loc": False, + } + + # Define policy architecture + policy_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=num_outputs, # predict only loc + num_cells=[64, 64], + ) + + # Initialize policy weights + for layer in policy_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 1.0) + layer.bias.data.zero_() + + # Add state-independent normal scale + policy_mlp = torch.nn.Sequential( + policy_mlp, + AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + TensorDictModule( + module=policy_mlp, + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define value architecture + value_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=1, + num_cells=[64, 64], + ) + + # Initialize value weights + for layer in value_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 0.01) + layer.bias.data.zero_() + + # Define value module + value_module = ValueOperator( + value_mlp, + in_keys=["observation"], + ) + + return policy_module, value_module + + +def make_ppo_models(env_name): + proof_environment = make_env(env_name, device="cpu") + actor, critic = make_ppo_models_state(proof_environment) + return actor, critic diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 78eefa6d443..473df10060b 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -82,6 +82,12 @@ class RoboHiveEnv(GymEnv): else: CURR_DIR = None + def __init__(self, env_name, include_info: bool=True, **kwargs): + self.include_info = include_info + kwargs["env_name"] = env_name + self._set_gym_args(kwargs) + super().__init__(**kwargs) + @classmethod def register_envs(cls): @@ -304,19 +310,20 @@ def read_obs(self, observation): return super().read_obs(out) def read_info(self, info, tensordict_out): - out = {} - for key, value in info.items(): - if key in ("obs_dict", "done", "reward", *self._env.obs_keys, "act"): - continue - if isinstance(value, dict): - value = {key: _val for key, _val in value.items() if _val is not None} - value = make_tensordict(value, batch_size=[]) - if value is not None: - out[key] = value - tensordict_out.update(out) - tensordict_out.update( - tensordict_out.apply(lambda x: x.reshape((1,)) if not x.shape else x) - ) + if self.include_info: + out = {} + for key, value in info.items(): + if key in ("obs_dict", "done", "reward", *self._env.obs_keys, "act"): + continue + if isinstance(value, dict): + value = {key: _val for key, _val in value.items() if _val is not None} + value = make_tensordict(value, batch_size=[]) + if value is not None: + out[key] = value + tensordict_out.update(out) + tensordict_out.update( + tensordict_out.apply(lambda x: x.reshape((1,)) if not x.shape else x) + ) return tensordict_out def to(self, *args, **kwargs):