Skip to content

Commit 39772aa

Browse files
committed
amend
1 parent 44e70e7 commit 39772aa

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

torchrl/envs/llm/transforms/kl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1177,14 +1177,15 @@ def _step(
11771177
r - self.coeff * k.unsqueeze(-1)
11781178
for r, k in _zip_strict(reward, kl)
11791179
]
1180+
next_tensordict.set("reward", torch.nested.as_nested_tensor(reward, layout=torch.strided))
11801181
else:
11811182
if reward.ndim != kl.ndim + 1:
11821183
raise ValueError(
11831184
f"The rewards have shape {reward.shape} but the kl has shape {kl.shape}. "
11841185
f"The rewards should have one more dimension than the KL."
11851186
)
11861187
reward = reward - self.coeff * kl.unsqueeze(-1)
1187-
next_tensordict.set("reward", reward)
1188+
next_tensordict.set("reward", reward)
11881189

11891190
return next_tensordict
11901191

torchrl/modules/distributions/discrete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
369369
if logits.ndim > 2:
370370
# Bring channels in 2nd dim
371371
logits = logits.transpose(-1, 1)
372+
if logits.ndim <= idx.ndim:
373+
logits = logits.expand(idx.shape + logits.shape)
374+
print(f"logits: {logits.shape}")
375+
print(f"idx: {idx.shape}")
372376
ret = -torch.nn.functional.cross_entropy(logits, idx, reduce=False)
373377
else:
374378
ret = super().log_prob(idx)

0 commit comments

Comments
 (0)