File tree Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -1177,14 +1177,15 @@ def _step(
1177
1177
r - self .coeff * k .unsqueeze (- 1 )
1178
1178
for r , k in _zip_strict (reward , kl )
1179
1179
]
1180
+ next_tensordict .set ("reward" , torch .nested .as_nested_tensor (reward , layout = torch .strided ))
1180
1181
else :
1181
1182
if reward .ndim != kl .ndim + 1 :
1182
1183
raise ValueError (
1183
1184
f"The rewards have shape { reward .shape } but the kl has shape { kl .shape } . "
1184
1185
f"The rewards should have one more dimension than the KL."
1185
1186
)
1186
1187
reward = reward - self .coeff * kl .unsqueeze (- 1 )
1187
- next_tensordict .set ("reward" , reward )
1188
+ next_tensordict .set ("reward" , reward )
1188
1189
1189
1190
return next_tensordict
1190
1191
Original file line number Diff line number Diff line change @@ -369,6 +369,10 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
369
369
if logits .ndim > 2 :
370
370
# Bring channels in 2nd dim
371
371
logits = logits .transpose (- 1 , 1 )
372
+ if logits .ndim <= idx .ndim :
373
+ logits = logits .expand (idx .shape + logits .shape )
374
+ print (f"logits: { logits .shape } " )
375
+ print (f"idx: { idx .shape } " )
372
376
ret = - torch .nn .functional .cross_entropy (logits , idx , reduce = False )
373
377
else :
374
378
ret = super ().log_prob (idx )
You can’t perform that action at this time.
0 commit comments