Skip to content

Commit d0a9be3

Browse files
committed
mm support structured output
1 parent a56d64e commit d0a9be3

File tree

15 files changed

+347
-70
lines changed

15 files changed

+347
-70
lines changed

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class MoEConfig:
144144
im_patch_id = (
145145
100295 # multimodality, TODO(liuyuanle): read from config.json
146146
)
147+
reasoning_parser: Optional[str] = None
147148

148149

149150
@dataclass

fastdeploy/engine/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,8 @@ def postprocess(self):
656656
self.max_model_len // self.cache_config.block_size)
657657

658658
if self.guided_decoding_backend == "auto":
659-
if self.enable_mm:
659+
if current_platform.is_xpu() or self.speculative_config.method is not None:
660+
llm_logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.")
660661
self.guided_decoding_backend = "off"
661662
else:
662663
self.guided_decoding_backend = "xgrammar"
@@ -718,10 +719,9 @@ def check(self):
718719
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
719720

720721
if self.guided_decoding_backend != "off":
721-
# TODO: mm support guided_decoding
722-
assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding"
723-
724722
# TODO: speculative decoding support guided_decoding
723+
assert self.speculative_config.method is None, \
724+
"speculative decoding currently do not support guided_decoding"
725725

726726
# TODO: xpu support guided_decoding
727727
assert not current_platform.is_xpu(
@@ -749,7 +749,7 @@ def print(self, file=None):
749749
if k == "generation_config" and v is not None:
750750
for gck, gcv in v.to_dict().items():
751751
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
752-
elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config":
752+
elif k in ["cache_config", "model_config", "scheduler_config", "scheduler_config", "parallel_config", "speculative_config"]:
753753
v.print()
754754
else:
755755
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))

fastdeploy/engine/engine.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,13 @@ def _insert_zmq_task_to_scheduler(self):
385385
llm_logger.debug(f"Receive request: {request}")
386386

387387
err_msg = None
388+
if ((request.guided_json is not None
389+
or request.guided_regex is not None
390+
or request.structural_tag is not None
391+
or request.guided_grammar is not None) and self.guided_decoding_checker is None):
392+
err_msg = "guided_backend is None, use --guided-decoding-backend to " \
393+
"specify the backend at server startup."
394+
388395
if self.guided_decoding_checker is not None:
389396
request, err_msg = self.guided_decoding_checker.schema_format(
390397
request)
@@ -473,6 +480,14 @@ def add_requests(self, task, sampling_params=None, **kwargs):
473480
llm_logger.error(error_msg)
474481
raise EngineError(error_msg, error_code=400)
475482

483+
if ((request.guided_json is not None
484+
or request.guided_regex is not None
485+
or request.structural_tag is not None
486+
or request.guided_grammar is not None) and self.guided_decoding_checker is None):
487+
err_msg = "guided_backend is None, use --guided-decoding-backend to specify the backend at server startup."
488+
llm_logger.error(err_msg)
489+
raise EngineError(err_msg, error_code=400)
490+
476491
if self.guided_decoding_checker is not None:
477492
request, err_msg = self.guided_decoding_checker.schema_format(
478493
request)
@@ -1021,8 +1036,8 @@ def _start_worker_service(self):
10211036
py_script = os.path.join(current_dir_path, worker_path)
10221037

10231038
ori_vocab_size = (
1024-
len(self.data_processor.tokenizer.sp_model)
1025-
if hasattr(self.data_processor.tokenizer, 'sp_model')
1039+
len(self.data_processor.tokenizer.sp_model)
1040+
if hasattr(self.data_processor.tokenizer, 'sp_model')
10261041
else len(self.data_processor.tokenizer.vocab)
10271042
)
10281043

@@ -1053,7 +1068,8 @@ def _start_worker_service(self):
10531068
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
10541069
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
10551070
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
1056-
f" --load_strategy {self.cfg.model_config.load_strategy}")
1071+
f" --load_strategy {self.cfg.model_config.load_strategy}"
1072+
f" --reasoning_parser {self.cfg.reasoning_parser}")
10571073

10581074
worker_append_flag = {
10591075
"enable_expert_parallel":

fastdeploy/engine/sampling_params.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class SamplingParams:
9090
min_tokens: int = 1
9191
logprobs: Optional[int] = None
9292
bad_words: Optional[List[str]] = None
93+
guided_decoding: Optional[GuidedDecodingParams] = None
9394

9495
@classmethod
9596
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
@@ -118,7 +119,8 @@ def from_optional(cls,
118119
reasoning_max_tokens=None,
119120
min_tokens=1,
120121
logprobs=None,
121-
bad_words=None) -> "SamplingParams":
122+
bad_words=None,
123+
guided_decoding=None) -> "SamplingParams":
122124
"""Create instance from command line arguments"""
123125
return cls(n=1 if n is None else n,
124126
best_of=best_of,
@@ -137,7 +139,8 @@ def from_optional(cls,
137139
reasoning_max_tokens=reasoning_max_tokens,
138140
min_tokens=min_tokens,
139141
logprobs=logprobs,
140-
bad_words=bad_words)
142+
bad_words=bad_words,
143+
guided_decoding=guided_decoding)
141144

