Skip to content

[Feature] Add command tool parser for Command-A model #20800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from .abstract_tool_parser import ToolParser, ToolParserManager
from .command_tool_parser import CommandToolParser
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
Expand All @@ -23,5 +24,5 @@
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
"KimiK2ToolParser"
"KimiK2ToolParser", "CommandToolParser"
]
151 changes: 151 additions & 0 deletions vllm/entrypoints/openai/tool_parsers/command_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
from collections.abc import Sequence
from typing import Union

import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.utils import random_uuid

logger = init_logger(__name__)


@ToolParserManager.register_module("command")
class CommandToolParser(ToolParser):

def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Streaming state
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
Comment on lines +32 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The state variables prev_tool_call_arr, streamed_args_for_tool, and current_tool_name_sent are re-declared here, but they are already inherited from the ToolParser base class and initialized in super().__init__(). These variables are also unused within this class. Remove these lines to avoid shadowing the parent class's attributes and improve code clarity.

Comment on lines +32 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The state variables prev_tool_call_arr, streamed_args_for_tool, and current_tool_name_sent are initialized here but are not used anywhere within the CommandToolParser class. To improve code clarity and avoid confusion, these redundant and unused variables should be removed.

Suggested change
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.current_tool_id: int = -1


# Action delimiters
self.tool_call_start_token = "<|START_ACTION|>"
self.tool_call_end_token = "<|END_ACTION|>"
self.tool_call_regex = re.compile(
r"<\|START_ACTION\|>(.*?)<\|END_ACTION\|>", re.DOTALL)

# Precompute token ids
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"CommandToolParser cannot find start/end tokens in vocab")

def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
# Synchronous parsing: look for full action block
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
match = self.tool_call_regex.search(model_output)
if not match:
raise ValueError("No action block found")
payload = match.group(1)
raw_calls = json.loads(payload)
tool_calls = []
for entry in raw_calls:
name = entry.get("tool_name")
params = entry.get("parameters", {})
tool_calls.append(
ToolCall(type="function",
function=FunctionCall(name=name,
arguments=json.dumps(
params,
ensure_ascii=False))))
# content before action
prefix = model_output.split(self.tool_call_start_token, 1)[0]
return ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=prefix or None)
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can hide specific issues and make debugging harder. It's better to catch more specific exceptions that you expect to handle, such as json.JSONDecodeError or ValueError. This provides better error context and avoids accidentally catching unrelated exceptions.

Suggested change
except Exception:
except (json.JSONDecodeError, ValueError):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a generic Exception is too broad and can hide unexpected errors or bugs in the code. It's better to catch specific exceptions that you expect to handle, such as json.JSONDecodeError and ValueError. This makes the error handling more robust and debugging easier.

Suggested change
except Exception:
except (json.JSONDecodeError, ValueError):

logger.exception("Error extracting sync tool calls")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)

def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:

prev_start = previous_token_ids.count(self.tool_call_start_token_id)
cur_start = current_token_ids.count(self.tool_call_start_token_id)
cur_end = current_token_ids.count(self.tool_call_end_token_id)
Comment on lines +98 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The count() method is called on previous_token_ids and current_token_ids in every invocation of this streaming method. Since these lists can grow very large for long conversations, this O(N) operation on each streaming chunk can become a performance bottleneck. Consider maintaining the counts as part of the parser's state and update them incrementally with each new delta_token_ids.

Comment on lines +98 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation repeatedly calls .count() on previous_token_ids and current_token_ids for every streaming chunk. As the generation progresses, current_token_ids can become very large, making this operation a performance bottleneck with O(N) complexity at each step, where N is the number of tokens generated so far.

To optimize this, I recommend maintaining the counts of start and end tokens as state within the CommandToolParser instance.


# Case 1: Block not started → Text as is
if cur_start == 0:
return DeltaMessage(content=delta_text)

# Case 2: Starting a new block
if cur_start > prev_start:
self.current_tool_id += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This increment, combined with the one in the parsing loop (line 147), causes non-contiguous tool call indices (e.g., 0, 1, 3, ...) when the model output contains multiple action blocks. Tool call indices for a single response must be contiguous, starting from 0. To fix this, current_tool_id should only be incremented inside the parsing loop (lines 132-148) for each parsed tool call.

return None

# Case 3: Inside block, not closed → ignored
if cur_start > cur_end:
return None

# Case 4: Block End Point
if cur_start == cur_end and self.tool_call_end_token in delta_text:
Comment on lines +102 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The streaming logic has a flaw: when a single delta_text contains both the end of one action block and the start of another (e.g., ...<|END_ACTION|><|START_ACTION|>...), cur_start will be greater than cur_end, causing the check for a completed block at line 116 to be skipped. The if cur_start > prev_start: check at line 107 will be true, and the method will return None, effectively ignoring the completed block that just ended. The logic should be structured to handle a block ending before checking for a new block starting within the same chunk.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Checking for the end token as a substring (self.tool_call_end_token in delta_text) is not robust. If the tokenizer splits the end token string into multiple tokens, this check will fail, and the tool call block will not be processed correctly. A more reliable approach is to check for the presence of the end token's ID in the delta_token_ids list.

Suggested change
if cur_start == cur_end and self.tool_call_end_token in delta_text:
if cur_start == cur_end and self.tool_call_end_token_id in delta_token_ids:

full = current_text + delta_text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current_text parameter should already contain the full accumulated text from the stream up to this point. Concatenating delta_text to it is redundant and will result in duplicated content, which will likely cause the payload extraction to fail. full should just be assigned the value of current_text.

Suggested change
full = current_text + delta_text
full = current_text

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current_text parameter already contains the full text generated so far, including the delta_text. Concatenating current_text + delta_text results in duplicating the last chunk of text, which is a bug and could lead to incorrect parsing of the tool call payload. You should use current_text directly.

Suggested change
full = current_text + delta_text
full = current_text


payload = full.split(self.tool_call_start_token, 1)[1] \
.split(self.tool_call_end_token, 1)[0].strip()
try:
calls = partial_json_parser.loads(payload or "[]", Allow.ALL)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("Waiting for complete JSON")
return None
except json.JSONDecodeError:
logger.debug("Malformed JSON payload: %s", payload)
return None

calls_list = calls if isinstance(calls, list) else [calls]
deltas = []
for entry in calls_list:
name = entry.get("tool_name")
params = entry.get("parameters", {})
args = json.dumps(params, ensure_ascii=False)
deltas.append(
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=name,
arguments=args,
).model_dump(exclude_none=True),
))

self.current_tool_id += 1

return DeltaMessage(tool_calls=deltas)

return DeltaMessage(content=delta_text)