diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 26979e2ae96..6ac6e001cb4 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -138,7 +138,7 @@ CQL CQLLoss DT ----- +-- .. autosummary:: :toctree: generated/ @@ -148,7 +148,7 @@ DT OnlineDTLoss TD3 ----- +--- .. autosummary:: :toctree: generated/ @@ -156,6 +156,15 @@ TD3 TD3Loss +TQC +--- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + TQCLoss + PPO --- diff --git a/examples/tqc/config.yaml b/examples/tqc/config.yaml new file mode 100755 index 00000000000..e5e634a2107 --- /dev/null +++ b/examples/tqc/config.yaml @@ -0,0 +1,52 @@ +# environment and task +env: + name: MountainCarContinuous-v0 + task: "" + exp_name: ${env.name}_TQC + library: gymnasium + max_episode_steps: 1_000 + seed: 42 + +# collector +collector: + total_frames: 1_000_000 + init_random_frames: 25_000 + frames_per_batch: 1_000 + collector_device: cpu + env_per_collector: 1 + reset_at_each_iter: False + +# replay buffer +replay_buffer: + size: 1_000_000 + prb: False # use prioritized experience replay + scratch_dir: + +# optim +optim: + utd_ratio: 1.0 + gamma: 0.99 + lr: 3.0e-4 + weight_decay: 0.0 + batch_size: 256 + target_update_polyak: 0.995 + alpha_init: 1.0 + adam_eps: 1.0e-8 + +# network +network: + actor_hidden_sizes: [256, 256] + critic_hidden_sizes: [512, 512, 512] + n_quantiles: 25 + n_nets: 5 + top_quantiles_to_drop_per_net: 2 + activation: relu + default_policy_scale: 1.0 + scale_lb: 0.1 + device: cuda + +# logging +logger: + backend: wandb + mode: online + eval_iter: 25_000 diff --git a/examples/tqc/tqc.py b/examples/tqc/tqc.py new file mode 100644 index 00000000000..fb3199fd10c --- /dev/null +++ b/examples/tqc/tqc.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""TQC Example. + +This is a simple self-contained example of a TQC training script. + +The implementation is based on the implementation of SAC in the examples +directory. TQC was introduced in + +"Controlling Overestimation Bias with Truncated Mixture of Continuous +Distributional Quantile Critics" (Arsenii Kuznetsov, Pavel Shvechikov, +Alexander Grishin, Dmitry Vetrov, 2020) + +Available from https://proceedings.mlr.press/v119/kuznetsov20a.html. + +Oftentimes, we follow the naming conventions used in the original TQC +PyTorch implementation, to facilitate the comparison with the present +implementation. Original PyTorch TQC code is available here: + +https://github.com/SamsungLabs/tqc_pytorch/tree/master + +All hyperparameters are set to the values used in the original +implementation. + +The helper functions are coded in the utils.py associated with this script. +""" + +import time + +import hydra +import numpy as np +import torch +import torch.cuda +import tqdm +from tensordict import TensorDict +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from torchrl.record.loggers import generate_exp_name, get_logger +from utils import ( + log_metrics, + make_collector, + make_environment, + make_loss_module, + make_replay_buffer, + make_tqc_agent, + make_tqc_optimizer, +) + + +@hydra.main(version_base="1.1", config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + device = torch.device(cfg.network.device) + + # Create logger + exp_name = generate_exp_name("SAC", cfg.env.exp_name) + logger = None + # TO-DO: Add logging back in before pushing to git repo + # if cfg.logger.backend: + # logger = get_logger( + # logger_type=cfg.logger.backend, + # logger_name="sac_logging/wandb", + # experiment_name=exp_name, + # wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + # ) + + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create environments + train_env, eval_env = make_environment(cfg) + + # Create agent + model, exploration_policy = make_tqc_agent(cfg, train_env, eval_env, device) + + # Create SAC loss + loss_module, target_net_updater = make_loss_module(cfg, model) + + # Create off-policy collector + collector = make_collector(cfg, train_env, exploration_policy) + + # Create replay buffer + replay_buffer = make_replay_buffer( + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + device=device, + ) + + # Create optimizers + ( + optimizer_actor, + optimizer_critic, + optimizer_alpha, + ) = make_tqc_optimizer(cfg, loss_module) + + # Main loop + start_time = time.time() + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + cfg.collector.env_per_collector + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + prb = cfg.replay_buffer.prb + eval_iter = cfg.logger.eval_iter + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.env.max_episode_steps + + sampling_start = time.time() + for i, tensordict in enumerate(collector): + + sampling_time = time.time() - sampling_start + # Update weights of the inference policy + collector.update_policy_weights_() + + pbar.update(tensordict.numel()) + + tensordict = tensordict.reshape(-1) + current_frames = tensordict.numel() + # Add to replay buffer + replay_buffer.extend(tensordict.cpu()) + collected_frames += current_frames + + # Optimization steps + training_start = time.time() + if collected_frames >= init_random_frames: + losses = TensorDict( + {}, + batch_size=[ + num_updates, + ], + ) + for i in range(num_updates): + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample().clone() + + # Compute loss + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_critic"] + alpha_loss = loss_td["loss_alpha"] + + # Update actor + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + + # Update alpha + optimizer_alpha.zero_grad() + alpha_loss.backward() + optimizer_alpha.step() + + losses[i] = loss_td.select( + "loss_actor", "loss_critic", "loss_alpha" + ).detach() + + # Update qnet_target params + target_net_updater.step() + + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] + ) + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/critic_loss"] = ( + losses.get("loss_critic").mean().item() + ) + metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item() + metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item() + metrics_to_log["train/alpha"] = loss_td["alpha"].item() + metrics_to_log["train/entropy"] = loss_td["entropy"].item() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_time = time.time() - eval_start + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + + sampling_start = time.time() + + collector.shutdown() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/tqc/utils.py b/examples/tqc/utils.py new file mode 100644 index 00000000000..9a89b9cd314 --- /dev/null +++ b/examples/tqc/utils.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +from contextlib import nullcontext +from typing import Tuple + +import torch +from tensordict.nn import InteractionType, TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import nn, optim +from torchrl.collectors import SyncDataCollector +from torchrl.data import ( + CompositeSpec, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ActorCriticWrapper, MLP, ProbabilisticActor, ValueOperator +from torchrl.modules.distributions import TanhNormal +from torchrl.objectives import SoftUpdate, TQCLoss +from torchrl.objectives.utils import ( + _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator + + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(task, device="cpu"): + with set_gym_backend("gym"): + return GymEnv( + task, + device=device, + ) + + +def apply_env_transforms(env, max_episode_steps=1000): + transformed_env = TransformedEnv( + env, + Compose( + InitTracker(), + StepCounter(max_episode_steps), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg): + """Make environments for training and evaluation.""" + parallel_env = ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda: env_maker(task=cfg.env.name)), + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + + eval_env = TransformedEnv( + ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda: env_maker(task=cfg.env.name)), + ), + train_env.transform.clone(), + ) + return train_env, eval_env + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, train_env, actor_model_explore): + """Make collector.""" + collector = SyncDataCollector( + train_env, + actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=cfg.collector.collector_device, + ) + collector.set_seed(cfg.env.seed) + return collector + + +def make_replay_buffer( + batch_size, + prb=False, + buffer_size=1_000_000, + buffer_scratch_dir=None, + device="cpu", + prefetch=3, +): + with ( + tempfile.TemporaryDirectory() + if buffer_scratch_dir is None + else nullcontext(buffer_scratch_dir) + ) as scratch_dir: + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + return replay_buffer + + +# ==================================================================== +# Model architecture for critic +# ----------------------------- + + +class TQC_Critic(nn.Module): + def __init__(self, cfg): + super().__init__() + self.nets = [] + qvalue_net_kwargs = { + "num_cells": cfg.network.critic_hidden_sizes, + "out_features": cfg.network.n_quantiles, + "activation_class": get_activation(cfg), + } + for i in range(cfg.network.n_nets): + net = MLP(**qvalue_net_kwargs) + self.add_module(f"critic_net_{i}", net) + self.nets.append(net) + + def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + if len(inputs) > 1: + inputs = (torch.cat([*inputs], -1),) + quantiles = torch.stack( + tuple(net(*inputs) for net in self.nets), dim=-2 + ) # batch x n_nets x n_quantiles + return quantiles + + +# ==================================================================== +# Model +# ----- + + +def make_tqc_agent(cfg, train_env, eval_env, device): + """Make TQC agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + actor_net_kwargs = { + "num_cells": cfg.network.actor_hidden_sizes, + "out_features": 2 * action_spec.shape[-1], + "activation_class": get_activation(cfg), + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.low, + "max": action_spec.space.high, + "tanh_loc": False, # can be omitted since this is default value + } + + actor_extractor = NormalParamExtractor( + scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}", + scale_lb=cfg.network.scale_lb, + ) + actor_net = nn.Sequential(actor_net, actor_extractor) + + in_keys_actor = in_keys + actor_module = TensorDictModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_type=InteractionType.RANDOM, + return_log_prob=True, + ) + + # Define Critic Network + qvalue_net = TQC_Critic(cfg) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = eval_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + eval_env.close() + + return model, model[0] + + +# ==================================================================== +# TQC Loss +# -------- + + +def make_loss_module(cfg, model): + """Make loss module and target network updater.""" + # Create TQC loss + top_quantiles_to_drop = ( + cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets + ) + loss_module = TQCLoss( + actor_network=model[0], + qvalue_network=model[1], + top_quantiles_to_drop=top_quantiles_to_drop, + alpha_init=cfg.optim.alpha_init, + ) + loss_module.make_value_estimator( + value_type=ValueEstimators.TD0, gamma=cfg.optim.gamma + ) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) + return loss_module, target_net_updater + + +def make_tqc_optimizer(cfg, loss_module): + critic_params = list(loss_module.critic_params.flatten_keys().values()) + actor_params = list(loss_module.actor_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + optimizer_alpha = optim.Adam( + [loss_module.log_alpha], + lr=3.0e-4, + ) + return optimizer_actor, optimizer_critic, optimizer_alpha + + +# ==================================================================== +# General utils +# ------------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/test/test_cost.py b/test/test_cost.py index 5c1a7dbc41c..ef0d1acacde 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -113,6 +113,7 @@ QMixerLoss, SACLoss, TD3Loss, + TQCLoss, ) from torchrl.objectives.common import LossModule from torchrl.objectives.deprecated import DoubleREDQLoss_deprecated, REDQLoss_deprecated @@ -2440,6 +2441,594 @@ def test_td3_notensordict( assert loss_qvalue == loss_val_td["loss_qvalue"] +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) +class TestTQC(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule( + net, in_keys=[observation_key], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhNormal, + in_keys=["loc", "scale"], + spec=action_spec, + return_log_prob=True, + ) + return actor.to(device) + + def _create_mock_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + action_key="action", + observation_key="observation", + ): + # Actor + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + value = ValueOperator( + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, + ) + return value.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + ): + common = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + value = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + ) + common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + return_log_prob=True, + ), + ) + value_head = Mod( + value, in_keys=["hidden", "action"], out_keys=["state_action_value"] + ) + value = Seq(common, value_head) + return actor, value, common, td + + def _create_mock_data_td3( + self, + batch=8, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + action_key="action", + observation_key="observation", + reward_key="reward", + done_key="done", + terminated_key="terminated", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_td3( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next": { + "observation": next_obs * mask.to(obs.dtype), + "reward": reward * mask.to(obs.dtype), + "done": done, + "terminated": terminated, + }, + "collector": {"mask": mask}, + "action": action * mask.to(obs.dtype), + }, + names=[None, "time"], + device=device, + ) + return td + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("td_est", [ValueEstimators.TD0, None]) + @pytest.mark.parametrize("use_action_spec", [True, False]) + def test_tqc( + self, + device, + td_est, + use_action_spec, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3(device=device) + if use_action_spec: + action_spec = actor.spec + else: + action_spec = None + loss_fn = TQCLoss( + actor, + value, + action_spec=action_spec, + ) + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with _check_td_steady(td): + loss = loss_fn(td) + + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + elif k == "loss_critic": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + assert loss_fn.log_alpha.grad.norm() > 0 + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum([item for _, item in loss.items()]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("use_action_spec", [True, False]) + def test_tqc_state_dict( + self, + device, + use_action_spec, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + if use_action_spec: + action_spec = actor.spec + else: + bounds = (-1, 1) + loss_fn = TQCLoss( + actor, + value, + action_spec=action_spec, + ) + sd = loss_fn.state_dict() + loss_fn2 = TQCLoss( + actor, + value, + action_spec=action_spec, + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_tqc_separate_losses( + self, + device, + separate_losses, + n_act=4, + ): + torch.manual_seed(self.seed) + actor, value, common, td = self._create_mock_common_layer_setup(n_act=n_act) + loss_fn = TQCLoss( + actor, + value, + action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + loss_function="l2", + separate_losses=separate_losses, + ) + with pytest.warns(UserWarning, match="No target network updater has been"): + loss = loss_fn(td) + + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.crtic_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + if separate_losses: + common_layers_no = len(list(common.parameters())) + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + critic_layers = itertools.islice( + loss_fn.critic_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in critic_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("device", get_default_devices()) + def test_tqc_batcher( + self, n, delay_actor, delay_qvalue, device, policy_noise, noise_clip, gamma=0.9 + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_seq_mock_data_td3(device=device) + loss_fn = TQCLoss( + actor, + value, + action_spec=actor.spec, + ) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if (delay_qvalue or delay_actor) + else contextlib.nullcontext() + ), _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + + sum([item for _, item in loss_ms.items()]).backward() + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_actor = loss_fn.target_actor_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue = loss_fn.target_critic_params.clone().values( + include_nested=True, leaves_only=True + ) + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = loss_fn.target_actor_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue2 = loss_fn.target_critic_params.clone().values( + include_nested=True, leaves_only=True + ) + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + + # check that policy is updated after parameter update + actorp_set = set(actor.parameters()) + loss_fnp_set = set(loss_fn.parameters()) + assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set) + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_tqc_tensordict_keys(self, td_est): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = TQCLoss( + actor, + value, + action_spec=actor.spec, + ) + + default_keys = { + "priority": "td_error", + "state_action_value": "state_action_value", + "action": "action", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + value = self._create_mock_value(out_keys=["state_action_value_test"]) + loss_fn = TQCLoss( + actor, + value, + action_spec=actor.spec, + ) + key_mapping = { + "state_action_value": ("value", "state_action_value_test"), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_tqc_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(in_keys=[observation_key]) + qvalue = self._create_mock_value( + observation_key=observation_key, out_keys=["state_action_value"] + ) + td = self._create_mock_data_td3( + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) + + kwargs = { + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + "action": td.get("action"), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + with pytest.warns(UserWarning, match="No target network updater has been"): + torch.manual_seed(0) + loss_val_td = loss(td) + torch.manual_seed(0) + loss_val = loss(**kwargs) + for i in loss_val: + assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" + + for i, key in enumerate(loss.out_keys): + torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + + # test select + loss.select_out_keys("loss_actor", "loss_qvalue") + torch.manual_seed(0) + if torch.__version__ >= "2.0.0": + loss_actor, loss_qvalue = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_qvalue = loss(**kwargs) + return + + assert loss_actor == loss_val_td["loss_actor"] + assert loss_qvalue == loss_val_td["loss_qvalue"] + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 023b22ba3c4..88cdf397df1 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -17,6 +17,7 @@ from .reinforce import ReinforceLoss from .sac import DiscreteSACLoss, SACLoss from .td3 import TD3Loss +from .tqc import TQCLoss from .utils import ( default_value_kwargs, distance_loss, diff --git a/torchrl/objectives/tqc.py b/torchrl/objectives/tqc.py new file mode 100644 index 00000000000..3bbebd73d82 --- /dev/null +++ b/torchrl/objectives/tqc.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from dataclasses import dataclass + +import numpy as np +import torch + +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import InteractionType, set_interaction_type, TensorDictModule +from tensordict.utils import NestedKey +from torchrl.data import CompositeSpec +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ValueEstimators + + +class TQCLoss(LossModule): + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"advantage"``. + value (NestedKey): The input tensordict key where the state value is expected. + Will be used for the underlying value estimator. Defaults to ``"state_value"``. + state_action_value (NestedKey): The input tensordict key where the + state action value is expected. Defaults to ``"state_action_value"``. + log_prob (NestedKey): The input tensordict key where the log probability is expected. + Defaults to ``"_log_prob"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + """ + + action: NestedKey = "action" + value: NestedKey = "state_value" + state_action_value: NestedKey = "state_action_value" + log_prob: NestedKey = "sample_log_prob" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + + def __init__( + self, + actor_network: TensorDictModule, + qvalue_network: TensorDictModule, + top_quantiles_to_drop: float = 10, + alpha_init: float = 1.0, + # no need to pass device, should be handled by actor/qvalue nets + # device: torch.device, + # gamma should be passed to the value estimator construction + # for consistency with other losses + # gamma: float=None, + target_entropy=None, + action_spec=None, + ): + super().__init__() + + self.convert_to_functional( + actor_network, + "actor", + create_target_params=False, + funs_to_decorate=["forward", "get_dist"], + ) + + self.convert_to_functional( + qvalue_network, + "critic", + create_target_params=True, # Create a target critic network + ) + + # self.device = device + for p in self.parameters(): + device = p.device + break + else: + # this should never be reached unless both network have 0 parameter + raise RuntimeError + self.log_alpha = torch.nn.Parameter( + torch.tensor([np.log(alpha_init)], requires_grad=True, device=device) + ) + self.top_quantiles_to_drop = top_quantiles_to_drop + self.target_entropy = target_entropy + self._action_spec = action_spec + self.make_value_estimator() + + @property + def target_entropy(self): + target_entropy = self.__dict__.get("_target_entropy", None) + if target_entropy is None: + # Compute target entropy + action_spec = self._action_spec + if action_spec is None: + action_spec = getattr(self.actor, "spec", None) + if action_spec is None: + raise RuntimeError( + "Could not deduce action spec neither from " + "the actor network nor from the constructor kwargs. " + "Please provide the target entropy during construction." + ) + if not isinstance(action_spec, CompositeSpec): + action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + action_container_len = len(action_spec.shape) + + target_entropy = -float( + action_spec[self.tensor_keys.action] + .shape[action_container_len:] + .numel() + ) + self.target_entropy = target_entropy + return target_entropy + + @target_entropy.setter + def target_entropy(self, value): + if value is not None: + value = float(value) + self._target_entropy = value + + @property + def alpha(self): + return self.log_alpha.exp().detach() + + def value_loss(self, tensordict): + tensordict_copy = tensordict.clone(False) + td_next = tensordict_copy.get("next") + reward = td_next.get(self.tensor_keys.reward) + not_done = td_next.get(self.tensor_keys.done).logical_not() + alpha = self.alpha + + # Q-loss + with torch.no_grad(): + # get policy action + self.actor(td_next, params=self.actor_params) + self.critic(td_next, params=self.target_critic_params) + next_log_pi = td_next.get(self.tensor_keys.log_prob) + next_log_pi = torch.unsqueeze(next_log_pi, dim=-1) + + # compute and cut quantiles at the next state + next_z = td_next.get(self.tensor_keys.state_action_value) + sorted_z, _ = torch.sort(next_z.reshape(*tensordict_copy.batch_size, -1)) + sorted_z_part = sorted_z[..., : -self.top_quantiles_to_drop] + + # compute target + # --- Note --- + # This is computed manually here, since the built-in value estimators in the library + # currently do not support a critic of a shape different from the reward. + # ------------ + target = reward + not_done * self.gamma * ( + sorted_z_part - alpha * next_log_pi + ) + + self.critic(tensordict_copy, params=self.critic_params) + cur_z = tensordict_copy.get(self.tensor_keys.state_action_value) + critic_loss = quantile_huber_loss_f(cur_z, target) + metadata = {} + return critic_loss, metadata + + def actor_loss(self, tensordict): + tensordict_copy = tensordict.clone(False) + alpha = self.alpha + self.actor(tensordict_copy, params=self.actor_params) + self.critic(tensordict_copy, params=self.critic_params) + new_log_pi = tensordict_copy.get(self.tensor_keys.log_prob) + tensordict.set(self.tensor_keys.log_prob, new_log_pi) + actor_loss = ( + alpha * new_log_pi + - tensordict_copy.get(self.tensor_keys.state_action_value) + .mean(-1) + .mean(-1, keepdim=True) + ).mean() + metadata = {} + return actor_loss, metadata + + def alpha_loss(self, tensordict): + log_prob = tensordict.get(self.tensor_keys.log_prob) + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy).detach().mean() + return alpha_loss, {} + + def entropy(self, tensordict): + with set_exploration_type(ExplorationType.RANDOM): + dist = self.actor.get_dist( + tensordict, + params=self.actor_params, + ) + a_reparm = dist.rsample() + log_prob = dist.log_prob(a_reparm).detach() + entropy = -log_prob.mean() + return entropy + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + critic_loss, metadata_value = self.value_loss(tensordict) + actor_loss, metadata_actor = self.actor_loss( + tensordict + ) # Compute actor loss AFTER critic loss + alpha_loss, metadata_alpha = self.alpha_loss(tensordict) + metadata = { + "alpha": self.alpha, + "entropy": self.entropy(tensordict), + } + metadata.update(metadata_alpha) + metadata.update(metadata_value) + metadata.update(metadata_actor) + losses = { + "loss_critic": critic_loss, + "loss_actor": actor_loss, + "loss_alpha": alpha_loss, + } + losses.update(metadata) + return TensorDict(losses, batch_size=[]) + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + """Value estimator settor for TQC. + + The only value estimator supported is ``ValueEstimators.TD0``. + + This method can also be used to set the ``gamma`` factor. + + Args: + value_type (ValueEstimators, optional): the value estimator to be used. + Will raise an exception if it differs from ``ValueEstimators.TD0``. + gamma (float, optional): the gamma factor for the target computation. + Defaults to 0.99. + """ + if value_type not in (ValueEstimators.TD0, None): + raise NotImplementedError( + f"Value type {value_type} is not currently implemented." + ) + self.gamma = hyperparams.pop("gamma", 0.99) + + +# ==================================================================== +# Quantile Huber Loss +# ------------------- + + +def quantile_huber_loss_f(quantiles, samples): + """ + Quantile Huber loss from the original PyTorch TQC implementation. + See: https://github.com/SamsungLabs/tqc_pytorch/blob/master/tqc/functions.py + + quantiles is assumed to be of shape [batch size, n_nets, n_quantiles] + samples is assumed to be of shape [batch size, n_samples] + Arbitrary batch sizes are allowed. + """ + pairwise_delta = ( + samples[..., None, None, :] - quantiles[..., None] + ) # batch x n_nets x n_quantiles x n_samples + abs_pairwise_delta = torch.abs(pairwise_delta) + huber_loss = torch.where( + abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5 + ) + n_quantiles = quantiles.shape[-1] + tau = ( + torch.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + + 1 / 2 / n_quantiles + ) + loss = ( + torch.abs(tau[..., None, :, None] - (pairwise_delta < 0).float()) * huber_loss + ).mean() + return loss