142145
def __post_init__(self):
143146
if self.seed is None:
@@ -193,6 +196,9 @@ def _verify_args(self) -> None:
193196
raise ValueError("seed must be in [0, 922337203685477580], got "
194197
f"{self.seed}.")
195198

199+
if self.guided_decoding is not None:
200+
self.guided_decoding._verify_args()
201+
196202
def update_from_tokenizer(self, tokenizer):
197203
"""
198204
# TODO: Implement stop tokens and bad words support
@@ -210,3 +216,45 @@ class BeamSearchParams:
210216
temperature: float = 0.0
211217
length_penalty: float = 1.0
212218
include_stop_str_in_output: bool = False
219+
220+
221+
@dataclass
222+
class GuidedDecodingParams:
223+
"""Guided decoding parameters for text generation."""
224+
json: Optional[Union[str, dict]] = None
225+
regex: Optional[str] = None
226+
choice: Optional[List[str]] = None
227+
grammar: Optional[str] = None
228+
json_object: Optional[bool] = None
229+
structural_tag: Optional[str] = None
230+
231+
def to_dict(self):
232+
"""convert to dict"""
233+
key_dict = {
234+
"guided_json": self.json,
235+
"guided_regex": self.regex,
236+
"guided_choice": self.choice,
237+
"guided_grammar": self.grammar,
238+
"structural_tag": self.structural_tag,
239+
"guided_json_object": self.json_object,
240+
}
241+
242+
guided_dict = {}
243+
for key, value in key_dict.items():
244+
if value is not None:
245+
guided_dict[key] = value
246+
return guided_dict
247+
248+
def _verify_args(self):
249+
"""Verify the arguments."""
250+
guided_count = sum([
251+
self.json is not None, self.regex is not None, self.choice
252+
is not None, self.grammar is not None, self.json_object
253+
is not None, self.structural_tag is not None
254+
])
255+
256+
if guided_count > 1:
257+
raise ValueError(
258+
"You can only use one kind of guided decoding "
259+
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
260+
)

fastdeploy/entrypoints/llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ def _add_request(
258258
if chat_template_kwargs is not None:
259259
enable_thinking = chat_template_kwargs.get(
260260
"enable_thinking", None)
261+
if current_sampling_params.guided_decoding is not None:
262+
guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
263+
tasks.update(guided_decoding_dict)
261264
self.llm_engine.add_requests(tasks,
262265
current_sampling_params,
263266
enable_thinking=enable_thinking)

fastdeploy/model_executor/guided_decoding/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
"""
1616

1717
# from fastdeploy.config import FDConfig
18+
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
19+
BackendBase, BaseChecker, LogitsProcessorBase)
1820

19-
__all__ = ['get_guided_backend', 'schema_checker']
21+
__all__ = ['get_guided_backend', 'schema_checker', 'LogitsProcessorBase', 'BackendBase', 'BaseChecker']
2022

2123

2224
def get_guided_backend(

fastdeploy/model_executor/guided_decoding/base_guided_decoding.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from fastdeploy.config import FDConfig
2121
from fastdeploy.engine.request import Request
22+
from fastdeploy.reasoning import ReasoningParserManager
2223
from fastdeploy.utils import llm_logger
2324

2425

@@ -34,8 +35,9 @@ class LogitsProcessorBase:
3435
None (all state should be managed by subclasses)
3536
"""
3637

37-
def __init__(self):
38-
pass
38+
def __init__(self, enable_reasoning):
39+
self.reasoning_ended = False
40+
self.enable_reasoning = enable_reasoning
3941

4042
def fill_token_bitmask(self, token_bitmask, idx):
4143
"""
@@ -136,8 +138,13 @@ def __init__(self, fd_config: FDConfig):
136138
self.fd_config = fd_config
137139
self.executor = ThreadPoolExecutor()
138140
self.max_cache_size = 2048
141+
self.reasoning_parser = None
139142

140143
self.hf_tokenizer = self._get_tokenizer_hf()
144+
if self.fd_config.model_config.reasoning_parser:
145+
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(
146+
self.fd_config.model_config.reasoning_parser)
147+
self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer)
141148

142149
def _create_processor(self):
143150
"""
@@ -148,71 +155,89 @@ def _create_processor(self):
148155
"""
149156
raise NotImplementedError()
150157

151-
def _json_processor(self, schemata):
158+
def _json_processor(self, schemata, enable_thinking=False):
152159
"""
153160
Process JSON schemata.
154161
155162
Args:
156163
schemata (str): The schemata string.
164+
enable_thinking (bool): Whether to enable thinking mode.
157165
158166
Raises:
159167
NotImplementedError: This method should be implemented in subclasses.
160168
"""
161169
raise NotImplementedError()
162170

