diff --git a/tests/reasoning/test_hunyuan_reasoning_parser.py b/tests/reasoning/test_hunyuan_reasoning_parser.py new file mode 100644 index 00000000000..f70cf453f0e --- /dev/null +++ b/tests/reasoning/test_hunyuan_reasoning_parser.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "hunyuan_a13b" +START_REASONING = "\n" +START_RESPONSE = "\n\n\n" +END_RESPONSE = "\n" + +NO_REASONING_QUICK_THROUGHT = { + "output": + f"{START_REASONING}{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501 + "reasoning_content": None, + "content": "This is the rest", +} + +SIMPLE_REASONING = { + "output": + f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING = { + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}", + "reasoning_content": "This is a reasoning section", + "content": None, +} +NO_REASONING = { + "output": "This is content", + "reasoning_content": None, + "content": "This is content", +} +MULTIPLE_LINES = { + "output": + f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} +REASONING_WITH_THINK = { + "output": + f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING_WITH_THINK = { + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}", + "reasoning_content": "This is a reasoning section", + "content": None, +} +MULTIPLE_LINES_WITH_THINK = { + "output": + f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + False, + NO_REASONING, + id="no_reasoning", + ), + pytest.param(False, NO_REASONING_QUICK_THROUGHT, id="no_reasoning_quick"), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_streaming", + ), + pytest.param( + True, + NO_REASONING, + id="no_reasoning_streaming", + ), + pytest.param(True, + NO_REASONING_QUICK_THROUGHT, + id="no_reasoning_quick_stream"), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), +] + +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("tencent/Hunyuan-A13B-Instruct", + trust_remote_code=True) + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, +): + output = tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: list[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(tokenizer) + + reasoning, content = run_reasoning_extraction(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index e8cd565519f..3e5485b883f 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -4,6 +4,7 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .granite_reasoning_parser import GraniteReasoningParser +from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser __all__ = [ @@ -11,5 +12,6 @@ "ReasoningParserManager", "DeepSeekR1ReasoningParser", "GraniteReasoningParser", + "HunyuanA13BReasoningParser", "Qwen3ReasoningParser", ] diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py new file mode 100644 index 00000000000..598a0e97e51 --- /dev/null +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import re +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("hunyuan_a13b") +class HunyuanA13BReasoningParser(ReasoningParser): + """ + Reasoning parser for Hunyuan A13B Model + + HunyuanReasoningParser + + This class implements a reasoning parser specifically designed + for the Hunyuan A13B Model. It is responsible for parsing and + extracting structured reasoning and answer segments from model + outputs that follow a specific pattern. + + Key Features: + - For non-stream output , Recognizes and extracts reasoning ("think") + and answer ("answer") sections from text using regular expressions. + - For stream process, it require a token id sequences to change the + reasoning state and other state so it maintains internal state to + manage parsing across multiple token. + + + think start: "\n": [14023, 771, 397] + think ends: "\n\n\n": [198, 524, 27963, 397, 27, 9399, 397] + response ends: "\n": [524, 9399, 29] + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_expr = r"\n" + self.think_end_expr = r"\n\n" + + self.response_start_expr = r"\n\n\n" + self.response_end_expr = r"\n" + + self.full_match_reasoning_regex = re.compile( + rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", + re.DOTALL) + + self.half_match_reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", + re.DOTALL) + + self.think_start_ids = [14023, 771, 397] + self.think_start_ids_fast = [14023, 771, 1363] + self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397] + self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397] + self.response_end_ids = [198, 524, 9399, 29] + self.fast_think_ids = [ + 14023, 771, 1363, 524, 27963, 397, 27, 9399, 397 + ] + + # when state change, send out all the buffered text in last state + self.buffered_text = [] + self.buffered_ids = [] + + self.current_state = "reasoning" + self.all_states = ["reasoning", "response"] + + self.current_state = "idle" + self.expected_sequence = self.think_start_ids + # this sequence only for the think start, it has two way to start. + self.expected_sequence_side = self.think_start_ids_fast + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.current_state == "response" + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + + re_match = self.full_match_reasoning_regex.findall(model_output) + if re_match: + reasoning_content, response_content = re_match[0] + if len(reasoning_content) == 0: + reasoning_content = None + if len(response_content) == 0: + response_content = None + return reasoning_content, response_content + + fallback_regex = self.half_match_reasoning_regex + fallback_match = fallback_regex.findall(model_output) + if fallback_match: + reasoning_content, response_content = fallback_match[0] + + if response_content.endswith(self.response_end_expr): + response_content = response_content[:-len(self. + response_end_expr)] + + if len(reasoning_content) == 0: + reasoning_content = None + if len(response_content) == 0: + response_content = None + + return reasoning_content, response_content + + return None, model_output + + def _is_strict_increasing_subsequence(self, subsequence: Sequence[int], + sequence: Sequence[int]) -> bool: + if not subsequence: + return False + + sub_idx = 0 + for num in sequence: + if sub_idx < len(subsequence) and num == subsequence[sub_idx]: + sub_idx += 1 + return sub_idx == len(subsequence) + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """Extract content using token ID sequence state machine""" + # Define sequences + think_start_sequence = self.think_start_ids + response_start_sequence = self.response_start_ids + response_end_sequence = self.response_end_ids + + assert (len(delta_token_ids) == 1) + # Process each token in the delta + token = delta_token_ids[0] + + def check_token_with_sequence(token): + if self.current_state == "idle" or self.current_state == "think": + return (token == self.expected_sequence[self.sequence_index] + or token == \ + self.expected_sequence_side[self.sequence_index]) + else: + return token == self.expected_sequence[self.sequence_index] + + def check_last_token(token): + if self.current_state == "idle" or self.current_state == "think": + # only return true if it's judge using a side sequence. + if (self.sequence_index - 1 < len(self.expected_sequence_side) + and token + == self.expected_sequence_side[self.sequence_index - + 1]): + return self.sequence_index == len( + self.expected_sequence_side) + else: + return self.sequence_index == len(self.expected_sequence) + else: + return self.sequence_index == len(self.expected_sequence) + + # Check if token matches expected sequence + token_in_state_seq = check_token_with_sequence(token) + + if token_in_state_seq: + # Store matching token + self.token_buffer.append(token) + self.text_buffer += delta_text + self.sequence_index += 1 + ## state change from idle->think->response->idle + + # Check if sequence fully matched + if check_last_token(token): + # State transition + if self.current_state == "idle": + self.current_state = "think" + self.expected_sequence = response_start_sequence + self.expected_sequence_side = self.response_start_ids_fast + elif self.current_state == "think": + self.current_state = "response" + self.expected_sequence = response_end_sequence + elif self.current_state == "response": + self.current_state = "idle" + self.expected_sequence = think_start_sequence + self.expected_sequence_side = self.think_start_ids_fast + + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + # Do not send content for state transition texts. + else: + # Sequence broken - handle buffered content + if self.token_buffer and len(self.token_buffer) > 0: + # Send buffered tokens + buffered_content = self.text_buffer + delta_text + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + # Return content based on current state + if self.current_state == "think": + return DeltaMessage(reasoning_content=buffered_content, + content=None) + else: + return DeltaMessage(reasoning_content=None, + content=buffered_content) + else: + # No buffered content, send normally + if self.current_state == "think": + return DeltaMessage(reasoning_content=delta_text, + content=None) + else: + return DeltaMessage(reasoning_content=None, + content=delta_text) + + # If no content to send in this delta + return None