-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Feature] Add command tool parser for Command-A model #20800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,151 @@ | ||||||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||||||||
|
||||||||||||
import json | ||||||||||||
from collections.abc import Sequence | ||||||||||||
from typing import Union | ||||||||||||
|
||||||||||||
import partial_json_parser | ||||||||||||
import regex as re | ||||||||||||
from partial_json_parser.core.options import Allow | ||||||||||||
from transformers import PreTrainedTokenizerBase | ||||||||||||
|
||||||||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, | ||||||||||||
DeltaFunctionCall, DeltaMessage, | ||||||||||||
DeltaToolCall, | ||||||||||||
ExtractedToolCallInformation, | ||||||||||||
FunctionCall, ToolCall) | ||||||||||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( | ||||||||||||
ToolParser, ToolParserManager) | ||||||||||||
from vllm.logger import init_logger | ||||||||||||
from vllm.utils import random_uuid | ||||||||||||
|
||||||||||||
logger = init_logger(__name__) | ||||||||||||
|
||||||||||||
|
||||||||||||
@ToolParserManager.register_module("command") | ||||||||||||
class CommandToolParser(ToolParser): | ||||||||||||
|
||||||||||||
def __init__(self, tokenizer: PreTrainedTokenizerBase): | ||||||||||||
super().__init__(tokenizer) | ||||||||||||
# Streaming state | ||||||||||||
self.prev_tool_call_arr: list[dict] = [] | ||||||||||||
self.streamed_args_for_tool: list[str] = [] | ||||||||||||
self.current_tool_id: int = -1 | ||||||||||||
self.current_tool_name_sent: bool = False | ||||||||||||
Comment on lines
+32
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The state variables
Suggested change
|
||||||||||||
|
||||||||||||
# Action delimiters | ||||||||||||
self.tool_call_start_token = "<|START_ACTION|>" | ||||||||||||
self.tool_call_end_token = "<|END_ACTION|>" | ||||||||||||
self.tool_call_regex = re.compile( | ||||||||||||
r"<\|START_ACTION\|>(.*?)<\|END_ACTION\|>", re.DOTALL) | ||||||||||||
|
||||||||||||
# Precompute token ids | ||||||||||||
self.tool_call_start_token_id = self.vocab.get( | ||||||||||||
self.tool_call_start_token) | ||||||||||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) | ||||||||||||
if (self.tool_call_start_token_id is None | ||||||||||||
or self.tool_call_end_token_id is None): | ||||||||||||
raise RuntimeError( | ||||||||||||
"CommandToolParser cannot find start/end tokens in vocab") | ||||||||||||
|
||||||||||||
def extract_tool_calls( | ||||||||||||
self, model_output: str, | ||||||||||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation: | ||||||||||||
# Synchronous parsing: look for full action block | ||||||||||||
if self.tool_call_start_token not in model_output: | ||||||||||||
return ExtractedToolCallInformation(tools_called=False, | ||||||||||||
tool_calls=[], | ||||||||||||
content=model_output) | ||||||||||||
try: | ||||||||||||
match = self.tool_call_regex.search(model_output) | ||||||||||||
if not match: | ||||||||||||
raise ValueError("No action block found") | ||||||||||||
payload = match.group(1) | ||||||||||||
raw_calls = json.loads(payload) | ||||||||||||
tool_calls = [] | ||||||||||||
for entry in raw_calls: | ||||||||||||
name = entry.get("tool_name") | ||||||||||||
params = entry.get("parameters", {}) | ||||||||||||
tool_calls.append( | ||||||||||||
ToolCall(type="function", | ||||||||||||
function=FunctionCall(name=name, | ||||||||||||
arguments=json.dumps( | ||||||||||||
params, | ||||||||||||
ensure_ascii=False)))) | ||||||||||||
# content before action | ||||||||||||
prefix = model_output.split(self.tool_call_start_token, 1)[0] | ||||||||||||
return ExtractedToolCallInformation(tools_called=True, | ||||||||||||
tool_calls=tool_calls, | ||||||||||||
content=prefix or None) | ||||||||||||
except Exception: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching a broad
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching a generic
Suggested change
|
||||||||||||
logger.exception("Error extracting sync tool calls") | ||||||||||||
return ExtractedToolCallInformation(tools_called=False, | ||||||||||||
tool_calls=[], | ||||||||||||
content=model_output) | ||||||||||||
|
||||||||||||
def extract_tool_calls_streaming( | ||||||||||||
self, | ||||||||||||
previous_text: str, | ||||||||||||
current_text: str, | ||||||||||||
delta_text: str, | ||||||||||||
previous_token_ids: Sequence[int], | ||||||||||||
current_token_ids: Sequence[int], | ||||||||||||
delta_token_ids: Sequence[int], | ||||||||||||
request: ChatCompletionRequest, | ||||||||||||
) -> Union[DeltaMessage, None]: | ||||||||||||
|
||||||||||||
prev_start = previous_token_ids.count(self.tool_call_start_token_id) | ||||||||||||
cur_start = current_token_ids.count(self.tool_call_start_token_id) | ||||||||||||
cur_end = current_token_ids.count(self.tool_call_end_token_id) | ||||||||||||
Comment on lines
+98
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Comment on lines
+98
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation repeatedly calls To optimize this, I recommend maintaining the counts of start and end tokens as state within the |
||||||||||||
|
||||||||||||
# Case 1: Block not started → Text as is | ||||||||||||
if cur_start == 0: | ||||||||||||
return DeltaMessage(content=delta_text) | ||||||||||||
|
||||||||||||
# Case 2: Starting a new block | ||||||||||||
if cur_start > prev_start: | ||||||||||||
self.current_tool_id += 1 | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This increment, combined with the one in the parsing loop (line 147), causes non-contiguous tool call indices (e.g., 0, 1, 3, ...) when the model output contains multiple action blocks. Tool call indices for a single response must be contiguous, starting from 0. To fix this, |
||||||||||||
return None | ||||||||||||
|
||||||||||||
# Case 3: Inside block, not closed → ignored | ||||||||||||
if cur_start > cur_end: | ||||||||||||
return None | ||||||||||||
|
||||||||||||
# Case 4: Block End Point | ||||||||||||
if cur_start == cur_end and self.tool_call_end_token in delta_text: | ||||||||||||
Comment on lines
+102
to
+116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The streaming logic has a flaw: when a single There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checking for the end token as a substring (
Suggested change
|
||||||||||||
full = current_text + delta_text | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||
|
||||||||||||
payload = full.split(self.tool_call_start_token, 1)[1] \ | ||||||||||||
.split(self.tool_call_end_token, 1)[0].strip() | ||||||||||||
try: | ||||||||||||
calls = partial_json_parser.loads(payload or "[]", Allow.ALL) | ||||||||||||
except partial_json_parser.core.exceptions.MalformedJSON: | ||||||||||||
logger.debug("Waiting for complete JSON") | ||||||||||||
return None | ||||||||||||
except json.JSONDecodeError: | ||||||||||||
logger.debug("Malformed JSON payload: %s", payload) | ||||||||||||
return None | ||||||||||||
|
||||||||||||
calls_list = calls if isinstance(calls, list) else [calls] | ||||||||||||
deltas = [] | ||||||||||||
for entry in calls_list: | ||||||||||||
name = entry.get("tool_name") | ||||||||||||
params = entry.get("parameters", {}) | ||||||||||||
args = json.dumps(params, ensure_ascii=False) | ||||||||||||
deltas.append( | ||||||||||||
DeltaToolCall( | ||||||||||||
index=self.current_tool_id, | ||||||||||||
type="function", | ||||||||||||
id=f"chatcmpl-tool-{random_uuid()}", | ||||||||||||
function=DeltaFunctionCall( | ||||||||||||
name=name, | ||||||||||||
arguments=args, | ||||||||||||
).model_dump(exclude_none=True), | ||||||||||||
)) | ||||||||||||
|
||||||||||||
self.current_tool_id += 1 | ||||||||||||
|
||||||||||||
return DeltaMessage(tool_calls=deltas) | ||||||||||||
|
||||||||||||
return DeltaMessage(content=delta_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state variables
prev_tool_call_arr
,streamed_args_for_tool
, andcurrent_tool_name_sent
are re-declared here, but they are already inherited from theToolParser
base class and initialized insuper().__init__()
. These variables are also unused within this class. Remove these lines to avoid shadowing the parent class's attributes and improve code clarity.