Skip to content

Commit 27a001c

Browse files
committed
mm support structured output
1 parent 2c36074 commit 27a001c

File tree

15 files changed

+347
-68
lines changed

15 files changed

+347
-68
lines changed

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class MoEConfig:
146146
im_patch_id = (
147147
100295 # multimodality, TODO(liuyuanle): read from config.json
148148
)
149+
reasoning_parser: Optional[str] = None
149150

150151

151152
@dataclass

fastdeploy/engine/config.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,8 @@ def postprocess(self):
733733
self.max_model_len // self.cache_config.block_size)
734734

735735
if self.guided_decoding_backend == "auto":
736-
if self.enable_mm:
736+
if current_platform.is_xpu() or self.speculative_config.method is not None:
737+
llm_logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.")
737738
self.guided_decoding_backend = "off"
738739
else:
739740
self.guided_decoding_backend = "xgrammar"
@@ -795,10 +796,9 @@ def check(self):
795796
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
796797

797798
if self.guided_decoding_backend != "off":
798-
# TODO: mm support guided_decoding
799-
assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding"
800-
801799
# TODO: speculative decoding support guided_decoding
800+
assert self.speculative_config.method is None, \
801+
"speculative decoding currently do not support guided_decoding"
802802

803803
# TODO: xpu support guided_decoding
804804
assert not current_platform.is_xpu(
@@ -826,11 +826,7 @@ def print(self, file=None):
826826
if k == "generation_config" and v is not None:
827827
for gck, gcv in v.to_dict().items():
828828
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
829-
elif (k == "cache_config" or
830-
k == "model_config" or
831-
k == "scheduler_config" or
832-
k == "parallel_config" or
833-
k == "commit_config"):
829+
elif k in ["cache_config", "model_config", "commit_config", "scheduler_config", "parallel_config", "speculative_config"]:
834830
v.print()
835831
else:
836832
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))

fastdeploy/engine/engine.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,13 @@ def _insert_zmq_task_to_scheduler(self):
359359
llm_logger.debug(f"Receive request: {request}")
360360

361361
err_msg = None
362+
if ((request.guided_json is not None
363+
or request.guided_regex is not None
364+
or request.structural_tag is not None
365+
or request.guided_grammar is not None) and self.guided_decoding_checker is None):
366+
err_msg = "guided_backend is None, use --guided-decoding-backend to " \
367+
"specify the backend at server startup."
368+
362369
if self.guided_decoding_checker is not None:
363370
request, err_msg = self.guided_decoding_checker.schema_format(
364371
request)
@@ -447,6 +454,14 @@ def add_requests(self, task, sampling_params=None, **kwargs):
447454
llm_logger.error(error_msg)
448455
raise EngineError(error_msg, error_code=400)
449456

457+
if ((request.guided_json is not None
458+
or request.guided_regex is not None
459+
or request.structural_tag is not None
460+
or request.guided_grammar is not None) and self.guided_decoding_checker is None):
461+
err_msg = "guided_backend is None, use --guided-decoding-backend to specify the backend at server startup."
462+
llm_logger.error(err_msg)
463+
raise EngineError(err_msg, error_code=400)
464+
450465
if self.guided_decoding_checker is not None:
451466
request, err_msg = self.guided_decoding_checker.schema_format(
452467
request)
@@ -1030,7 +1045,8 @@ def _start_worker_service(self):
10301045
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
10311046
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
10321047
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
1033-
f" --load_strategy {self.cfg.model_config.load_strategy}")
1048+
f" --load_strategy {self.cfg.model_config.load_strategy}"
1049+
f" --reasoning_parser {self.cfg.reasoning_parser}")
10341050

10351051
worker_append_flag = {
10361052
"enable_expert_parallel":

fastdeploy/engine/sampling_params.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class SamplingParams:
9292
min_tokens: int = 1
9393
logprobs: Optional[int] = None
9494
bad_words: Optional[List[str]] = None
95+
guided_decoding: Optional[GuidedDecodingParams] = None
9596

9697
@classmethod
9798
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
@@ -121,7 +122,8 @@ def from_optional(cls,
121122
reasoning_max_tokens=None,
122123
min_tokens=1,
123124
logprobs=None,
124-
bad_words=None) -> "SamplingParams":
125+
bad_words=None,
126+
guided_decoding=None) -> "SamplingParams":
125127
"""Create instance from command line arguments"""
126128
return cls(n=1 if n is None else n,
127129
best_of=best_of,
@@ -141,7 +143,8 @@ def from_optional(cls,
141143
reasoning_max_tokens=reasoning_max_tokens,
142144
min_tokens=min_tokens,
143145
logprobs=logprobs,
144-
bad_words=bad_words)
146+
bad_words=bad_words,
147+
guided_decoding=guided_decoding)
145148

