File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -220,7 +220,7 @@ def _get_cur_log_prob(self, tensordict):
220
220
action = tensordict .get (
221
221
self .tensor_keys .action ,
222
222
as_padded_tensor = True ,
223
- padding_side = dist . padding_side ,
223
+ padding_side = "left" ,
224
224
padding_value = - 100 ,
225
225
)
226
226
log_prob = dist .log_prob (action )
@@ -308,7 +308,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
308
308
ref_log_prob = tensordict .get (
309
309
self .tensor_keys .ref_log_probs ,
310
310
as_padded_tensor = True ,
311
- padding_side = dist . padding_side ,
311
+ padding_side = "left" ,
312
312
padding_value = dist .padding_value ,
313
313
),
314
314
)
@@ -343,7 +343,7 @@ def _kl_to_ref(
343
343
ref_log_prob = tensordict .get (
344
344
key ,
345
345
as_padded_tensor = True ,
346
- padding_side = dist . padding_side ,
346
+ padding_side = "left" ,
347
347
padding_value = dist .padding_value ,
348
348
)
349
349
if ref_log_prob is None :
@@ -377,7 +377,7 @@ def _log_weight(
377
377
prev_log_prob = tensordict .get (
378
378
self .tensor_keys .sample_log_prob ,
379
379
as_padded_tensor = True ,
380
- padding_side = dist . padding_side ,
380
+ padding_side = "left" ,
381
381
padding_value = dist .padding_value ,
382
382
)
383
383
You can’t perform that action at this time.
0 commit comments