Skip to content

Commit 0cf893c

Browse files
MoyanZittowangzhengtaowangzhengtao
authored
Add kimi-k2 tool parser (#20789)
Signed-off-by: wangzhengtao <wangzhengtao@moonshot.cn> Co-authored-by: wangzhengtao <wangzhengtao@moonshot.cn> Co-authored-by: wangzhengtao <wangzhengtao@msh.team>
1 parent cf75cd2 commit 0cf893c

File tree

4 files changed

+576
-2
lines changed

4 files changed

+576
-2
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa: E501
4+
5+
import json
6+
7+
import pytest
8+
9+
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
10+
from vllm.entrypoints.openai.tool_parsers import KimiK2ToolParser
11+
from vllm.transformers_utils.tokenizer import get_tokenizer
12+
13+
pytest.skip("skip kimi_k2 parser test", allow_module_level=True)
14+
15+
# Use a common model that is likely to be available
16+
MODEL = "moonshotai/Kimi-K2-Instruct"
17+
18+
19+
@pytest.fixture(scope="module")
20+
def kimi_k2_tokenizer():
21+
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
22+
23+
24+
@pytest.fixture
25+
def kimi_k2_tool_parser(kimi_k2_tokenizer):
26+
return KimiK2ToolParser(kimi_k2_tokenizer)
27+
28+
29+
def assert_tool_calls(actual_tool_calls: list[ToolCall],
30+
expected_tool_calls: list[ToolCall]):
31+
assert len(actual_tool_calls) == len(expected_tool_calls)
32+
33+
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
34+
expected_tool_calls):
35+
36+
assert actual_tool_call.type == "function"
37+
assert actual_tool_call.function == expected_tool_call.function
38+
39+
# assert tool call id format
40+
assert actual_tool_call.id.startswith("functions.")
41+
assert actual_tool_call.id.split(':')[-1].isdigit()
42+
assert actual_tool_call.id.split('.')[1].split(
43+
':')[0] == expected_tool_call.function.name
44+
45+
46+
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
47+
model_output = "This is a test"
48+
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
49+
model_output, request=None) # type: ignore[arg-type]
50+
assert not extracted_tool_calls.tools_called
51+
assert extracted_tool_calls.tool_calls == []
52+
assert extracted_tool_calls.content == model_output
53+
54+
55+
@pytest.mark.parametrize(
56+
ids=[
57+
"tool_call_with_content_before",
58+
"multi_tool_call_with_content_before",
59+
],
60+
argnames=["model_output", "expected_tool_calls", "expected_content"],
61+
argvalues=[
62+
(
63+
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
64+
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
65+
[
66+
ToolCall(id='functions.get_weather:0',
67+
function=FunctionCall(
68+
name="get_weather",
69+
arguments=json.dumps({
70+
"city": "Beijing",
71+
}, ),
72+
),
73+
type='function')
74+
],
75+
"I'll help you check the weather. ",
76+
),
77+
(
78+
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
79+
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
80+
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
81+
[
82+
ToolCall(id='functions.get_weather:0',
83+
function=FunctionCall(
84+
name="get_weather",
85+
arguments=json.dumps({
86+
"city": "Beijing",
87+
}, ),
88+
),
89+
type='function'),
90+
ToolCall(id='functions.get_weather:1',
91+
function=FunctionCall(
92+
name="get_weather",
93+
arguments=json.dumps({
94+
"city": "Shanghai",
95+
}, ),
96+
),
97+
type='function')
98+
],
99+
"I'll help you check the weather. ",
100+
),
101+
],
102+
)
103+
def test_extract_tool_calls(kimi_k2_tool_parser, model_output,
104+
expected_tool_calls, expected_content):
105+
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
106+
model_output, request=None) # type: ignore[arg-type]
107+
assert extracted_tool_calls.tools_called
108+
109+
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
110+
111+
assert extracted_tool_calls.content == expected_content
112+
113+
114+
def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser):
115+
"""we'll return every funcall result"""
116+
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
117+
functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|>
118+
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
119+
120+
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
121+
model_output, request=None) # type: ignore[arg-type]
122+
123+
assert extracted_tool_calls.tools_called
124+
# Should extract only the valid JSON tool calls
125+
assert len(extracted_tool_calls.tool_calls) == 2
126+
assert extracted_tool_calls.tool_calls[
127+
0].function.name == "invalid_get_weather"
128+
assert extracted_tool_calls.tool_calls[
129+
1].function.name == "valid_get_weather"
130+
131+
132+
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
133+
"""we'll return every funcall result"""
134+
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
135+
functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
136+
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
137+
138+
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
139+
model_output, request=None) # type: ignore[arg-type]
140+
141+
assert extracted_tool_calls.tools_called
142+
# Should extract only the valid JSON tool calls
143+
assert len(extracted_tool_calls.tool_calls) == 1
144+
assert extracted_tool_calls.tool_calls[
145+
0].function.name == "valid_get_weather"
146+
147+
148+
def test_streaming_basic_functionality(kimi_k2_tool_parser):
149+
"""Test basic streaming functionality."""
150+
# Reset streaming state
151+
kimi_k2_tool_parser.current_tool_name_sent = False
152+
kimi_k2_tool_parser.prev_tool_call_arr = []
153+
kimi_k2_tool_parser.current_tool_id = -1
154+
kimi_k2_tool_parser.streamed_args_for_tool = []
155+
156+
# Test with a simple tool call
157+
current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
158+
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
159+
160+
# First call should handle the initial setup
161+
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
162+
previous_text="I'll help you",
163+
current_text=current_text,
164+
delta_text="<|tool_calls_section_end|>",
165+
previous_token_ids=[],
166+
current_token_ids=[],
167+
delta_token_ids=[],
168+
request=None,
169+
)
170+
171+
# The result might be None or contain tool call information
172+
# This depends on the internal state management
173+
if result is not None and hasattr(result,
174+
'tool_calls') and result.tool_calls:
175+
assert len(result.tool_calls) >= 0
176+
177+
178+
def test_streaming_no_tool_calls(kimi_k2_tool_parser):
179+
"""Test streaming when there are no tool calls."""
180+
current_text = "This is just regular text without any tool calls."
181+
182+
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
183+
previous_text="This is just regular text",
184+
current_text=current_text,
185+
delta_text=" without any tool calls.",
186+
previous_token_ids=[],
187+
current_token_ids=[],
188+
delta_token_ids=[],
189+
request=None,
190+
)
191+
192+
# Should return the delta text as content
193+
assert result is not None
194+
assert hasattr(result, 'content')
195+
assert result.content == " without any tool calls."

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ def is_deepseek_mla(self) -> bool:
11431143
if not hasattr(self.hf_text_config, "model_type"):
11441144
return False
11451145
elif self.hf_text_config.model_type in \
1146-
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
1146+
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'):
11471147
return self.hf_text_config.kv_lora_rank is not None
11481148
elif self.hf_text_config.model_type == 'eagle':
11491149
# if the model is an EAGLE module, check for the

vllm/entrypoints/openai/tool_parsers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .hermes_tool_parser import Hermes2ProToolParser
99
from .internlm2_tool_parser import Internlm2ToolParser
1010
from .jamba_tool_parser import JambaToolParser
11+
from .kimi_k2_tool_parser import KimiK2ToolParser
1112
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
1213
from .llama_tool_parser import Llama3JsonToolParser
1314
from .minimax_tool_parser import MinimaxToolParser
@@ -21,5 +22,6 @@
2122
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
2223
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
2324
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
24-
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser"
25+
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
26+
"KimiK2ToolParser"
2527
]

0 commit comments

Comments
 (0)