@@ -223,7 +223,7 @@ def __init__(
223
223
if from_text :
224
224
self .out_keys += [self .text_response_key , self .token_key ]
225
225
if self .return_log_probs :
226
- self .out_keys += ["log_probs" ]
226
+ self .out_keys += [self . log_prob_key ]
227
227
228
228
def forward (
229
229
self ,
@@ -303,7 +303,7 @@ def _from_vllm_generate_text(self, td):
303
303
),
304
304
)
305
305
in_keys = [
306
- "log_probs" ,
306
+ self . log_prob_key ,
307
307
self .token_response_key ,
308
308
self .text_response_key ,
309
309
self .token_key ,
@@ -394,7 +394,7 @@ def _from_vllm_logprobs_text(self, td):
394
394
if isinstance (input_ids_response , list ):
395
395
input_ids_response = torch .nested .nested_tensor (input_ids_response )
396
396
out ["tokens_response" ] = input_ids_response
397
- out ["log_probs" ] = lps
397
+ out [self . log_prob_key ] = lps
398
398
inputs = td .select (* self .in_keys , strict = False )
399
399
if inputs .ndim < out .ndim :
400
400
# This happens when n > 1
@@ -423,18 +423,19 @@ def _from_vllm_generate_tokens(self, td):
423
423
).to_padded_tensor (padding = self .padding_value )
424
424
tokens_response_td .rename_key_ ("token_ids" , "tokens_response" )
425
425
if self .return_log_probs :
426
- tokens_response_td .rename_key_ ("logprobs" , "log_probs" )
426
+ tokens_response_td .rename_key_ ("logprobs" , self . log_prob_key )
427
427
if self .pad_output :
428
428
padded_values = (
429
429
tokens_response_td ["tokens_response" ] == self .padding_value
430
430
)
431
431
if padded_values .any ():
432
- lps = tokens_response_td ["log_probs" ]
432
+ lps = tokens_response_td [self . log_prob_key ]
433
433
lps = torch .where (expand_as_right (~ padded_values , lps ), lps , 0.0 )
434
- tokens_response_td ["log_probs" ] = lps
434
+ tokens_response_td [self . log_prob_key ] = lps
435
435
out = tokens_response_td .empty (recurse = True )
436
436
out .update (
437
- tokens_response_td , keys_to_update = (self .token_response_key , "log_probs" )
437
+ tokens_response_td ,
438
+ keys_to_update = (self .token_response_key , self .log_prob_key ),
438
439
)
439
440
inputs = td .select (* self .in_keys , strict = False )
440
441
if inputs .ndim < out .ndim :
@@ -467,7 +468,7 @@ def _from_vllm_logprobs_tokens(self, td):
467
468
padded = tokens_response == self .padding_value
468
469
prompt_logprobs = torch .where (~ padded , prompt_logprobs , 0.0 )
469
470
out = tokens_out ._tensordict .empty (recurse = True )
470
- out .set ("log_probs" , prompt_logprobs )
471
+ out .set (self . log_prob_key , prompt_logprobs )
471
472
out .set (self .token_response_key , tokens_response )
472
473
inputs = td .select (* self .in_keys , strict = False )
473
474
if inputs .ndim < out .ndim :
@@ -501,13 +502,13 @@ def _get_output_tokens_and_log_probs(self, tokens_out):
501
502
)
502
503
503
504
if self .return_log_probs or "logprobs" in tokens_response_td :
504
- tokens_response_td .rename_key_ ("logprobs" , "log_probs" )
505
+ tokens_response_td .rename_key_ ("logprobs" , self . log_prob_key )
505
506
if self .pad_output :
506
507
padded_values = tokens_response_td ["tokens_response" ] == padding_value
507
508
if padded_values .any ():
508
- lps = tokens_response_td ["log_probs" ]
509
+ lps = tokens_response_td [self . log_prob_key ]
509
510
lps = torch .where (expand_as_right (~ padded_values , lps ), lps , 0.0 )
510
- tokens_response_td ["log_probs" ] = lps
511
+ tokens_response_td [self . log_prob_key ] = lps
511
512
return tokens_response_td
512
513
513
514
def _to_list (self , tokens , attention_mask ):
0 commit comments