File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -598,8 +598,6 @@ def _log_weight(
598
598
599
599
if is_composite :
600
600
with set_composite_lp_aggregate (False ):
601
- if log_prob .batch_size != adv_shape :
602
- log_prob .batch_size = adv_shape
603
601
if not is_tensor_collection (prev_log_prob ):
604
602
# this isn't great: in general, multi-head actions should have a composite log-prob too
605
603
warnings .warn (
@@ -612,6 +610,8 @@ def _log_weight(
612
610
if is_tensor_collection (log_prob ):
613
611
log_prob = _sum_td_features (log_prob )
614
612
log_prob .view_as (prev_log_prob )
613
+ if log_prob .batch_size != adv_shape :
614
+ log_prob .batch_size = adv_shape
615
615
log_weight = (log_prob - prev_log_prob ).unsqueeze (- 1 )
616
616
if is_tensor_collection (log_weight ):
617
617
log_weight = _sum_td_features (log_weight )
You can’t perform that action at this time.
0 commit comments