Skip to content

Commit 5aa8969

Browse files
committed
amend
1 parent 15717b1 commit 5aa8969

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _get_cur_log_prob(self, tensordict):
220220
action = tensordict.get(
221221
self.tensor_keys.action,
222222
as_padded_tensor=True,
223-
padding_side=dist.padding_side,
223+
padding_side="left",
224224
padding_value=-100,
225225
)
226226
log_prob = dist.log_prob(action)
@@ -308,7 +308,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
308308
ref_log_prob=tensordict.get(
309309
self.tensor_keys.ref_log_probs,
310310
as_padded_tensor=True,
311-
padding_side=dist.padding_side,
311+
padding_side="left",
312312
padding_value=dist.padding_value,
313313
),
314314
)
@@ -343,7 +343,7 @@ def _kl_to_ref(
343343
ref_log_prob = tensordict.get(
344344
key,
345345
as_padded_tensor=True,
346-
padding_side=dist.padding_side,
346+
padding_side="left",
347347
padding_value=dist.padding_value,
348348
)
349349
if ref_log_prob is None:
@@ -377,7 +377,7 @@ def _log_weight(
377377
prev_log_prob = tensordict.get(
378378
self.tensor_keys.sample_log_prob,
379379
as_padded_tensor=True,
380-
padding_side=dist.padding_side,
380+
padding_side="left",
381381
padding_value=dist.padding_value,
382382
)
383383

0 commit comments

Comments
 (0)