Skip to content

[Feature] mm and thinking model support structred output #2749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions docs/features/structured_outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,65 @@ ParsedChatCompletionMessage[Info](content='{"addr": "No.1 Century Avenue, Pudong
Address: No.1 Century Avenue, Pudong New Area, Shanghai
Height: 468
```

### Offline Inference

Offline inference allows restricting the model's output format by pre-specified constraints. In `FastDeploy`, constraints can be specified through the `GuidedDecodingParams` class in `SamplingParams`. `GuidedDecodingParams` supports the following constraint types, with usage similar to online inference:

```python
json: Optional[Union[str, dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
structural_tag: Optional[str] = None
```

The following example demonstrates how to use offline inference to generate a structured json:

```python
from fastdeploy import LLM, SamplingParams
from fastdeploy.engine.sampling_params import GuidedDecodingParams
from pydantic import BaseModel
from enum import Enum

class BookType(str, Enum):
romance = "Romance"
historical = "Historical"
adventure = "Adventure"
mystery = "Mystery"
dystopian = "Dystopian"

class BookDescription(BaseModel):
author: str
title: str
genre: BookType

# Constrained decoding parameters
guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())

# Sampling parameters
sampling_params = SamplingParams(
top_p=0.95,
max_tokens=6400,
guided_decoding=guided_decoding_params,
)

# Load model
llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")

outputs = llm.generate(
prompts="Generate a JSON describing a literary work, including author, title and book type.",
sampling_params=sampling_params,
)

# Output results
for output in outputs:
print(output.outputs.text)
```

Output:

```
{"author": "George Orwell", "title": "1984", "genre": "Dystopian"}
```
64 changes: 64 additions & 0 deletions docs/zh/features/structured_outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,67 @@ ParsedChatCompletionMessage[Info](content='{"addr": "上海市浦东新区世纪
地址: 上海市浦东新区世纪大道1号
高度: 468
```

### 离线推理

离线推理允许通过预先指定约束条件,限制模型输出格式。在 `FastDeploy` 中,支持通过 `SamplingParams` 中的 `GuidedDecodingParams` 类指定相关约束条件。`GuidedDecodingParams` 支持以下几种约束条件,使用方式可以参考在线推理:

```python
json: Optional[Union[str, dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
structural_tag: Optional[str] = None
```

以下示例展示了如何使用离线推理生成一个结构化的 json :

```python

from fastdeploy import LLM, SamplingParams
from fastdeploy.engine.sampling_params import GuidedDecodingParams
from pydantic import BaseModel
from enum import Enum

class BookType(str, Enum):
romance = "Romance"
historical = "Historical"
adventure = "Adventure"
mystery = "Mystery"
dystopian = "Dystopian"

class BookDescription(BaseModel):
author: str
title: str
genre: BookType

# Constrained decoding parameters
guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())

# Sampling parameters
sampling_params = SamplingParams(
top_p=0.95,
max_tokens=6400,
guided_decoding=guided_decoding_params,
)

# Load model
llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")

outputs = llm.generate(
prompts="生成一个JSON,描述一本中国的著作,要包含作者、标题和书籍类型。",
sampling_params=sampling_params,
)

# Output results
for output in outputs:
print(output.outputs.text)

```

输出

```
{"author": "曹雪芹", "title": "红楼梦", "genre": "Historical"}
```
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class MoEConfig:
im_patch_id = (
100295 # multimodality, TODO(liuyuanle): read from config.json
)
reasoning_parser: Optional[str] = None


@dataclass
Expand Down
14 changes: 5 additions & 9 deletions fastdeploy/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,8 @@ def postprocess(self):
self.max_model_len // self.cache_config.block_size)

if self.guided_decoding_backend == "auto":
if self.enable_mm:
if current_platform.is_xpu() or self.speculative_config.method is not None:
llm_logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.")
self.guided_decoding_backend = "off"
else:
self.guided_decoding_backend = "xgrammar"
Expand Down Expand Up @@ -795,10 +796,9 @@ def check(self):
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."

if self.guided_decoding_backend != "off":
# TODO: mm support guided_decoding
assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding"

# TODO: speculative decoding support guided_decoding
assert self.speculative_config.method is None, \
"speculative decoding currently do not support guided_decoding"

# TODO: xpu support guided_decoding
assert not current_platform.is_xpu(
Expand Down Expand Up @@ -826,11 +826,7 @@ def print(self, file=None):
if k == "generation_config" and v is not None:
for gck, gcv in v.to_dict().items():
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
elif (k == "cache_config" or
k == "model_config" or
k == "scheduler_config" or
k == "parallel_config" or
k == "commit_config"):
elif k in ["cache_config", "model_config", "commit_config", "scheduler_config", "parallel_config", "speculative_config"]:
v.print()
else:
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
Expand Down
18 changes: 17 additions & 1 deletion fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ def _insert_zmq_task_to_scheduler(self):
llm_logger.debug(f"Receive request: {request}")

err_msg = None
if ((request.guided_json is not None
or request.guided_regex is not None
or request.structural_tag is not None
or request.guided_grammar is not None) and self.guided_decoding_checker is None):
err_msg = "guided_backend is None, use --guided-decoding-backend to " \
"specify the backend at server startup."

if self.guided_decoding_checker is not None:
request, err_msg = self.guided_decoding_checker.schema_format(
request)
Expand Down Expand Up @@ -447,6 +454,14 @@ def add_requests(self, task, sampling_params=None, **kwargs):
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)

if ((request.guided_json is not None
or request.guided_regex is not None
or request.structural_tag is not None
or request.guided_grammar is not None) and self.guided_decoding_checker is None):
err_msg = "guided_backend is None, use --guided-decoding-backend to specify the backend at server startup."
llm_logger.error(err_msg)
raise EngineError(err_msg, error_code=400)

if self.guided_decoding_checker is not None:
request, err_msg = self.guided_decoding_checker.schema_format(
request)
Expand Down Expand Up @@ -1030,7 +1045,8 @@ def _start_worker_service(self):
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}")
f" --load_strategy {self.cfg.model_config.load_strategy}"
f" --reasoning_parser {self.cfg.reasoning_parser}")

