@@ -37,14 +37,16 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
37
37
assert len (actual_tool_call .id ) == 9
38
38
39
39
assert actual_tool_call .type == "function"
40
- assert actual_tool_call .function == expected_tool_call .function , f'got ${ actual_tool_call .function } '
40
+ assert actual_tool_call .function == expected_tool_call .function , (
41
+ f'got ${ actual_tool_call .function } ' )
41
42
42
43
43
44
def stream_delta_message_generator (
44
- mistral_tool_parser : MistralToolParser , mistral_tokenizer : AnyTokenizer ,
45
+ mistral_tool_parser : MistralToolParser ,
46
+ mistral_tokenizer : AnyTokenizer ,
45
47
model_output : str ) -> Generator [DeltaMessage , None , None ]:
46
48
all_token_ids = mistral_tokenizer .encode (model_output ,
47
- add_special_tokens = False )
49
+ add_special_tokens = False )
48
50
49
51
previous_text = ""
50
52
previous_tokens = None
@@ -98,9 +100,7 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser):
98
100
99
101
@pytest .mark .parametrize (
100
102
ids = [
101
- "single_tool_add" ,
102
- "single_tool_weather" ,
103
- "argument_before_name" ,
103
+ "single_tool_add" , "single_tool_weather" , "argument_before_name" ,
104
104
"argument_before_name_and_name_in_argument"
105
105
],
106
106
argnames = ["model_output" , "expected_tool_calls" , "expected_content" ],
@@ -109,11 +109,10 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser):
109
109
'''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''' , # noqa: E501
110
110
[
111
111
ToolCall (function = FunctionCall (name = "add" ,
112
- arguments = json .dumps (
113
- {
114
- "a" : 3.5 ,
115
- "b" : 4
116
- })))
112
+ arguments = json .dumps ({
113
+ "a" : 3.5 ,
114
+ "b" : 4
115
+ })))
117
116
],
118
117
None ),
119
118
(
@@ -125,7 +124,7 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser):
125
124
"city" : "San Francisco" ,
126
125
"state" : "CA" ,
127
126
"unit" : "celsius"
128
- })))
127
+ })))
129
128
],
130
129
None ),
131
130
(
@@ -137,17 +136,17 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser):
137
136
"city" : "San Francisco" ,
138
137
"state" : "CA" ,
139
138
"unit" : "celsius"
140
- })))
139
+ })))
141
140
],
142
141
None ),
143
142
(
144
143
'''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''' , # noqa: E501
145
144
[
146
145
ToolCall (function = FunctionCall (name = "get_age" ,
147
- arguments = json .dumps (
148
- {
149
- "name" : "John Doe" ,
150
- })))
146
+ arguments = json .dumps ({
147
+ "name" :
148
+ "John Doe" ,
149
+ })))
151
150
],
152
151
None ),
153
152
],
@@ -180,22 +179,20 @@ def test_extract_tool_calls(mistral_tool_parser, model_output,
180
179
'''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''' , # noqa: E501
181
180
[
182
181
ToolCall (function = FunctionCall (name = "add" ,
183
- arguments = json .dumps (
184
- {
185
- "a" : 3 ,
186
- "b" : 4
187
- })))
182
+ arguments = json .dumps ({
183
+ "a" : 3 ,
184
+ "b" : 4
185
+ })))
188
186
],
189
187
"" ),
190
188
(
191
189
'''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''' , # noqa: E501
192
190
[
193
191
ToolCall (function = FunctionCall (name = "add" ,
194
- arguments = json .dumps (
195
- {
196
- "a" : "3" ,
197
- "b" : "4"
198
- })))
192
+ arguments = json .dumps ({
193
+ "a" : "3" ,
194
+ "b" : "4"
195
+ })))
199
196
],
200
197
"" ),
201
198
(
@@ -207,7 +204,7 @@ def test_extract_tool_calls(mistral_tool_parser, model_output,
207
204
"city" : "San Francisco" ,
208
205
"state" : "CA" ,
209
206
"unit" : "celsius"
210
- })))
207
+ })))
211
208
],
212
209
"" ),
213
210
(
@@ -219,35 +216,34 @@ def test_extract_tool_calls(mistral_tool_parser, model_output,
219
216
"city" : "San Francisco" ,
220
217
"state" : "CA" ,
221
218
"unit" : "celsius"
222
- })))
219
+ })))
223
220
],
224
221
'' ),
225
222
(
226
223
'''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''' , # noqa: E501
227
224
[
228
225
ToolCall (function = FunctionCall (name = "get_age" ,
229
- arguments = json .dumps (
230
- {
231
- "name" : "John Doe" ,
232
- })))
226
+ arguments = json .dumps ({
227
+ "name" :
228
+ "John Doe" ,
229
+ })))
233
230
],
234
231
'' ),
235
232
(
236
233
'''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''' , # noqa: E501
237
234
[
238
235
ToolCall (function = FunctionCall (name = "add" ,
239
- arguments = json .dumps (
240
- {
241
- "a" : 3.5 ,
242
- "b" : 4
243
- }))),
236
+ arguments = json .dumps ({
237
+ "a" : 3.5 ,
238
+ "b" : 4
239
+ }))),
244
240
ToolCall (function = FunctionCall (name = "get_current_weather" ,
245
241
arguments = json .dumps (
246
242
{
247
243
"city" : "San Francisco" ,
248
244
"state" : "CA" ,
249
245
"unit" : "celsius"
250
- })))
246
+ })))
251
247
],
252
248
'' ),
253
249
],
0 commit comments