Skip to content

Commit 668ea42

Browse files
committed
Fix IndexError caused by invalid token IDs in CFGGuide
1 parent 5f39ded commit 668ea42

File tree

2 files changed

+57
-35
lines changed

2 files changed

+57
-35
lines changed

outlines/fsm/guide.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ def __init__(self, cfg_string: str, tokenizer):
116116

117117
self.cfg_string = cfg_string
118118
self.tokenizer = tokenizer
119+
120+
# Set eos_token_id if available
119121
self.eos_token_id = self.tokenizer.eos_token_id
122+
120123
self.parser = PartialLark(
121124
cfg_string,
122125
parser="lalr",
@@ -149,14 +152,20 @@ def get_next_instruction(self, state: CFGState) -> Instruction:
149152
"""
150153

151154
if state.parser_state is None:
152-
return Write(torch.tensor([self.eos_token_id]))
155+
if self.eos_token_id is not None:
156+
return Write(torch.tensor([self.eos_token_id]))
157+
else:
158+
return None # No instruction if eos_token_id is not set
153159

154160
valid_tokens = list(
155-
self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values())
161+
self.iter_valid_token_ids(state, list(self.tokenizer.vocabulary.values()))
156162
)
157-
if len(valid_tokens) == 1:
163+
if not valid_tokens:
164+
return None # No valid tokens to generate
165+
elif len(valid_tokens) == 1:
158166
return Write(torch.tensor(valid_tokens))
159-
return Generate(torch.tensor(valid_tokens))
167+
else:
168+
return Generate(torch.tensor(valid_tokens))
160169

161170
def iter_valid_token_ids(
162171
self, state: CFGState, candidate_token_ids: list
@@ -177,11 +186,12 @@ def iter_valid_token_ids(
177186
Valid token ids.
178187
"""
179188
if state.parser_state is None:
180-
yield self.eos_token_id
189+
if self.eos_token_id is not None:
190+
yield self.eos_token_id
181191
return
182192

183193
for token_id in candidate_token_ids:
184-
if token_id == self.eos_token_id:
194+
if token_id == self.eos_token_id and self.eos_token_id is not None:
185195
if self.can_terminate_state(state):
186196
yield token_id
187197
else:
@@ -234,20 +244,14 @@ def _get_parser_state_token_applied(
234244
"""
235245
parser_state = copy.copy(state.parser_state) # prevent side effects
236246

237-
# normalize
238-
if state.prev_token is None:
239-
new_token_str = self.tokenizer.decode([token_id])[0]
240-
else:
241-
prev_token_str = self.tokenizer.decode([[state.prev_token]])[0]
242-
combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[
243-
0
244-
]
245-
new_token_str = combined_token_str[len(prev_token_str) :]
246-
247-
if new_token_str == "":
247+
# Decode the token
248+
token_str = self.tokenizer.decode([token_id])
249+
if not token_str:
248250
raise ValueError("empty next token")
249251

250-
# update parser with new token
252+
new_token_str = token_str[0] # Assuming decode returns a list
253+
254+
# Update parser with new token
251255
parser_state.lexer.state.text += new_token_str
252256
self.parser.parse_from_state(parser_state, is_end=False)
253257

outlines/processors/structured.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,17 @@ def process_logits(
8989
if self._seq_start_idx is None:
9090
self._seq_start_idx = len(input_ids[0])
9191

92-
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`
92+
sequence_states: List[Any] = [] # vector of states corresponding to `input_ids`
9393

9494
for seq_ids in input_ids:
9595
gen_ids = seq_ids[self._seq_start_idx :]
9696
curr_state_key = hash(tuple(gen_ids.tolist()))
9797

9898
if curr_state_key not in self._guide_states:
99-
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
99+
prev_state_key = hash(tuple(gen_ids[:-1].tolist()))
100+
prev_state = self._guide_states.get(
101+
prev_state_key, self.guide.initial_state
102+
)
100103
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
101104
self._guide_states[curr_state_key] = curr_state
102105

@@ -107,19 +110,26 @@ def process_logits(
107110
allowed_tokens_batch = []
108111
batch_indices = []
109112
for i, guide_state in enumerate(sequence_states):
110-
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
111-
mask.device, non_blocking=True
112-
)
113+
instruction = self.guide.get_next_instruction(guide_state)
114+
if instruction is None:
115+
continue # Skip if no instruction is available
116+
allowed_tokens = instruction.tokens
117+
if allowed_tokens is None:
118+
continue # Skip if no tokens are allowed
119+
allowed_tokens = allowed_tokens.to(mask.device, non_blocking=True)
120+
121+
# Filter out invalid token IDs
122+
allowed_tokens = allowed_tokens[allowed_tokens < logits.size(1)]
113123
allowed_tokens_batch.append(allowed_tokens)
114-
batch_indices.append(
115-
torch.full_like(allowed_tokens, i)
116-
) # Store batch index for each allowed token
124+
batch_indices.append(torch.full_like(allowed_tokens, i))
117125

118-
allowed_tokens_concat = torch.cat(allowed_tokens_batch)
119-
batch_indices_concat = torch.cat(batch_indices)
126+
if allowed_tokens_batch:
127+
allowed_tokens_concat = torch.cat(allowed_tokens_batch)
128+
batch_indices_concat = torch.cat(batch_indices)
120129

121-
mask[batch_indices_concat, allowed_tokens_concat] = False
122-
logits.masked_fill_(mask, float("-inf"))
130+
mask[batch_indices_concat, allowed_tokens_concat] = False
131+
132+
logits = logits.masked_fill(mask, float("-inf"))
123133

124134
return logits
125135

@@ -221,26 +231,34 @@ def process_logits(
221231
if self._seq_start_idx is None:
222232
self._seq_start_idx = len(input_ids[0])
223233

224-
sequence_states: List = [] # vector of states corresponding to `input_ids`
234+
sequence_states: List[Any] = [] # vector of states corresponding to `input_ids`
225235

226236
for seq_ids in input_ids:
227237
gen_ids = seq_ids[self._seq_start_idx :]
228238
curr_state_key = hash(tuple(gen_ids.tolist()))
229239

230240
if curr_state_key not in self._guide_states:
231-
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
241+
prev_state_key = hash(tuple(gen_ids[:-1].tolist()))
242+
prev_state = self._guide_states.get(
243+
prev_state_key, self.guide.initial_state
244+
)
232245
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
233246
self._guide_states[curr_state_key] = curr_state
234247

235248
sequence_states.append(self._guide_states[curr_state_key])
236249

237250
mask = torch.full_like(logits, -math.inf)
238251
for i, guide_state in enumerate(sequence_states):
239-
first_legal_token = next(
252+
valid_tokens = list(
240253
self.guide.iter_valid_token_ids(
241-
guide_state, torch.argsort(logits[i], descending=True)
254+
guide_state, torch.arange(logits.size(1), device=logits.device)
242255
)
243256
)
244-
mask[i, [first_legal_token]] = logits[i, [first_legal_token]]
257+
if valid_tokens:
258+
# Keep only valid tokens
259+
mask[i, valid_tokens] = logits[i, valid_tokens]
260+
else:
261+
# No valid tokens; generation should stop
262+
mask[i] = logits[i]
245263

246264
return mask

0 commit comments

Comments
 (0)