worker_append_flag = {
"enable_expert_parallel":
Expand Down
49 changes: 47 additions & 2 deletions fastdeploy/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class SamplingParams:
min_tokens: int = 1
logprobs: Optional[int] = None
bad_words: Optional[List[str]] = None
guided_decoding: Optional[GuidedDecodingParams] = None

@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
Expand Down Expand Up @@ -121,7 +122,8 @@ def from_optional(cls,
reasoning_max_tokens=None,
min_tokens=1,
logprobs=None,
bad_words=None) -> "SamplingParams":
bad_words=None,
guided_decoding=None) -> "SamplingParams":
"""Create instance from command line arguments"""
return cls(n=1 if n is None else n,
best_of=best_of,
Expand All @@ -141,7 +143,8 @@ def from_optional(cls,
reasoning_max_tokens=reasoning_max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words)
bad_words=bad_words,
guided_decoding=guided_decoding)

def __post_init__(self):
if self.seed is None:
Expand Down Expand Up @@ -224,3 +227,45 @@ class BeamSearchParams:
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False


@dataclass
class GuidedDecodingParams:
"""Guided decoding parameters for text generation."""
json: Optional[Union[str, dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
structural_tag: Optional[str] = None

def to_dict(self):
"""convert to dict"""
key_dict = {
"guided_json": self.json,
"guided_regex": self.regex,
"guided_choice": self.choice,
"guided_grammar": self.grammar,
"structural_tag": self.structural_tag,
"guided_json_object": self.json_object,
}

guided_dict = {}
for key, value in key_dict.items():
if value is not None:
guided_dict[key] = value
return guided_dict

def __post_init__(self):
"""Verify the arguments."""
guided_count = sum([
self.json is not None, self.regex is not None, self.choice
is not None, self.grammar is not None, self.json_object
is not None, self.structural_tag is not None
])

if guided_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
)
3 changes: 3 additions & 0 deletions fastdeploy/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ def _add_request(
if chat_template_kwargs is not None:
enable_thinking = chat_template_kwargs.get(
"enable_thinking", None)
if current_sampling_params.guided_decoding is not None:
guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
tasks.update(guided_decoding_dict)
self.llm_engine.add_requests(tasks,
current_sampling_params,
enable_thinking=enable_thinking)
Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/input/ernie_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def process_request(self, request, max_model_len=None, **kwargs):

if request.prompt_token_ids is None or len(
request.prompt_token_ids) == 0:
system = request.get("system")
# system = request.get("system")
Copy link
Preview

Copilot AI Jul 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This commented-out dead code can be removed to clean up the implementation and avoid confusion.

Suggested change
# system = request.get("system")

Copilot uses AI. Check for mistakes.

if request.prompt is None and request.messages is None:
raise ValueError(
f"The request should have `input_ids`, `text` or `messages`: {request}.")
Expand Down Expand Up @@ -149,7 +149,7 @@ def process_request_dict(self, request, max_model_len=None):
request['stop_token_ids'] = stop_seqs
request['stop_seqs_len'] = stop_seqs_len

system = request.get("system")
# system = request.get("system")
# 处理prompt_token_ids
if not request.get('prompt_token_ids'):
if request.get('prompt') is None and request.get(
Expand Down Expand Up @@ -213,7 +213,7 @@ def process_response(self, response_dict, **kwargs):
response_dict.outputs.reasoning_content = reasoning_content
else:
response_dict.outputs.text = full_text
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
data_processor_logger.info(f"req_id:{req_id}, token ids: {token_ids}")
if response_dict.outputs.text == "" and \
response_dict.outputs.reasoning_content == "":
return None
Expand Down
4 changes: 3 additions & 1 deletion fastdeploy/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""

# from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
BackendBase, BaseChecker, LogitsProcessorBase)

__all__ = ['get_guided_backend', 'schema_checker']
__all__ = ['get_guided_backend', 'schema_checker', 'LogitsProcessorBase', 'BackendBase', 'BaseChecker']


def get_guided_backend(
Expand Down
Loading