Skip to content

Commit d6d17c1

Browse files
committed
repair ruff pre-commit
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
1 parent 92601d9 commit d6d17c1

File tree

2 files changed

+151
-107
lines changed

2 files changed

+151
-107
lines changed

tests/tool_use/test_mistral_tool_parser.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
3737
assert len(actual_tool_call.id) == 9
3838

3939
assert actual_tool_call.type == "function"
40-
assert actual_tool_call.function == expected_tool_call.function, f'got ${actual_tool_call.function}'
40+
assert actual_tool_call.function == expected_tool_call.function, (
41+
f'got ${actual_tool_call.function}')
4142

4243

4344
def stream_delta_message_generator(
44-
mistral_tool_parser: MistralToolParser, mistral_tokenizer: AnyTokenizer,
45+
mistral_tool_parser: MistralToolParser,
46+
mistral_tokenizer: AnyTokenizer,
4547
model_output: str) -> Generator[DeltaMessage, None, None]:
4648
all_token_ids = mistral_tokenizer.encode(model_output,
47-
add_special_tokens=False)
49+
add_special_tokens=False)
4850

4951
previous_text = ""
5052
previous_tokens = None
@@ -98,9 +100,7 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser):
98100

99101
@pytest.mark.parametrize(
100102
ids=[
101-
"single_tool_add",
102-
"single_tool_weather",
103-
"argument_before_name",
103+
"single_tool_add", "single_tool_weather", "argument_before_name",
104104
"argument_before_name_and_name_in_argument"
105105
],
106106
argnames=["model_output", "expected_tool_calls", "expected_content"],
@@ -109,11 +109,10 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser):
109109
'''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501
110110
[
111111
ToolCall(function=FunctionCall(name="add",
112-
arguments=json.dumps(
113-
{
114-
"a": 3.5,
115-
"b": 4
116-
})))
112+
arguments=json.dumps({
113+
"a": 3.5,
114+
"b": 4
115+
})))
117116
],
118117
None),
119118
(
@@ -125,7 +124,7 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser):
125124
"city": "San Francisco",
126125
"state": "CA",
127126
"unit": "celsius"
128-
})))
127+
})))
129128
],
130129
None),
131130
(
@@ -137,17 +136,17 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser):
137136
"city": "San Francisco",
138137
"state": "CA",
139138
"unit": "celsius"
140-
})))
139+
})))
141140
],
142141
None),
143142
(
144143
'''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501
145144
[
146145
ToolCall(function=FunctionCall(name="get_age",
147-
arguments=json.dumps(
148-
{
149-
"name": "John Doe",
150-
})))
146+
arguments=json.dumps({
147+
"name":
148+
"John Doe",
149+
})))
151150
],
152151
None),
153152
],
@@ -180,22 +179,20 @@ def test_extract_tool_calls(mistral_tool_parser, model_output,
180179
'''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501
181180
[
182181
ToolCall(function=FunctionCall(name="add",
183-
arguments=json.dumps(
184-
{
185-
"a": 3,
186-
"b": 4
187-
})))
182+
arguments=json.dumps({
183+
"a": 3,
184+
"b": 4
185+
})))
188186
],
189187
""),
190188
(
191189
'''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501
192190
[
193191
ToolCall(function=FunctionCall(name="add",
194-
arguments=json.dumps(
195-
{
196-
"a": "3",
197-
"b": "4"
198-
})))
192+
arguments=json.dumps({
193+
"a": "3",
194+
"b": "4"
195+
})))
199196
],
200197
""),
201198
(
@@ -207,7 +204,7 @@ def test_extract_tool_calls(mistral_tool_parser, model_output,
207204
"city": "San Francisco",
208205
"state": "CA",
209206
"unit": "celsius"
210-
})))
207+
})))
211208
],
212209
""),
213210
(
@@ -219,35 +216,34 @@ def test_extract_tool_calls(mistral_tool_parser, model_output,
219216
"city": "San Francisco",
220217
"state": "CA",
221218
"unit": "celsius"
222-
})))
219+
})))
223220
],
224221
''),
225222
(
226223
'''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501
227224
[
228225
ToolCall(function=FunctionCall(name="get_age",
229-
arguments=json.dumps(
230-
{
231-
"name": "John Doe",
232-
})))
226+
arguments=json.dumps({
227+
"name":
228+
"John Doe",
229+
})))
233230
],
234231
''),
235232
(
236233
'''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501
237234
[
238235
ToolCall(function=FunctionCall(name="add",
239-
arguments=json.dumps(
240-
{
241-
"a": 3.5,
242-
"b": 4
243-
}))),
236+
arguments=json.dumps({
237+
"a": 3.5,
238+
"b": 4
239+
}))),
244240
ToolCall(function=FunctionCall(name="get_current_weather",
245241
arguments=json.dumps(
246242
{
247243
"city": "San Francisco",
248244
"state": "CA",
249245
"unit": "celsius"
250-
})))
246+
})))
251247
],
252248
''),
253249
],

0 commit comments

Comments
 (0)