From c10adf14f6394a5f952cb55c20c56291f335d881 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 10 Jan 2025 14:12:20 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_cost.py | 16 ++++----- torchrl/objectives/ppo.py | 72 ++++++++++++++++++++++++--------------- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 7c7c97eedfc..5d54246419d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -8158,18 +8158,19 @@ def _create_seq_mock_data_ppo( obs = total_obs[:, :T] next_obs = total_obs[:, 1:] if atoms: - action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( - -1, 1 - ) + action_shape = (batch, T, atoms, action_dim) else: - action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + action_shape = (batch, T, action_dim) + params_mean = torch.randn(action_shape, device=device) / 10 + params_scale = torch.rand(action_shape, device=device) / 10 + action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp( + -1, 1 + ) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) - params_mean = torch.randn_like(action) / 10 - params_scale = torch.rand_like(action) / 10 loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( @@ -8184,9 +8185,6 @@ def _create_seq_mock_data_ppo( }, "collector": {"mask": mask}, action_key: {"action1": action} if composite_action_dist else action, - sample_log_prob_key: ( - torch.randn_like(action[..., 1]) / 10 - ).masked_fill_(~mask, 0.0), }, device=device, names=[None, "time"], diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 5411687eb5e..175c5cddfdd 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -510,17 +510,28 @@ def _log_weight( # current log_prob of actions action = _maybe_get_or_select(tensordict, self.tensor_keys.action) + is_composite = None + if all(key in tensordict for key in self.actor_network.dist_params_keys): + prev_dist = self.actor_network.build_dist_from_params(tensordict.detach()) + kwargs, is_composite = _get_composite_kwargs(prev_dist) + if is_composite: + prev_log_prob = prev_dist.log_prob(tensordict, **kwargs) + else: + prev_log_prob = prev_dist.log_prob(action, **kwargs) + print('prev_log_prob', prev_log_prob) + else: + try: + prev_log_prob = _maybe_get_or_select( + tensordict, self.tensor_keys.sample_log_prob + ) + except KeyError as err: + raise _make_lp_get_error(self.tensor_keys, tensordict, err) + with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): - dist = self.actor_network.get_dist(tensordict) + current_dist = self.actor_network.get_dist(tensordict) - try: - prev_log_prob = _maybe_get_or_select( - tensordict, self.tensor_keys.sample_log_prob - ) - except KeyError as err: - raise _make_lp_get_error(self.tensor_keys, tensordict, err) if prev_log_prob.requires_grad: raise RuntimeError( @@ -532,25 +543,11 @@ def _log_weight( f"tensordict stored {self.tensor_keys.action} requires grad." ) if isinstance(action, torch.Tensor): - log_prob = dist.log_prob(action) + log_prob = current_dist.log_prob(action) else: - if isinstance(dist, CompositeDistribution): - is_composite = True - aggregate = dist.aggregate_probabilities - if aggregate is None: - aggregate = False - include_sum = dist.include_sum - if include_sum is None: - include_sum = False - kwargs = { - "inplace": False, - "aggregate_probabilities": aggregate, - "include_sum": include_sum, - } - else: - is_composite = False - kwargs = {} - log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs) + if is_composite is None: + kwargs, is_composite = _get_composite_kwargs(current_dist) + log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs) if ( is_composite and not is_tensor_collection(prev_log_prob) @@ -564,7 +561,7 @@ def _log_weight( if is_tensor_collection(kl_approx): kl_approx = _sum_td_features(kl_approx) - return log_weight, dist, kl_approx + return log_weight, current_dist, kl_approx def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: """Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``.""" @@ -640,6 +637,9 @@ def _cached_critic_network_params_detached(self): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) + + log_weight, dist, kl_approx = self._log_weight(tensordict) + advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: self.value_estimator( @@ -653,7 +653,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: scale = advantage.std().clamp_min(1e-6) advantage = (advantage - loc) / scale - log_weight, dist, kl_approx = self._log_weight(tensordict) if is_tensor_collection(log_weight): log_weight = _sum_td_features(log_weight) log_weight = log_weight.view(advantage.shape) @@ -1295,3 +1294,22 @@ def _make_lp_get_error(tensor_keys, log_prob, err): return KeyError(result) result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=)." return KeyError(result) + +def _get_composite_kwargs(current_dist): + if isinstance(current_dist, CompositeDistribution): + is_composite = True + aggregate = current_dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = current_dist.include_sum + if include_sum is None: + include_sum = False + kwargs = { + "inplace": False, + "aggregate_probabilities": aggregate, + "include_sum": include_sum, + } + else: + is_composite = False + kwargs = {} + return kwargs, is_composite