146149
def __post_init__(self):
147150
if self.seed is None:
@@ -207,6 +210,9 @@ def _verify_args(self) -> None:
207210
raise ValueError("seed must be in [0, 922337203685477580], got "
208211
f"{self.seed}.")
209212

213+
if self.guided_decoding is not None:
214+
self.guided_decoding._verify_args()
215+
210216
def update_from_tokenizer(self, tokenizer):
211217
"""
212218
# TODO: Implement stop tokens and bad words support
@@ -224,3 +230,45 @@ class BeamSearchParams:
224230
temperature: float = 0.0
225231
length_penalty: float = 1.0
226232
include_stop_str_in_output: bool = False
233+
234+
235+
@dataclass
236+
class GuidedDecodingParams:
237+
"""Guided decoding parameters for text generation."""
238+
json: Optional[Union[str, dict]] = None
239+
regex: Optional[str] = None
240+
choice: Optional[List[str]] = None
241+
grammar: Optional[str] = None
242+
json_object: Optional[bool] = None
243+
structural_tag: Optional[str] = None
244+
245+
def to_dict(self):
246+
"""convert to dict"""
247+
key_dict = {
248+
"guided_json": self.json,
249+
"guided_regex": self.regex,
250+
"guided_choice": self.choice,
251+
"guided_grammar": self.grammar,
252+
"structural_tag": self.structural_tag,
253+
"guided_json_object": self.json_object,
254+
}
255+
256+
guided_dict = {}
257+
for key, value in key_dict.items():
258+
if value is not None:
259+
guided_dict[key] = value
260+
return guided_dict
261+
262+
def _verify_args(self):
263+
"""Verify the arguments."""
264+
guided_count = sum([
265+
self.json is not None, self.regex is not None, self.choice
266+
is not None, self.grammar is not None, self.json_object
267+
is not None, self.structural_tag is not None
268+
])
269+
270+
if guided_count > 1:
271+
raise ValueError(
272+
"You can only use one kind of guided decoding "
273+
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
274+
)

fastdeploy/entrypoints/llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def _add_request(
273273
if chat_template_kwargs is not None:
274274
enable_thinking = chat_template_kwargs.get(
275275
"enable_thinking", None)
276+
if current_sampling_params.guided_decoding is not None:
277+
guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
278+
tasks.update(guided_decoding_dict)
276279
self.llm_engine.add_requests(tasks,
277280
current_sampling_params,
278281
enable_thinking=enable_thinking)

fastdeploy/model_executor/guided_decoding/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
"""
1616

1717
# from fastdeploy.config import FDConfig
18+
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
19+
BackendBase, BaseChecker, LogitsProcessorBase)
1820

19-
__all__ = ['get_guided_backend', 'schema_checker']
21+
__all__ = ['get_guided_backend', 'schema_checker', 'LogitsProcessorBase', 'BackendBase', 'BaseChecker']
2022

2123

2224
def get_guided_backend(

fastdeploy/model_executor/guided_decoding/base_guided_decoding.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from fastdeploy.config import FDConfig
2121
from fastdeploy.engine.request import Request
22+
from fastdeploy.reasoning import ReasoningParserManager
2223
from fastdeploy.utils import llm_logger
2324

2425

@@ -34,8 +35,9 @@ class LogitsProcessorBase:
3435
None (all state should be managed by subclasses)
3536
"""
3637

37-
def __init__(self):
38-
pass
38+
def __init__(self, enable_reasoning):
39+
self.reasoning_ended = False
40+
self.enable_reasoning = enable_reasoning
3941

4042
def fill_token_bitmask(self, token_bitmask, idx):
4143
"""
@@ -136,8 +138,13 @@ def __init__(self, fd_config: FDConfig):
136138
self.fd_config = fd_config
137139
self.executor = ThreadPoolExecutor()
138140
self.max_cache_size = 2048
141+
self.reasoning_parser = None
139142

140143
self.hf_tokenizer = self._get_tokenizer_hf()
144+
if self.fd_config.model_config.reasoning_parser:
145+
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(
146+
self.fd_config.model_config.reasoning_parser)
147+
self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer)
141148

142149
def _create_processor(self):
143150
"""
@@ -148,71 +155,89 @@ def _create_processor(self):
148155
"""
149156
raise NotImplementedError()
150157

151-
def _json_processor(self, schemata):
158+
def _json_processor(self, schemata, enable_thinking=False):
152159
"""
153160
Process JSON schemata.
154161
155162
Args:
156163
schemata (str): The schemata string.
164+
enable_thinking (bool): Whether to enable thinking mode.
157165
158166
Raises:
159167
NotImplementedError: This method should be implemented in subclasses.
160168
"""
161169
raise NotImplementedError()
162170

