@@ -44,7 +44,8 @@ def __init__(self, tokenizer: AnyTokenizer):
44
44
self .tool_call_end_token : str = "</tool_call>"
45
45
46
46
self .tool_call_regex = re .compile (
47
- r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)" , re .DOTALL )
47
+ r"<tool_call>\s*(.*?)\s*(?:<tool_call>\s*|</tool_call>\s*|$)" ,
48
+ re .DOTALL )
48
49
self .scratch_pad_regex = re .compile (
49
50
r"<scratch_pad>(.*?)</scratch_pad>" , re .DOTALL )
50
51
@@ -80,15 +81,18 @@ def extract_tool_calls(
80
81
# tag and end-of-string so the result of
81
82
# findall is an array of tuples where one is a function call and
82
83
# the other is None
83
- function_call_tuples = (
84
- self .tool_call_regex .findall (model_output ))
85
-
86
- # load the JSON, and then use it to build the Function and
87
- # Tool Call
88
- raw_function_calls = [
89
- json .loads (match [0 ] if match [0 ] else match [1 ])
90
- for match in function_call_tuples
91
- ]
84
+ matches = self .tool_call_regex .findall (model_output )
85
+ raw_function_calls = []
86
+ for match in matches :
87
+ if not match :
88
+ continue
89
+ try :
90
+ parsed = json .loads (match .strip ())
91
+ raw_function_calls .append (parsed )
92
+ except json .JSONDecodeError as e :
93
+ logger .warning (
94
+ "Skipping malformed tool_call block: %s" , e )
95
+
92
96
tool_calls = [
93
97
ToolCall (
94
98
type = "function" ,
@@ -99,9 +103,9 @@ def extract_tool_calls(
99
103
ensure_ascii = False )))
100
104
for function_call in raw_function_calls
101
105
]
102
-
103
- content = model_output [:model_output .
104
- find ( self . tool_call_start_token )]
106
+ tool_call_start = model_output . find ( self . tool_call_start_token )
107
+ content = ( model_output [:tool_call_start ]
108
+ if tool_call_start >= 0 else None )
105
109
return ExtractedToolCallInformation (
106
110
tools_called = True ,
107
111
tool_calls = tool_calls ,
0 commit comments