From 27a001ce51f164b6e20137246c64248db5084477 Mon Sep 17 00:00:00 2001 From: kevin Date: Mon, 7 Jul 2025 11:27:40 +0800 Subject: [PATCH 1/5] mm support structured output --- fastdeploy/config.py | 1 + fastdeploy/engine/config.py | 14 +-- fastdeploy/engine/engine.py | 18 ++- fastdeploy/engine/sampling_params.py | 52 ++++++++- fastdeploy/entrypoints/llm.py | 3 + .../guided_decoding/__init__.py | 4 +- .../guided_decoding/base_guided_decoding.py | 61 +++++++--- .../guided_decoding/xgrammar_backend.py | 38 +++--- .../model_executor/layers/sample/sampler.py | 51 +++++++- .../reasoning/ernie_vl_reasoning_parsers.py | 21 +++- .../reasoning/qwen3_reasoning_parsers.py | 17 +++ fastdeploy/worker/gpu_model_runner.py | 5 +- fastdeploy/worker/vl_gpu_model_runner.py | 109 ++++++++++++++++-- fastdeploy/worker/vl_worker_process.py | 7 +- fastdeploy/worker/worker_process.py | 14 +++ 15 files changed, 347 insertions(+), 68 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 3d7b0caadb..cd7d630d78 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -146,6 +146,7 @@ class MoEConfig: im_patch_id = ( 100295 # multimodality, TODO(liuyuanle): read from config.json ) + reasoning_parser: Optional[str] = None @dataclass diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index ee7a8c3670..8bb170977b 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -733,7 +733,8 @@ def postprocess(self): self.max_model_len // self.cache_config.block_size) if self.guided_decoding_backend == "auto": - if self.enable_mm: + if current_platform.is_xpu() or self.speculative_config.method is not None: + llm_logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.") self.guided_decoding_backend = "off" else: self.guided_decoding_backend = "xgrammar" @@ -795,10 +796,9 @@ def check(self): f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}." if self.guided_decoding_backend != "off": - # TODO: mm support guided_decoding - assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding" - # TODO: speculative decoding support guided_decoding + assert self.speculative_config.method is None, \ + "speculative decoding currently do not support guided_decoding" # TODO: xpu support guided_decoding assert not current_platform.is_xpu( @@ -826,11 +826,7 @@ def print(self, file=None): if k == "generation_config" and v is not None: for gck, gcv in v.to_dict().items(): llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv)) - elif (k == "cache_config" or - k == "model_config" or - k == "scheduler_config" or - k == "parallel_config" or - k == "commit_config"): + elif k in ["cache_config", "model_config", "commit_config", "scheduler_config", "parallel_config", "speculative_config"]: v.print() else: llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index e50f7d6702..d9fbc7a75e 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -359,6 +359,13 @@ def _insert_zmq_task_to_scheduler(self): llm_logger.debug(f"Receive request: {request}") err_msg = None + if ((request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None) and self.guided_decoding_checker is None): + err_msg = "guided_backend is None, use --guided-decoding-backend to " \ + "specify the backend at server startup." + if self.guided_decoding_checker is not None: request, err_msg = self.guided_decoding_checker.schema_format( request) @@ -447,6 +454,14 @@ def add_requests(self, task, sampling_params=None, **kwargs): llm_logger.error(error_msg) raise EngineError(error_msg, error_code=400) + if ((request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None) and self.guided_decoding_checker is None): + err_msg = "guided_backend is None, use --guided-decoding-backend to specify the backend at server startup." + llm_logger.error(err_msg) + raise EngineError(err_msg, error_code=400) + if self.guided_decoding_checker is not None: request, err_msg = self.guided_decoding_checker.schema_format( request) @@ -1030,7 +1045,8 @@ def _start_worker_service(self): f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}" f" --max_capture_batch_size {self.cfg.max_capture_batch_size}" f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" - f" --load_strategy {self.cfg.model_config.load_strategy}") + f" --load_strategy {self.cfg.model_config.load_strategy}" + f" --reasoning_parser {self.cfg.reasoning_parser}") worker_append_flag = { "enable_expert_parallel": diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index a7912407a8..bcbb5ec55e 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -92,6 +92,7 @@ class SamplingParams: min_tokens: int = 1 logprobs: Optional[int] = None bad_words: Optional[List[str]] = None + guided_decoding: Optional[GuidedDecodingParams] = None @classmethod def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams": @@ -121,7 +122,8 @@ def from_optional(cls, reasoning_max_tokens=None, min_tokens=1, logprobs=None, - bad_words=None) -> "SamplingParams": + bad_words=None, + guided_decoding=None) -> "SamplingParams": """Create instance from command line arguments""" return cls(n=1 if n is None else n, best_of=best_of, @@ -141,7 +143,8 @@ def from_optional(cls, reasoning_max_tokens=reasoning_max_tokens, min_tokens=min_tokens, logprobs=logprobs, - bad_words=bad_words) + bad_words=bad_words, + guided_decoding=guided_decoding) def __post_init__(self): if self.seed is None: @@ -207,6 +210,9 @@ def _verify_args(self) -> None: raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.") + if self.guided_decoding is not None: + self.guided_decoding._verify_args() + def update_from_tokenizer(self, tokenizer): """ # TODO: Implement stop tokens and bad words support @@ -224,3 +230,45 @@ class BeamSearchParams: temperature: float = 0.0 length_penalty: float = 1.0 include_stop_str_in_output: bool = False + + +@dataclass +class GuidedDecodingParams: + """Guided decoding parameters for text generation.""" + json: Optional[Union[str, dict]] = None + regex: Optional[str] = None + choice: Optional[List[str]] = None + grammar: Optional[str] = None + json_object: Optional[bool] = None + structural_tag: Optional[str] = None + + def to_dict(self): + """convert to dict""" + key_dict = { + "guided_json": self.json, + "guided_regex": self.regex, + "guided_choice": self.choice, + "guided_grammar": self.grammar, + "structural_tag": self.structural_tag, + "guided_json_object": self.json_object, + } + + guided_dict = {} + for key, value in key_dict.items(): + if value is not None: + guided_dict[key] = value + return guided_dict + + def _verify_args(self): + """Verify the arguments.""" + guided_count = sum([ + self.json is not None, self.regex is not None, self.choice + is not None, self.grammar is not None, self.json_object + is not None, self.structural_tag is not None + ]) + + if guided_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')." + ) diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 6c0ce4997d..51f8aa1ff3 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -273,6 +273,9 @@ def _add_request( if chat_template_kwargs is not None: enable_thinking = chat_template_kwargs.get( "enable_thinking", None) + if current_sampling_params.guided_decoding is not None: + guided_decoding_dict = current_sampling_params.guided_decoding.to_dict() + tasks.update(guided_decoding_dict) self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking) diff --git a/fastdeploy/model_executor/guided_decoding/__init__.py b/fastdeploy/model_executor/guided_decoding/__init__.py index 53163f2c22..91f951b15b 100644 --- a/fastdeploy/model_executor/guided_decoding/__init__.py +++ b/fastdeploy/model_executor/guided_decoding/__init__.py @@ -15,8 +15,10 @@ """ # from fastdeploy.config import FDConfig +from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( + BackendBase, BaseChecker, LogitsProcessorBase) -__all__ = ['get_guided_backend', 'schema_checker'] +__all__ = ['get_guided_backend', 'schema_checker', 'LogitsProcessorBase', 'BackendBase', 'BaseChecker'] def get_guided_backend( diff --git a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py index 0449943f9d..9704eded5e 100644 --- a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py +++ b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py @@ -19,6 +19,7 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request +from fastdeploy.reasoning import ReasoningParserManager from fastdeploy.utils import llm_logger @@ -34,8 +35,9 @@ class LogitsProcessorBase: None (all state should be managed by subclasses) """ - def __init__(self): - pass + def __init__(self, enable_reasoning): + self.reasoning_ended = False + self.enable_reasoning = enable_reasoning def fill_token_bitmask(self, token_bitmask, idx): """ @@ -136,8 +138,13 @@ def __init__(self, fd_config: FDConfig): self.fd_config = fd_config self.executor = ThreadPoolExecutor() self.max_cache_size = 2048 + self.reasoning_parser = None self.hf_tokenizer = self._get_tokenizer_hf() + if self.fd_config.model_config.reasoning_parser: + reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser( + self.fd_config.model_config.reasoning_parser) + self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer) def _create_processor(self): """ @@ -148,71 +155,89 @@ def _create_processor(self): """ raise NotImplementedError() - def _json_processor(self, schemata): + def _json_processor(self, schemata, enable_thinking=False): """ Process JSON schemata. Args: schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. Raises: NotImplementedError: This method should be implemented in subclasses. """ raise NotImplementedError() - def _regex_processor(self, schemata): + def _regex_processor(self, schemata, enable_thinking=False): """ Process regular expression schemata. Args: schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. Raises: NotImplementedError: This method should be implemented in subclasses. """ raise NotImplementedError() - def _grammar_processor(self, schemata): + def _grammar_processor(self, schemata, enable_thinking=False): """ Process grammar schemata. Args: schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. Raises: NotImplementedError: This method should be implemented in subclasses. """ raise NotImplementedError() - def _structural_tag_processor(self, schemata): + def _structural_tag_processor(self, schemata, enable_thinking=False): """ Process structural tag schemata. Args: schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. Raises: NotImplementedError: This method should be implemented in subclasses. """ raise NotImplementedError() - def _unsupported_processor_type(self, key_type, schemata): + def _unsupported_processor_type(self, key_type, schemata, enable_thinking=False): """ Process unsupported type. Args: key_type (str): The key type string. schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. """ raise Exception(f"Unsupported processor type {key_type}.") + def get_reasoning_parser(self): + """ + Get reasoning parser object. + + Returns: + ReasoningParser: Reasoning parser object or None + """ + return self.reasoning_parser + def _init_logits_processor( - self, schemata_key: tuple[str, str]) -> LogitsProcessorBase: + self, + schemata_key: tuple[str, str], + enable_thinking: bool = False, + ) -> LogitsProcessorBase: """ init logits processor by type and schemata. Args: schemata_key (tuple[str, str]): Tuple containing processor type and schema string + enable_thinking (bool): Whether to enable thinking step Returns: LogitsProcessorBase: Initialized logits processor instance @@ -222,20 +247,21 @@ def _init_logits_processor( """ key_type, schemata = schemata_key if key_type == "json": - return self._json_processor(schemata) + return self._json_processor(schemata, enable_thinking) elif key_type == "regex": - return self._regex_processor(schemata) + return self._regex_processor(schemata, enable_thinking) elif key_type == "grammar": - return self._grammar_processor(schemata) + return self._grammar_processor(schemata, enable_thinking) elif key_type == "structural_tag": - return self._structural_tag_processor(schemata) + return self._structural_tag_processor(schemata, enable_thinking) else: llm_logger.error(f"Unsupported processor type {key_type}.") return None def get_logits_processor( self, - schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]: + schemata_key: tuple[str, str], + enable_thinking: bool = False) -> tuple[LogitsProcessorBase, bool]: """ get logits processor by key from cache or create new one. @@ -249,8 +275,10 @@ def get_logits_processor( """ value = self.cache.get(schemata_key, None) if value: - return value.copy(), True - value = self.executor.submit(self._init_logits_processor, schemata_key) + value_copy = value.copy() + value_copy.enable_reasoning = enable_thinking + return value_copy, True + value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking) return value, False def _get_tokenizer_hf(self): @@ -269,7 +297,8 @@ def _get_tokenizer_hf(self): try: architectures = self.fd_config.model_config.architectures if "Ernie4_5_MoeForCausalLM" not in architectures \ - and "Ernie4_5_ForCausalLM" not in architectures: + and "Ernie4_5_ForCausalLM" not in architectures \ + and "Ernie4_5_VLMoeForConditionalGeneration" not in architectures: from transformers import AutoTokenizer, PreTrainedTokenizerFast tokenizer = AutoTokenizer.from_pretrained( diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py index 74b1c29528..fc356a506c 100644 --- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py +++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py @@ -23,8 +23,9 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( - BackendBase, BaseChecker, LogitsProcessorBase) +from fastdeploy.model_executor.guided_decoding import (BackendBase, + BaseChecker, + LogitsProcessorBase) from fastdeploy.utils import llm_logger try: @@ -47,7 +48,6 @@ class XGrammarProcessor(LogitsProcessorBase): max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch vocab_size (int): Size of the vocabulary batch_size (int): Batch size for processing - splitwise_role (str): Role for splitwise processing compiled_grammar (CompiledGrammar): Compiled grammar rules terminate_without_stop_token (bool): Whether to terminate without stop token override_stop_tokens (Optional[List[int]]): Custom stop tokens @@ -61,13 +61,12 @@ def __init__( override_stop_tokens: Optional[List[int]] = None, vocab_size: Optional[int] = None, batch_size: Optional[int] = None, - splitwise_role: str = "mixed", + enable_thinking: bool = False, ): - super().__init__() + super().__init__(enable_reasoning=enable_thinking) self.max_rollback_tokens = 200 self.vocab_size = vocab_size self.batch_size = batch_size - self.splitwise_role = splitwise_role self.compiled_grammar = compiled_grammar self.terminate_without_stop_token = terminate_without_stop_token self.override_stop_tokens = override_stop_tokens @@ -180,7 +179,6 @@ def copy(self) -> "XGrammarProcessor": override_stop_tokens=self.override_stop_tokens, vocab_size=self.vocab_size, batch_size=self.batch_size, - splitwise_role=self.splitwise_role, ) @@ -195,7 +193,6 @@ class XGrammarBackend(BackendBase): vocab_size (int): Size of the vocabulary from config batch_size (int): Maximum batch size from config any_whitespace (bool): Whether to allow any whitespace in JSON - splitwise_role (str): Role for splitwise processing grammar_compiler (GrammarCompiler): Grammar compilation engine """ @@ -209,7 +206,6 @@ def __init__( self.batch_size = fd_config.parallel_config.max_num_seqs self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace - self.splitwise_role = fd_config.parallel_config.splitwise_role try: tokenizer_info = TokenizerInfo.from_huggingface( @@ -224,6 +220,7 @@ def _create_processor( compiled_grammar: CompiledGrammar, terminate_without_stop_token: bool = False, override_stop_tokens: Optional[List[int]] = None, + enable_thinking: bool = False, ) -> XGrammarProcessor: """ Create a logits processor instance for the given compiled grammar. @@ -232,6 +229,7 @@ def _create_processor( compiled_grammar (CompiledGrammar): Compiled grammar rules terminate_without_stop_token (bool): Whether to terminate without stop token override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults + enable_thinking (bool): Whether to enable thinking mode Returns: XGrammarProcessor: Configured grammar processor instance @@ -242,15 +240,16 @@ def _create_processor( override_stop_tokens=override_stop_tokens, vocab_size=self.vocab_size, batch_size=self.batch_size, - splitwise_role=self.splitwise_role, + enable_thinking=enable_thinking, ) - def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]: + def _json_processor(self, schemata: str, enable_thinking=False) -> Optional[XGrammarProcessor]: """ Compile JSON schema into a grammar processor. Args: schemata (str): JSON schema string to compile + enable_thinking (bool): Whether to enable thinking mode Returns: Optional[XGrammarProcessor]: Configured processor if successful, None on failure @@ -261,14 +260,15 @@ def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]: except Exception as e: llm_logger.error(f"Failed to compile json schema: {e}") return None - return self._create_processor(compiled_grammar) + return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) - def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]: + def _regex_processor(self, schemata: str, enable_thinking=False) -> Optional[XGrammarProcessor]: """ Compile regex pattern into a grammar processor. Args: schemata (str): Regex pattern string to compile + enable_thinking (bool): Whether to enable thinking mode Returns: Optional[XGrammarProcessor]: Configured processor if successful, None on failure @@ -278,14 +278,15 @@ def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]: except Exception as e: llm_logger.error(f"Failed to compile regex schema: {e}") return None - return self._create_processor(compiled_grammar) + return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) - def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]: + def _grammar_processor(self, schemata: str, enable_thinking=False) -> Optional[XGrammarProcessor]: """ Compile grammar (EBNF) into a grammar processor. Args: schemata (str): Grammar string in EBNF format + enable_thinking (bool): Whether to enable thinking mode Returns: Optional[XGrammarProcessor]: Configured processor if successful, None on failure @@ -295,15 +296,16 @@ def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]: except Exception as e: llm_logger.error(f"Failed to compile ebnf schema: {e}") return None - return self._create_processor(compiled_grammar) + return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) def _structural_tag_processor( - self, schemata: str) -> Optional[XGrammarProcessor]: + self, schemata: str, enable_thinking=False) -> Optional[XGrammarProcessor]: """ Compile structural tags into a grammar processor. Args: schemata (str): JSON string containing structural tag definitions + enable_thinking (bool): Whether to enable thinking mode Returns: Optional[XGrammarProcessor]: Configured processor if successful, None on failure @@ -323,7 +325,7 @@ def _structural_tag_processor( except Exception as e: llm_logger.error(f"Failed to compile structural tags schema: {e}") return None - return self._create_processor(compiled_grammar) + return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) class XGrammarChecker(BaseChecker): diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 162bbc347f..e48966b585 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -22,14 +22,14 @@ import paddle.nn.functional as F from fastdeploy.config import FDConfig -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \ - LogitsProcessorBase +from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.ops import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, top_k_top_p_sampling) from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput +from fastdeploy.reasoning import ReasoningParser class SamplerProcessor: @@ -43,6 +43,10 @@ def __init__(self): self.logits_processor: Dict[int, Optional[Any]] = dict() self.executor = ThreadPoolExecutor() self.logits_lock = threading.Lock() + self.reasoning_parser = None + + def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + self.reasoning_parser = reasoning_parser def add_logits_processor(self, ids: int, @@ -119,7 +123,13 @@ def apply_token_mask(self, if available_processors is None: return logits - indices = list(self.logits_processor.keys()) + indices = [] + for idx, processor in self.logits_processor.items(): + if processor is None: + continue + if not processor.enable_reasoning or processor.reasoning_ended: + indices.append(idx) + mask_idx = [i for i in indices if i not in skip_idx_list] return available_processors.apply_token_mask(logits, self.token_bitmask, @@ -135,6 +145,15 @@ def _accept_token(self, idx: int, token: int): if self.logits_processor[idx].is_terminated(): return + if ( + self.reasoning_parser is not None + and self.logits_processor[idx].enable_reasoning + and not self.logits_processor[idx].reasoning_ended + ): + reasoning_ended = self.reasoning_parser.is_reasoning_end([token]) + self.logits_processor[idx].reasoning_ended = reasoning_ended + return + self.logits_processor[idx].accept_token(token) def update_output_tokens(self, @@ -179,6 +198,10 @@ def __init__(self): self.processor = SamplerProcessor() + def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + """ set reasoning parser """ + self.processor.apply_reasoning_parser(reasoning_parser) + def apply_logits_processor(self, ids: int, future: Optional[Any] = None, @@ -237,6 +260,10 @@ def gather_logprobs( return LogprobsTensors(indices, top_logprobs, token_ranks) + def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): + """ post process after running """ + self.processor.update_output_tokens(next_tokens, skip_idx_list) + def forward_cuda( self, logits: paddle.Tensor, @@ -271,8 +298,6 @@ def forward_cuda( logprobs_tensors = None if num_logprobs is None else \ self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) - self.processor.update_output_tokens(next_tokens, skip_idx_list) - sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape # [num_requests, 1], where each row represents one generated @@ -300,6 +325,10 @@ def __init__(self, fd_config: FDConfig): self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode + def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + """ set reasoning parser """ + pass + def pre_process(self, skip_idx_list: List[int] = []): """ pre process before running """ pass @@ -311,6 +340,10 @@ def apply_logits_processor(self, """ apply logits processor to sampler """ pass + def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): + """ post process after running """ + pass + def forward_cuda( self, logits: paddle.Tensor, @@ -392,6 +425,14 @@ def __init__(self, fd_config: FDConfig): else: raise NotImplementedError() + def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + """ set reasoning parser """ + pass + + def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): + """ post process after running """ + pass + def pre_process(self, skip_idx_list: List[int] = []): """ pre process before running """ pass diff --git a/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py b/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py index c1814e20b8..52efe1b97b 100644 --- a/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py +++ b/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py @@ -17,7 +17,7 @@ from typing import Optional, Union from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) + DeltaMessage) from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager @@ -48,6 +48,23 @@ def __init__(self, tokenizer): "Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!") + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + """ + Check if the reasoning content ends in the input_ids. + + It is used in structured engines like `xgrammar` to check if the + reasoning content ends in the model output. + + Parameters: + input_ids: list[int] + The input_ids of the model output. + + Returns: + bool + True if the reasoning content ends in the input_ids. + """ + return self.think_end_token_id in input_ids + def extract_reasoning_content_streaming( self, previous_text: str, @@ -103,4 +120,4 @@ def extract_reasoning_content( self.think_end_token) final_content = content or "" - return reasoning_content, final_content \ No newline at end of file + return reasoning_content, final_content diff --git a/fastdeploy/reasoning/qwen3_reasoning_parsers.py b/fastdeploy/reasoning/qwen3_reasoning_parsers.py index 9e3aae5924..7f86d71ec9 100644 --- a/fastdeploy/reasoning/qwen3_reasoning_parsers.py +++ b/fastdeploy/reasoning/qwen3_reasoning_parsers.py @@ -50,6 +50,23 @@ def __init__(self, tokenizer): "Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!") + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + """ + Check if the reasoning content ends in the input_ids. + + It is used in structured engines like `xgrammar` to check if the + reasoning content ends in the model output. + + Parameters: + input_ids: list[int] + The input_ids of the model output. + + Returns: + bool + True if the reasoning content ends in the input_ids. + """ + return self.think_end_token_id in input_ids + def extract_reasoning_content_streaming( self, previous_text: str, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bb1080f75e..80e19162f7 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -24,9 +24,8 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request -from fastdeploy.model_executor.guided_decoding import get_guided_backend -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \ - LogitsProcessorBase +from fastdeploy.model_executor.guided_decoding import (LogitsProcessorBase, + get_guided_backend) from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import \ AttentionBackend diff --git a/fastdeploy/worker/vl_gpu_model_runner.py b/fastdeploy/worker/vl_gpu_model_runner.py index 5ad5c0f724..f09eeb4725 100644 --- a/fastdeploy/worker/vl_gpu_model_runner.py +++ b/fastdeploy/worker/vl_gpu_model_runner.py @@ -17,7 +17,7 @@ import json import os import random -from typing import Optional +from typing import List, Optional import numpy as np import paddle @@ -29,8 +29,11 @@ KVCacheConfig, LoadConfig, ModelConfig, MoEConfig, MoEPhase, ParallelConfig, SpeculativeConfig) +from fastdeploy.engine.request import Request from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.input.mm_processor import DataProcessor +from fastdeploy.model_executor.guided_decoding import (LogitsProcessorBase, + get_guided_backend) from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata @@ -123,6 +126,32 @@ def __init__( self.sampler = Sampler() + self.guided_backend = None + if self.fd_config.parallel_config.guided_decoding_backend != "off": + self.guided_backend = get_guided_backend(fd_config=self.fd_config) + self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser()) + + def _init_logits_processor(self, request): + """ + init logits processor for guided decoding + """ + assert self.guided_backend is not None, "guided_backend is None, use "\ + "--guided-decoding-backend to specify the backend at server startup." + + if request.guided_json is not None: + schemata_key = ("json", request.guided_json) + elif request.guided_regex is not None: + schemata_key = ("regex", request.guided_regex) + elif request.guided_grammar is not None: + schemata_key = ("grammar", request.guided_grammar) + elif request.structural_tag is not None: + schemata_key = ("structural_tag", request.structural_tag) + + return self.guided_backend.get_logits_processor( + schemata_key=schemata_key, + enable_thinking=request.get("enable_thinking", True), + ), schemata_key + def _reset_paddle_env(self): pass @@ -290,6 +319,13 @@ def _load_model( fd_config.parallel_config.max_model_len = fd_config.model_config.max_seq_len self.fd_config = fd_config + self.fd_config.model_config.model_name_or_path = self.args.model_name_or_path + self.fd_config.model_config.reasoning_parser = self.args.reasoning_parser + self.fd_config.parallel_config.model_name_or_path = self.args.model_name_or_path + self.fd_config.parallel_config.guided_decoding_backend = self.args.guided_decoding_backend + self.fd_config.parallel_config.disable_any_whitespace = self.args.disable_any_whitespace + self.fd_config.parallel_config.splitwise_role = self.args.splitwise_role + attn_backend_cls = get_attention_backend() num_heads = self.fd_config.model_config.num_attention_heads // \ self.fd_config.parallel_config.tensor_parallel_degree @@ -670,6 +706,14 @@ def get_numeric_value(task, key, default_value): task = tasks[i] idx = task.idx + if (task.guided_json is not None + or task.guided_regex is not None + or task.structural_tag is not None + or task.guided_grammar is not None): + logits_info, schemata_key = self._init_logits_processor(task) + task.logits_processor, task.logits_cached = logits_info + task.schemata_key = schemata_key + kwargs = { "max_length": get_numeric_value(task, "max_tokens", 2048), @@ -778,11 +822,39 @@ def get_numeric_value(task, key, default_value): self.share_inputs["block_tables"][ idx:idx + 1, :encoder_block_num] = np.array(task.block_tables, dtype="int32") + self.sampler.apply_logits_processor(idx, task.get("logits_processor")) + + def _get_skip_idx(self, model_forward_batch): + """ + Get the index of the request that needs to be skipped during execution. + Args: + model_forward_batch: A list of requests to be executed by this runner. + Returns: + A list of indices corresponding to the requests that need to be skipped. + """ + skip_idx_list = [] + if ( + not self.args.enable_chunked_prefill + or self.guided_backend is None + or model_forward_batch is None + ): + return skip_idx_list + + for task in model_forward_batch: + if task.get("prefill_chunk_info", + None) is None or task.chunk_idx >= len( + task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + return skip_idx_list - def pre_process(self) -> None: + def pre_process(self, skip_idx_list: list = []) -> None: """ pre_process """ + self.sampler.pre_process(skip_idx_list) + if current_platform.is_cuda(): if self.args.speculative_method is not None: ( @@ -851,11 +923,32 @@ def pre_process(self) -> None: max_num_logprobs=20 if self.enable_logprob else None, ) - def generate(self) -> None: + def _add_cache(self, model_forward_batch) -> None: + """ + Add cache for guided decoding. + """ + if self.guided_backend is None or model_forward_batch is None: + return + + for request in model_forward_batch: + logits_cached = request.get("logits_cached", None) + if logits_cached is None or logits_cached: + continue + + request.logits_cached = True + if isinstance(request.logits_processor, LogitsProcessorBase): + self.guided_backend.add_cache(request.schemata_key, + request.logits_processor) + else: + self.guided_backend.add_cache( + request.schemata_key, request.logits_processor.result()) + + def generate(self, req_dicts: Optional[List[Request]] = None) -> None: """ generate """ - self.pre_process() + skip_idx_list = self._get_skip_idx(req_dicts) + self.pre_process(skip_idx_list) hiddden_states = self.model(self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta) @@ -870,12 +963,13 @@ def generate(self) -> None: self.share_inputs["stop_flags"], ) # sampler & save_output - sampler_output = self.sampler(logits, self.sampling_metadata) + sampler_output = self.sampler(logits, self.sampling_metadata, skip_idx_list) if self.fd_config.parallel_config.tensor_parallel_degree > 1: paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) - self.post_process(sampler_output) + self.post_process(sampler_output, skip_idx_list) + self._add_cache(req_dicts) - def post_process(self, sampler_output: SamplerOutput) -> None: + def post_process(self, sampler_output: SamplerOutput, skip_idx_list: List[int] = []) -> None: """ post_process """ @@ -943,6 +1037,7 @@ def post_process(self, sampler_output: SamplerOutput) -> None: sampler_output.sampled_token_ids, self.share_inputs["is_block_step"], ) + self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) if sampler_output.logprobs_tensors is None: save_output( sampler_output.sampled_token_ids, diff --git a/fastdeploy/worker/vl_worker_process.py b/fastdeploy/worker/vl_worker_process.py index 1efc725110..96f3d570e8 100644 --- a/fastdeploy/worker/vl_worker_process.py +++ b/fastdeploy/worker/vl_worker_process.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -import argparse import time from collections import defaultdict from concurrent.futures import ThreadPoolExecutor @@ -25,8 +24,8 @@ from fastdeploy.engine.config import ModelConfig from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal -from fastdeploy.utils import get_logger, none_or_str -from fastdeploy.worker.worker_process import initialize_fd_config, parse_args +from fastdeploy.utils import get_logger +from fastdeploy.worker.worker_process import parse_args logger = get_logger("worker", "worker.log") @@ -385,7 +384,7 @@ def run(self): time.sleep(0.001) continue - self.infer_engine.generate() + self.infer_engine.generate(req_dicts) self.infer_engine.share_inputs["infer_seed"].add_( infer_seed_increment) self.infer_engine.share_inputs[ diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index e308002602..f9585920d2 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -576,6 +576,11 @@ def parse_args(): parser.add_argument("--enable_logprob", action='store_true', help="Enable output of token-level log probabilities.") + parser.add_argument("--reasoning_parser", + type=str, + default=None, + help="Flag specifies the reasoning parser to use for " \ + "extracting reasoning content from the model output") args = parser.parse_args() return args @@ -590,6 +595,14 @@ def initialize_fd_config(config_or_args) -> FDConfig: Returns: FDConfig: Initialized FastDeploy configuration object """ + + def getattr_without_none(obj, attr_name, default=None): + if hasattr(obj, attr_name): + if getattr(obj, attr_name) == "None": + return default + return getattr(obj, attr_name) + return default + # Get model config from model directory model_config_dict, _ = ModelConfig.get_config_dict(config_or_args.model_name_or_path) @@ -634,6 +647,7 @@ def initialize_fd_config(config_or_args) -> FDConfig: # Handle quantization (check for attribute existence) model_config.quantization = getattr(config_or_args, 'quantization', None) + model_config.reasoning_parser = getattr_without_none(config_or_args, 'reasoning_parser', None) # Update speculative config_or_args speculative_config.method = getattr(config_or_args, 'speculative_method', None) From a0293e959f614f9d53b0c3e45696f7531b07bef2 Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 8 Jul 2025 15:47:05 +0800 Subject: [PATCH 2/5] update code --- fastdeploy/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 80e19162f7..054dfec0e9 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1023,6 +1023,7 @@ class at the server level, which is too granular for ModelRunner. if self.parallel_config.tensor_parallel_degree > 1: paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) + self.sampler.post_process(sampled_token_ids, skip_idx_list) else: self.sampler(logits, self.sampling_metadata, self.parallel_config.max_model_len, self.share_inputs) @@ -1092,7 +1093,6 @@ class at the server level, which is too granular for ModelRunner. self.speculative_config, self.parallel_config.enable_prefix_caching, ) - self._update_chunked_prefill(model_forward_batch) self._add_cache(model_forward_batch) return None From c7479b15908961fdfc24dff02c19c7ef179b33db Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 8 Jul 2025 17:00:57 +0800 Subject: [PATCH 3/5] update docs --- docs/features/structured_outputs.md | 63 +++++++++++++++++++++++++ docs/zh/features/structured_outputs.md | 65 ++++++++++++++++++++++++++ fastdeploy/worker/gpu_model_runner.py | 9 ++-- 3 files changed, 133 insertions(+), 4 deletions(-) diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 40e177c1ce..e8d9b9b72a 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -330,3 +330,66 @@ ParsedChatCompletionMessage[Info](content='{"addr": "No.1 Century Avenue, Pudong Address: No.1 Century Avenue, Pudong New Area, Shanghai Height: 468 ``` + +### Offline Inference + +Offline inference allows restricting the model's output format by pre-specified constraints. In `FastDeploy`, constraints can be specified through the `GuidedDecodingParams` class in `SamplingParams`. `GuidedDecodingParams` supports the following constraint types, with usage similar to online inference: + +```python +json: Optional[Union[str, dict]] = None +regex: Optional[str] = None +choice: Optional[List[str]] = None +grammar: Optional[str] = None +json_object: Optional[bool] = None +structural_tag: Optional[str] = None +``` + +The following example demonstrates how to use offline inference to generate a structured json: + +```python +from fastdeploy import LLM, SamplingParams +from fastdeploy.sampling_params import GuidedDecodingParams +from pydantic import BaseModel +from enum import Enum + +class BookType(str, Enum): + romance = "Romance" + historical = "Historical" + adventure = "Adventure" + mystery = "Mystery" + dystopian = "Dystopian" + +class BookDescription(BaseModel): + author: str + title: str + genre: BookType + +# Constrained decoding parameters +guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema()) + +# Sampling parameters +sampling_params = SamplingParams( + top_p=0.95, + max_tokens=6400, + guided_decoding=guided_decoding_params, +) + +# Load model +llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192) + +outputs = llm.generate( + prompts="Classify this sentiment: vLLM is wonderful!", + sampling_params=sampling_params, +) + +# Output results +for output in outputs: + prompt = output.prompt + generated_text = output.outputs.text +``` + +Output: + +``` +{"author": "Cao Xueqin", "title": "Dream of the Red Chamber", "genre": "Historical"} +``` diff --git a/docs/zh/features/structured_outputs.md b/docs/zh/features/structured_outputs.md index ce33f1232d..8dd8badcdf 100644 --- a/docs/zh/features/structured_outputs.md +++ b/docs/zh/features/structured_outputs.md @@ -330,3 +330,68 @@ ParsedChatCompletionMessage[Info](content='{"addr": "上海市浦东新区世纪 地址: 上海市浦东新区世纪大道1号 高度: 468 ``` + +### 离线推理 + +离线推理允许通过预先指定约束条件,限制模型输出格式。在 `FastDeploy` 中,支持通过 `SamplingParams` 中的 `GuidedDecodingParams` 类指定相关约束条件。`GuidedDecodingParams` 支持以下几种约束条件,使用方式可以参考在线推理: + +```python +json: Optional[Union[str, dict]] = None +regex: Optional[str] = None +choice: Optional[List[str]] = None +grammar: Optional[str] = None +json_object: Optional[bool] = None +structural_tag: Optional[str] = None +``` + +以下示例展示了如何使用离线推理生成一个结构化的 json : + +```python + +from fastdeploy import LLM, SamplingParams +from fastdeploy.sampling_params import GuidedDecodingParams +from pydantic import BaseModel +from enum import Enum + +class BookType(str, Enum): + romance = "Romance" + historical = "Historical" + adventure = "Adventure" + mystery = "Mystery" + dystopian = "Dystopian" + +class BookDescription(BaseModel): + author: str + title: str + genre: BookType + +# 受限解码参数 +guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema()) + +# 采样参数 +sampling_params = SamplingParams( + top_p=0.95, + max_tokens=6400, + guided_decoding=guided_decoding_params, +) + +# 加载模型 +llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192) + +outputs = llm.generate( + prompts="Classify this sentiment: vLLM is wonderful!", + sampling_params=sampling_params, +) + +# 输出结果 +for output in outputs: + prompt = output.prompt + generated_text = output.outputs.text + +``` + +输出 + +``` +{"author": "曹雪芹", "title": "红楼梦", "genre": "Historical"} +``` diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 054dfec0e9..ffd6c3e8d2 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -67,16 +67,17 @@ def __init__( self.speculative_decoding = self.speculative_method is not None self.enable_logprob = fd_config.model_config.enable_logprob - self.guided_backend = None - if self.fd_config.parallel_config.guided_decoding_backend != "off": - self.guided_backend = get_guided_backend(fd_config=self.fd_config) - # Sampler if not self.speculative_decoding: self.sampler = Sampler() else: self.sampler = SpeculativeSampler(fd_config) + self.guided_backend = None + if self.fd_config.parallel_config.guided_decoding_backend != "off": + self.guided_backend = get_guided_backend(fd_config=self.fd_config) + self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser()) + # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] From 72de4a38c3c71757f5f48c9d465e68cc8e442e53 Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 8 Jul 2025 21:14:50 +0800 Subject: [PATCH 4/5] update code --- docs/features/structured_outputs.md | 11 +++++------ docs/zh/features/structured_outputs.md | 17 ++++++++--------- fastdeploy/engine/sampling_params.py | 5 +---- fastdeploy/input/ernie_processor.py | 6 +++--- .../model_executor/layers/sample/sampler.py | 5 ++--- fastdeploy/worker/gpu_model_runner.py | 2 +- fastdeploy/worker/worker_process.py | 3 +++ 7 files changed, 23 insertions(+), 26 deletions(-) diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index e8d9b9b72a..f7ee424cb6 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -348,7 +348,7 @@ The following example demonstrates how to use offline inference to generate a st ```python from fastdeploy import LLM, SamplingParams -from fastdeploy.sampling_params import GuidedDecodingParams +from fastdeploy.engine.sampling_params import GuidedDecodingParams from pydantic import BaseModel from enum import Enum @@ -375,21 +375,20 @@ sampling_params = SamplingParams( ) # Load model -llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192) +llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto") outputs = llm.generate( - prompts="Classify this sentiment: vLLM is wonderful!", + prompts="Generate a JSON describing a literary work, including author, title and book type.", sampling_params=sampling_params, ) # Output results for output in outputs: - prompt = output.prompt - generated_text = output.outputs.text + print(output.outputs.text) ``` Output: ``` -{"author": "Cao Xueqin", "title": "Dream of the Red Chamber", "genre": "Historical"} +{"author": "George Orwell", "title": "1984", "genre": "Dystopian"} ``` diff --git a/docs/zh/features/structured_outputs.md b/docs/zh/features/structured_outputs.md index 8dd8badcdf..cafda804c6 100644 --- a/docs/zh/features/structured_outputs.md +++ b/docs/zh/features/structured_outputs.md @@ -349,7 +349,7 @@ structural_tag: Optional[str] = None ```python from fastdeploy import LLM, SamplingParams -from fastdeploy.sampling_params import GuidedDecodingParams +from fastdeploy.engine.sampling_params import GuidedDecodingParams from pydantic import BaseModel from enum import Enum @@ -365,28 +365,27 @@ class BookDescription(BaseModel): title: str genre: BookType -# 受限解码参数 +# Constrained decoding parameters guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema()) -# 采样参数 +# Sampling parameters sampling_params = SamplingParams( top_p=0.95, max_tokens=6400, guided_decoding=guided_decoding_params, ) -# 加载模型 -llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192) +# Load model +llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto") outputs = llm.generate( - prompts="Classify this sentiment: vLLM is wonderful!", + prompts="生成一个JSON,描述一本中国的著作,要包含作者、标题和书籍类型。", sampling_params=sampling_params, ) -# 输出结果 +# Output results for output in outputs: - prompt = output.prompt - generated_text = output.outputs.text + print(output.outputs.text) ``` diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index bcbb5ec55e..5bc4178b65 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -210,9 +210,6 @@ def _verify_args(self) -> None: raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.") - if self.guided_decoding is not None: - self.guided_decoding._verify_args() - def update_from_tokenizer(self, tokenizer): """ # TODO: Implement stop tokens and bad words support @@ -259,7 +256,7 @@ def to_dict(self): guided_dict[key] = value return guided_dict - def _verify_args(self): + def __post_init__(self): """Verify the arguments.""" guided_count = sum([ self.json is not None, self.regex is not None, self.choice diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index d4e45712bc..2de4ce2917 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -100,7 +100,7 @@ def process_request(self, request, max_model_len=None, **kwargs): if request.prompt_token_ids is None or len( request.prompt_token_ids) == 0: - system = request.get("system") + # system = request.get("system") if request.prompt is None and request.messages is None: raise ValueError( f"The request should have `input_ids`, `text` or `messages`: {request}.") @@ -149,7 +149,7 @@ def process_request_dict(self, request, max_model_len=None): request['stop_token_ids'] = stop_seqs request['stop_seqs_len'] = stop_seqs_len - system = request.get("system") + # system = request.get("system") # 处理prompt_token_ids if not request.get('prompt_token_ids'): if request.get('prompt') is None and request.get( @@ -213,7 +213,7 @@ def process_response(self, response_dict, **kwargs): response_dict.outputs.reasoning_content = reasoning_content else: response_dict.outputs.text = full_text - data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}") + data_processor_logger.info(f"req_id:{req_id}, token ids: {token_ids}") if response_dict.outputs.text == "" and \ response_dict.outputs.reasoning_content == "": return None diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index e48966b585..f716c1b4e7 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -125,15 +125,14 @@ def apply_token_mask(self, indices = [] for idx, processor in self.logits_processor.items(): - if processor is None: + if processor is None or idx in skip_idx_list: continue if not processor.enable_reasoning or processor.reasoning_ended: indices.append(idx) - mask_idx = [i for i in indices if i not in skip_idx_list] return available_processors.apply_token_mask(logits, self.token_bitmask, - indices=mask_idx) + indices=indices) def _accept_token(self, idx: int, token: int): """ accept token """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index ffd6c3e8d2..2bc9a7378f 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1024,7 +1024,7 @@ class at the server level, which is too granular for ModelRunner. if self.parallel_config.tensor_parallel_degree > 1: paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) - self.sampler.post_process(sampled_token_ids, skip_idx_list) + self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) else: self.sampler(logits, self.sampling_metadata, self.parallel_config.max_model_len, self.share_inputs) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index f9585920d2..c43f21e53c 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -606,7 +606,10 @@ def getattr_without_none(obj, attr_name, default=None): # Get model config from model directory model_config_dict, _ = ModelConfig.get_config_dict(config_or_args.model_name_or_path) +<<<<<<< HEAD +======= +>>>>>>> 75f51506 (update code) # Handle MoE related configs if 'num_experts' in model_config_dict: model_config_dict['moe_num_experts'] = model_config_dict.pop('num_experts') From 205b6d84dd676ac371160a3f6ff4c42367579fe8 Mon Sep 17 00:00:00 2001 From: kevin Date: Fri, 11 Jul 2025 14:48:11 +0800 Subject: [PATCH 5/5] update code --- fastdeploy/worker/vl_gpu_model_runner.py | 6 +++--- fastdeploy/worker/worker_process.py | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/fastdeploy/worker/vl_gpu_model_runner.py b/fastdeploy/worker/vl_gpu_model_runner.py index f09eeb4725..04cb59faa8 100644 --- a/fastdeploy/worker/vl_gpu_model_runner.py +++ b/fastdeploy/worker/vl_gpu_model_runner.py @@ -32,6 +32,7 @@ from fastdeploy.engine.request import Request from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.input.mm_processor import DataProcessor +from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.guided_decoding import (LogitsProcessorBase, get_guided_backend) from fastdeploy.model_executor.layers.attention import get_attention_backend @@ -49,7 +50,6 @@ from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ( ScatterOp, VariableResolutionResamplerModel) from fastdeploy.platforms import current_platform -from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.worker.output import SamplerOutput from fastdeploy.worker.utils import check_safetensors_model from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase @@ -949,10 +949,10 @@ def generate(self, req_dicts: Optional[List[Request]] = None) -> None: """ skip_idx_list = self._get_skip_idx(req_dicts) self.pre_process(skip_idx_list) - hiddden_states = self.model(self.share_inputs["ids_remove_padding"], + hidden_states = self.model(self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta) - logits = self.model.compute_logits(hiddden_states) + logits = self.model.compute_logits(hidden_states) set_value_by_flags_and_idx( self.share_inputs["pre_ids"], self.share_inputs["input_ids"], diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index c43f21e53c..4411a17fec 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -606,10 +606,6 @@ def getattr_without_none(obj, attr_name, default=None): # Get model config from model directory model_config_dict, _ = ModelConfig.get_config_dict(config_or_args.model_name_or_path) -<<<<<<< HEAD - -======= ->>>>>>> 75f51506 (update code) # Handle MoE related configs if 'num_experts' in model_config_dict: model_config_dict['moe_num_experts'] = model_config_dict.pop('num_experts')