|
18 | 18 | TensorDictParams,
|
19 | 19 | )
|
20 | 20 | from tensordict.nn import (
|
| 21 | + CompositeDistribution, |
21 | 22 | dispatch,
|
22 | 23 | ProbabilisticTensorDictModule,
|
23 | 24 | ProbabilisticTensorDictSequential,
|
|
33 | 34 | _clip_value_loss,
|
34 | 35 | _GAMMA_LMBDA_DEPREC_ERROR,
|
35 | 36 | _reduce,
|
| 37 | + _sum_td_features, |
36 | 38 | default_value_kwargs,
|
37 | 39 | distance_loss,
|
38 | 40 | ValueEstimators,
|
@@ -462,9 +464,13 @@ def reset(self) -> None:
|
462 | 464 |
|
463 | 465 | def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
|
464 | 466 | try:
|
465 |
| - entropy = dist.entropy() |
| 467 | + if isinstance(dist, CompositeDistribution): |
| 468 | + kwargs = {"aggregate_probabilities": False, "include_sum": False} |
| 469 | + else: |
| 470 | + kwargs = {} |
| 471 | + entropy = dist.entropy(**kwargs) |
466 | 472 | if is_tensor_collection(entropy):
|
467 |
| - entropy = entropy.get(dist.entropy_key) |
| 473 | + entropy = _sum_td_features(entropy) |
468 | 474 | except NotImplementedError:
|
469 | 475 | x = dist.rsample((self.samples_mc_entropy,))
|
470 | 476 | log_prob = dist.log_prob(x)
|
@@ -497,13 +503,20 @@ def _log_weight(
|
497 | 503 | if isinstance(action, torch.Tensor):
|
498 | 504 | log_prob = dist.log_prob(action)
|
499 | 505 | else:
|
500 |
| - maybe_log_prob = dist.log_prob(tensordict) |
501 |
| - if not isinstance(maybe_log_prob, torch.Tensor): |
502 |
| - # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not |
503 |
| - # be a tensor |
504 |
| - log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) |
| 506 | + if isinstance(dist, CompositeDistribution): |
| 507 | + is_composite = True |
| 508 | + kwargs = { |
| 509 | + "inplace": False, |
| 510 | + "aggregate_probabilities": False, |
| 511 | + "include_sum": False, |
| 512 | + } |
505 | 513 | else:
|
506 |
| - log_prob = maybe_log_prob |
| 514 | + is_composite = False |
| 515 | + kwargs = {} |
| 516 | + log_prob = dist.log_prob(tensordict, **kwargs) |
| 517 | + if is_composite and not isinstance(prev_log_prob, TensorDict): |
| 518 | + log_prob = _sum_td_features(log_prob) |
| 519 | + log_prob.view_as(prev_log_prob) |
507 | 520 |
|
508 | 521 | log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
|
509 | 522 | kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
|
@@ -598,6 +611,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
598 | 611 | advantage = (advantage - loc) / scale
|
599 | 612 |
|
600 | 613 | log_weight, dist, kl_approx = self._log_weight(tensordict)
|
| 614 | + if is_tensor_collection(log_weight): |
| 615 | + log_weight = _sum_td_features(log_weight) |
| 616 | + log_weight = log_weight.view(advantage.shape) |
601 | 617 | neg_loss = log_weight.exp() * advantage
|
602 | 618 | td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
|
603 | 619 | if self.entropy_bonus:
|
@@ -1149,16 +1165,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
1149 | 1165 | kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
|
1150 | 1166 | except NotImplementedError:
|
1151 | 1167 | x = previous_dist.sample((self.samples_mc_kl,))
|
1152 |
| - previous_log_prob = previous_dist.log_prob(x) |
1153 |
| - current_log_prob = current_dist.log_prob(x) |
| 1168 | + if isinstance(previous_dist, CompositeDistribution): |
| 1169 | + kwargs = { |
| 1170 | + "aggregate_probabilities": False, |
| 1171 | + "inplace": False, |
| 1172 | + "include_sum": False, |
| 1173 | + } |
| 1174 | + else: |
| 1175 | + kwargs = {} |
| 1176 | + previous_log_prob = previous_dist.log_prob(x, **kwargs) |
| 1177 | + current_log_prob = current_dist.log_prob(x, **kwargs) |
1154 | 1178 | if is_tensor_collection(current_log_prob):
|
1155 |
| - previous_log_prob = previous_log_prob.get( |
1156 |
| - self.tensor_keys.sample_log_prob |
1157 |
| - ) |
1158 |
| - current_log_prob = current_log_prob.get( |
1159 |
| - self.tensor_keys.sample_log_prob |
1160 |
| - ) |
1161 |
| - |
| 1179 | + previous_log_prob = _sum_td_features(previous_log_prob) |
| 1180 | + current_log_prob = _sum_td_features(current_log_prob) |
1162 | 1181 | kl = (previous_log_prob - current_log_prob).mean(0)
|
1163 | 1182 | kl = kl.unsqueeze(-1)
|
1164 | 1183 | neg_loss = neg_loss - self.beta * kl
|
|
0 commit comments