Skip to content

Commit c3426c1

Browse files
committed
[Bugfix] Enhance tool call regex and improve error handling in Hermes tool parser
- Updated regex to better capture tool calls, ensuring proper whitespace handling. - Refactored function call extraction to include error handling for malformed JSON, logging warnings as necessary. Signed-off-by: Adrien <adrien@huggingface.co>
1 parent 545b848 commit c3426c1

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, tokenizer: AnyTokenizer):
4444
self.tool_call_end_token: str = "</tool_call>"
4545

4646
self.tool_call_regex = re.compile(
47-
r"(?:<tool_call>\s*)+(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
47+
r"<tool_call>\s*(.*?)\s*(?:<tool_call>\s*|</tool_call>\s*|$)", re.DOTALL)
4848
self.scratch_pad_regex = re.compile(
4949
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
5050

@@ -80,15 +80,17 @@ def extract_tool_calls(
8080
# tag and end-of-string so the result of
8181
# findall is an array of tuples where one is a function call and
8282
# the other is None
83-
function_call_tuples = (
84-
self.tool_call_regex.findall(model_output))
85-
86-
# load the JSON, and then use it to build the Function and
87-
# Tool Call
88-
raw_function_calls = [
89-
json.loads(match[0] if match[0] else match[1])
90-
for match in function_call_tuples
91-
]
83+
matches = self.tool_call_regex.findall(model_output)
84+
raw_function_calls = []
85+
for match in matches:
86+
if not match:
87+
continue
88+
try:
89+
parsed = json.loads(match.strip())
90+
raw_function_calls.append(parsed)
91+
except json.JSONDecodeError as e:
92+
logger.warning("Skipping malformed tool_call block: %s", e)
93+
9294
tool_calls = [
9395
ToolCall(
9496
type="function",
@@ -99,9 +101,8 @@ def extract_tool_calls(
99101
ensure_ascii=False)))
100102
for function_call in raw_function_calls
101103
]
102-
103-
content = model_output[:model_output.
104-
find(self.tool_call_start_token)]
104+
tool_call_start = model_output.find(self.tool_call_start_token)
105+
content = model_output[:tool_call_start] if tool_call_start >= 0 else None
105106
return ExtractedToolCallInformation(
106107
tools_called=True,
107108
tool_calls=tool_calls,

0 commit comments

Comments
 (0)