163-
def _regex_processor(self, schemata):
171+
def _regex_processor(self, schemata, enable_thinking=False):
164172
"""
165173
Process regular expression schemata.
166174
167175
Args:
168176
schemata (str): The schemata string.
177+
enable_thinking (bool): Whether to enable thinking mode.
169178
170179
Raises:
171180
NotImplementedError: This method should be implemented in subclasses.
172181
"""
173182
raise NotImplementedError()
174183

175-
def _grammar_processor(self, schemata):
184+
def _grammar_processor(self, schemata, enable_thinking=False):
176185
"""
177186
Process grammar schemata.
178187
179188
Args:
180189
schemata (str): The schemata string.
190+
enable_thinking (bool): Whether to enable thinking mode.
181191
182192
Raises:
183193
NotImplementedError: This method should be implemented in subclasses.
184194
"""
185195
raise NotImplementedError()
186196

187-
def _structural_tag_processor(self, schemata):
197+
def _structural_tag_processor(self, schemata, enable_thinking=False):
188198
"""
189199
Process structural tag schemata.
190200
191201
Args:
192202
schemata (str): The schemata string.
203+
enable_thinking (bool): Whether to enable thinking mode.
193204
194205
Raises:
195206
NotImplementedError: This method should be implemented in subclasses.
196207
"""
197208
raise NotImplementedError()
198209

199-
def _unsupported_processor_type(self, key_type, schemata):
210+
def _unsupported_processor_type(self, key_type, schemata, enable_thinking=False):
200211
"""
201212
Process unsupported type.
202213
203214
Args:
204215
key_type (str): The key type string.
205216
schemata (str): The schemata string.
217+
enable_thinking (bool): Whether to enable thinking mode.
206218
"""
207219
raise Exception(f"Unsupported processor type {key_type}.")
208220

221+
def get_reasoning_parser(self):
222+
"""
223+
Get reasoning parser object.
224+
225+
Returns:
226+
ReasoningParser: Reasoning parser object or None
227+
"""
228+
return self.reasoning_parser
229+
209230
def _init_logits_processor(
210-
self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
231+
self,
232+
schemata_key: tuple[str, str],
233+
enable_thinking: bool = False,
234+
) -> LogitsProcessorBase:
211235
"""
212236
init logits processor by type and schemata.
213237
214238
Args:
215239
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
240+
enable_thinking (bool): Whether to enable thinking step
216241
217242
Returns:
218243
LogitsProcessorBase: Initialized logits processor instance
@@ -222,20 +247,21 @@ def _init_logits_processor(
222247
"""
223248
key_type, schemata = schemata_key
224249
if key_type == "json":
225-
return self._json_processor(schemata)
250+
return self._json_processor(schemata, enable_thinking)
226251
elif key_type == "regex":
227-
return self._regex_processor(schemata)
252+
return self._regex_processor(schemata, enable_thinking)
228253
elif key_type == "grammar":
229-
return self._grammar_processor(schemata)
254+
return self._grammar_processor(schemata, enable_thinking)
230255
elif key_type == "structural_tag":
231-
return self._structural_tag_processor(schemata)
256+
return self._structural_tag_processor(schemata, enable_thinking)
232257
else:
233258
llm_logger.error(f"Unsupported processor type {key_type}.")
234259
return None
235260

236261
def get_logits_processor(
237262
self,
238-
schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
263+
schemata_key: tuple[str, str],
264+
enable_thinking: bool = False) -> tuple[LogitsProcessorBase, bool]:
239265
"""
240266
get logits processor by key from cache or create new one.
241267
@@ -249,8 +275,10 @@ def get_logits_processor(
249275
"""
250276
value = self.cache.get(schemata_key, None)
251277
if value:
252-
return value.copy(), True
253-
value = self.executor.submit(self._init_logits_processor, schemata_key)
278+
value_copy = value.copy()
279+
value_copy.enable_reasoning = enable_thinking
280+
return value_copy, True
281+
value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
254282
return value, False
255283

256284
def _get_tokenizer_hf(self):
@@ -269,7 +297,8 @@ def _get_tokenizer_hf(self):
269297
try:
270298
architectures = self.fd_config.model_config.architectures
271299
if "Ernie4_5_MoeForCausalLM" not in architectures \
272-
and "Ernie4_5_ForCausalLM" not in architectures:
300+
and "Ernie4_5_ForCausalLM" not in architectures \
301+
and "Ernie4_5_VLMoeForConditionalGeneration" not in architectures:
273302

274303
from transformers import AutoTokenizer, PreTrainedTokenizerFast
275304
tokenizer = AutoTokenizer.from_pretrained(

0 commit comments

Comments
 (0)