Skip to content

Commit 44e70e7

Browse files
committed
amend
1 parent b823f3d commit 44e70e7

File tree

1 file changed

+1
-10
lines changed
  • torchrl/envs/llm/transforms

1 file changed

+1
-10
lines changed

torchrl/envs/llm/transforms/kl.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,17 +1159,8 @@ def _step(
11591159
for gen_lp, ref_lp in _zip_strict(gen_log_probs, ref_log_probs)
11601160
]
11611161

1162-
# Convert to appropriate format
1163-
if hasattr(gen_log_probs[0], "device"):
1164-
# If it's a tensor, use pad_sequence
1165-
kl = pad_sequence(
1166-
kl, batch_first=True, padding_value=0.0, padding_side=self.padding_side
1167-
)
1168-
else:
1169-
# If it's nested, use nested tensor
1170-
kl = torch.nested.as_nested_tensor(kl, layout=torch.strided)
1162+
kl = torch.nested.as_nested_tensor(kl, layout=torch.strided)
11711163

1172-
# Store KL
11731164
next_tensordict.set(self.kl_key, kl)
11741165

11751166
# Add to reward if requested

0 commit comments

Comments
 (0)