163-
def _regex_processor(self, schemata):
171+
def _regex_processor(self, schemata, enable_thinking=False):
164172
"""
165173
Process regular expression schemata.
166174
167175
Args:
168176
schemata (str): The schemata string.
177+
enable_thinking (bool): Whether to enable thinking mode.
169178
170179
Raises:
171180
NotImplementedError: This method should be implemented in subclasses.
172181
"""
173182
raise NotImplementedError()
174183

175-
def _grammar_processor(self, schemata):
184+
def _grammar_processor(self, schemata, enable_thinking=False):
176185
"""
177186
Process grammar schemata.
178187
179188
Args:
180189
schemata (str): The schemata string.
190+
enable_thinking (bool): Whether to enable thinking mode.
181191
182192
Raises:
183193
NotImplementedError: This method should be implemented in subclasses.
184194
"""
185195
raise NotImplementedError()
186196

187-
def _structural_tag_processor(self, schemata):
197+
def _structural_tag_processor(self, schemata, enable_thinking=False):
188198
"""
189199
Process structural tag schemata.
190200
191201
Args:
192202
schemata (str): The schemata string.
203+
enable_thinking (bool): Whether to enable thinking mode.
193204
194205
Raises:
195206
NotImplementedError: This method should be implemented in subclasses.
196207
"""
197208
raise NotImplementedError()
198209

199-
def _unsupported_processor_type(self, key_type, schemata):
210+
def _unsupported_processor_type(self, key_type, schemata, enable_thinking=False):
200211
"""
201212
Process unsupported type.
202213
203214
Args:
204215
key_type (str): The key type string.
205216
schemata (str): The schemata string.
217+
enable_thinking (bool): Whether to enable thinking mode.
206218
"""
207219
raise Exception(f"Unsupported processor type {key_type}.")
208220

221+
def get_reasoning_parser(self):
222+
"""
223+
Get reasoning parser object.
224+
225+
Returns:
226+
ReasoningParser: Reasoning parser object or None
227+
"""
228+
return self.reasoning_parser
229+
209230
def _init_logits_processor(
210-
self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
231+
self,
232+
schemata_key: tuple[str, str],
233+
enable_thinking: bool = False,
234+
) -> LogitsProcessorBase:
211235
"""
212236
init logits processor by type and schemata.
213237
214238
Args:
215239
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
240+
enable_thinking (bool): Whether to enable thinking step
216241
217242
Returns:
218243
LogitsProcessorBase: Initialized logits processor instance
@@ -222,20 +247,21 @@ def _init_logits_processor(
222247
"""
223248
key_type, schemata = schemata_key
224249
if key_type == "json":
225-
return self._json_processor(schemata)
250+
return self._json_processor(schemata, enable_thinking)
226251
elif key_type == "regex":
227-
return self._regex_processor(schemata)
252+
return self._regex_processor(schemata, enable_thinking)
228253
elif key_type == "grammar":
229-
return self._grammar_processor(schemata)
254+
return self._grammar_processor(schemata, enable_thinking)
230255
elif key_type == "structural_tag":
231-
return self._structural_tag_processor(schemata)
256+
return self._structural_tag_processor(schemata, enable_thinking)
232257
else:
233258
llm_logger.error(f"Unsupported processor type {key_type}.")
234259
return None
235260

236261
def get_logits_processor(
237262
self,
238-
schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
263+
schemata_key: tuple[str, str],
264+
enable_thinking: bool = False) -> tuple[LogitsProcessorBase, bool]:
239265
"""
240266
get logits processor by key from cache or create new one.
241267
@@ -249,8 +275,10 @@ def get_logits_processor(
249275
"""
250276
value = self.cache.get(schemata_key, None)
251277
if value:
252-
return value.copy(), True
253-
value = self.executor.submit(self._init_logits_processor, schemata_key)
278+
value_copy = value.copy()
279+
value_copy.enable_reasoning = enable_thinking
280+
return value_copy, True
281+
value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
254282
return value, False
255283

256284
def _get_tokenizer_hf(self):
@@ -269,7 +297,8 @@ def _get_tokenizer_hf(self):
269297
try:
270298
architectures = self.fd_config.model_config.architectures
271299
if "Ernie4_5_MoeForCausalLM" not in architectures \
272-
and "Ernie4_5_ForCausalLM" not in architectures:
300+
and "Ernie4_5_ForCausalLM" not in architectures \
301+
and "Ernie4_5_VLMoeForConditionalGeneration" not in architectures:
273302

274303
from transformers import AutoTokenizer, PreTrainedTokenizerFast
275304
tokenizer = AutoTokenizer.from_pretrained(

0 commit comments

Comments
 (0)