@@ -63,25 +63,24 @@ def get_logits_processors(
63
63
try :
64
64
# Convert token_id to integer
65
65
# Clamp the bias between -100 and 100 per OpenAI API spec
66
- clamped_logit_bias : Dict [int , float ] = {
67
- int ( token_id ): min ( 100.0 , max ( - 100.0 , bias ))
68
- for token_id , bias in logit_bias .items ()
69
- }
66
+ logit_bias_index = [int ( token_id ) for token_id in logit_bias ]
67
+ logit_bias_value = [
68
+ min ( 100.0 , max ( - 100.0 , bias )) for bias in logit_bias .values ()
69
+ ]
70
70
except ValueError as exc :
71
71
raise ValueError (
72
72
"Found token_id in logit_bias that is not "
73
73
"an integer or string representing an integer" ) from exc
74
74
75
75
# Check if token_id is within the vocab size
76
- for token_id , bias in clamped_logit_bias . items () :
76
+ for token_id in logit_bias_index :
77
77
if token_id < 0 or token_id >= len (tokenizer ):
78
78
raise ValueError (f"token_id { token_id } in logit_bias contains "
79
79
"out-of-vocab token id" )
80
80
81
- clamped_logit_bias = {
82
- "index" : torch .tensor (list (clamped_logit_bias .keys ())),
83
- "value" : torch .tensor (list (clamped_logit_bias .values ()),
84
- dtype = dtype )
81
+ clamped_logit_bias : Dict [str , torch .Tensor ] = {
82
+ "index" : torch .tensor (logit_bias_index ),
83
+ "value" : torch .tensor (logit_bias_value , dtype = dtype )
85
84
}
86
85
logits_processors .append (
87
86
partial (logit_bias_logits_processor , clamped_logit_bias ))
0 commit comments