We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b823f3d commit 44e70e7Copy full SHA for 44e70e7
torchrl/envs/llm/transforms/kl.py
@@ -1159,17 +1159,8 @@ def _step(
1159
for gen_lp, ref_lp in _zip_strict(gen_log_probs, ref_log_probs)
1160
]
1161
1162
- # Convert to appropriate format
1163
- if hasattr(gen_log_probs[0], "device"):
1164
- # If it's a tensor, use pad_sequence
1165
- kl = pad_sequence(
1166
- kl, batch_first=True, padding_value=0.0, padding_side=self.padding_side
1167
- )
1168
- else:
1169
- # If it's nested, use nested tensor
1170
- kl = torch.nested.as_nested_tensor(kl, layout=torch.strided)
+ kl = torch.nested.as_nested_tensor(kl, layout=torch.strided)
1171
1172
- # Store KL
1173
next_tensordict.set(self.kl_key, kl)
1174
1175
# Add to reward if requested
0 commit comments