@@ -89,14 +89,17 @@ def process_logits(
89
89
if self ._seq_start_idx is None :
90
90
self ._seq_start_idx = len (input_ids [0 ])
91
91
92
- sequence_states : List [int ] = [] # vector of states corresponding to `input_ids`
92
+ sequence_states : List [Any ] = [] # vector of states corresponding to `input_ids`
93
93
94
94
for seq_ids in input_ids :
95
95
gen_ids = seq_ids [self ._seq_start_idx :]
96
96
curr_state_key = hash (tuple (gen_ids .tolist ()))
97
97
98
98
if curr_state_key not in self ._guide_states :
99
- prev_state = self ._guide_states [hash (tuple (gen_ids [:- 1 ].tolist ()))]
99
+ prev_state_key = hash (tuple (gen_ids [:- 1 ].tolist ()))
100
+ prev_state = self ._guide_states .get (
101
+ prev_state_key , self .guide .initial_state
102
+ )
100
103
curr_state = self .guide .get_next_state (prev_state , gen_ids [- 1 ].item ())
101
104
self ._guide_states [curr_state_key ] = curr_state
102
105
@@ -107,19 +110,26 @@ def process_logits(
107
110
allowed_tokens_batch = []
108
111
batch_indices = []
109
112
for i , guide_state in enumerate (sequence_states ):
110
- allowed_tokens = self .guide .get_next_instruction (guide_state ).tokens .to (
111
- mask .device , non_blocking = True
112
- )
113
+ instruction = self .guide .get_next_instruction (guide_state )
114
+ if instruction is None :
115
+ continue # Skip if no instruction is available
116
+ allowed_tokens = instruction .tokens
117
+ if allowed_tokens is None :
118
+ continue # Skip if no tokens are allowed
119
+ allowed_tokens = allowed_tokens .to (mask .device , non_blocking = True )
120
+
121
+ # Filter out invalid token IDs
122
+ allowed_tokens = allowed_tokens [allowed_tokens < logits .size (1 )]
113
123
allowed_tokens_batch .append (allowed_tokens )
114
- batch_indices .append (
115
- torch .full_like (allowed_tokens , i )
116
- ) # Store batch index for each allowed token
124
+ batch_indices .append (torch .full_like (allowed_tokens , i ))
117
125
118
- allowed_tokens_concat = torch .cat (allowed_tokens_batch )
119
- batch_indices_concat = torch .cat (batch_indices )
126
+ if allowed_tokens_batch :
127
+ allowed_tokens_concat = torch .cat (allowed_tokens_batch )
128
+ batch_indices_concat = torch .cat (batch_indices )
120
129
121
- mask [batch_indices_concat , allowed_tokens_concat ] = False
122
- logits .masked_fill_ (mask , float ("-inf" ))
130
+ mask [batch_indices_concat , allowed_tokens_concat ] = False
131
+
132
+ logits = logits .masked_fill (mask , float ("-inf" ))
123
133
124
134
return logits
125
135
@@ -221,26 +231,34 @@ def process_logits(
221
231
if self ._seq_start_idx is None :
222
232
self ._seq_start_idx = len (input_ids [0 ])
223
233
224
- sequence_states : List = [] # vector of states corresponding to `input_ids`
234
+ sequence_states : List [ Any ] = [] # vector of states corresponding to `input_ids`
225
235
226
236
for seq_ids in input_ids :
227
237
gen_ids = seq_ids [self ._seq_start_idx :]
228
238
curr_state_key = hash (tuple (gen_ids .tolist ()))
229
239
230
240
if curr_state_key not in self ._guide_states :
231
- prev_state = self ._guide_states [hash (tuple (gen_ids [:- 1 ].tolist ()))]
241
+ prev_state_key = hash (tuple (gen_ids [:- 1 ].tolist ()))
242
+ prev_state = self ._guide_states .get (
243
+ prev_state_key , self .guide .initial_state
244
+ )
232
245
curr_state = self .guide .get_next_state (prev_state , gen_ids [- 1 ].item ())
233
246
self ._guide_states [curr_state_key ] = curr_state
234
247
235
248
sequence_states .append (self ._guide_states [curr_state_key ])
236
249
237
250
mask = torch .full_like (logits , - math .inf )
238
251
for i , guide_state in enumerate (sequence_states ):
239
- first_legal_token = next (
252
+ valid_tokens = list (
240
253
self .guide .iter_valid_token_ids (
241
- guide_state , torch .argsort (logits [ i ], descending = True )
254
+ guide_state , torch .arange (logits . size ( 1 ), device = logits . device )
242
255
)
243
256
)
244
- mask [i , [first_legal_token ]] = logits [i , [first_legal_token ]]
257
+ if valid_tokens :
258
+ # Keep only valid tokens
259
+ mask [i , valid_tokens ] = logits [i , valid_tokens ]
260
+ else :
261
+ # No valid tokens; generation should stop
262
+ mask [i ] = logits [i ]
245
263
246
264
return mask
0 commit comments