Skip to content

Commit 50ecb15

Browse files
author
Vincent Moens
committed
[Quality] Warning composite mismatch when we should
ghstack-source-id: 808db56 Pull-Request-resolved: #2964
1 parent f0cda31 commit 50ecb15

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchrl/objectives/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,6 @@ def _log_weight(
598598

599599
if is_composite:
600600
with set_composite_lp_aggregate(False):
601-
if log_prob.batch_size != adv_shape:
602-
log_prob.batch_size = adv_shape
603601
if not is_tensor_collection(prev_log_prob):
604602
# this isn't great: in general, multi-head actions should have a composite log-prob too
605603
warnings.warn(
@@ -612,6 +610,8 @@ def _log_weight(
612610
if is_tensor_collection(log_prob):
613611
log_prob = _sum_td_features(log_prob)
614612
log_prob.view_as(prev_log_prob)
613+
if log_prob.batch_size != adv_shape:
614+
log_prob.batch_size = adv_shape
615615
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
616616
if is_tensor_collection(log_weight):
617617
log_weight = _sum_td_features(log_weight)

0 commit comments

Comments
 (0)