11
11
import json
12
12
import os
13
13
import re
14
- from typing import Dict , List , Optional , Tuple
14
+ from typing import Dict , List , Optional
15
15
16
16
import pytest
17
17
import requests
20
20
21
21
22
22
def check_openai_nonstream_response (
23
- response : Dict ,
24
- * ,
25
- model : str ,
26
- object_str : str ,
27
- num_choices : int ,
28
- finish_reason : List [str ],
29
- completion_tokens : Optional [int ] = None ,
23
+ response : Dict ,
24
+ * ,
25
+ model : str ,
26
+ object_str : str ,
27
+ num_choices : int ,
28
+ finish_reason : List [str ],
29
+ completion_tokens : Optional [int ] = None ,
30
30
):
31
31
assert response ["model" ] == model
32
32
assert response ["object" ] == object_str
@@ -68,16 +68,16 @@ def check_openai_nonstream_response(
68
68
69
69
70
70
def check_openai_stream_response (
71
- responses : List [Dict ],
72
- * ,
73
- model : str ,
74
- object_str : str ,
75
- num_choices : int ,
76
- finish_reason : str ,
77
- echo_prompt : Optional [str ] = None ,
78
- suffix : Optional [str ] = None ,
79
- stop : Optional [List [str ]] = None ,
80
- require_substr : Optional [List [str ]] = None ,
71
+ responses : List [Dict ],
72
+ * ,
73
+ model : str ,
74
+ object_str : str ,
75
+ num_choices : int ,
76
+ finish_reason : str ,
77
+ echo_prompt : Optional [str ] = None ,
78
+ suffix : Optional [str ] = None ,
79
+ stop : Optional [List [str ]] = None ,
80
+ require_substr : Optional [List [str ]] = None ,
81
81
):
82
82
assert len (responses ) > 0
83
83
@@ -219,7 +219,6 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
219
219
You are a helpful assistant.""" ,
220
220
}
221
221
222
-
223
222
STRUCTURAL_TAGS = {
224
223
"triggers" : ["<CALL--->" , "<call--->" ],
225
224
"tags" : [
@@ -236,14 +235,14 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
236
235
"state" : {
237
236
"type" : "string" ,
238
237
"description" : "the two-letter abbreviation for the state that the city is"
239
- " in, e.g. 'CA' which would mean 'California'" ,
238
+ " in, e.g. 'CA' which would mean 'California'" ,
240
239
},
241
240
"unit" : {
242
241
"type" : "string" ,
243
242
"description" : "The unit to fetch the temperature in" ,
244
243
"enum" : ["celsius" , "fahrenheit" ],
245
244
},
246
- "hash_code" : {"const" : 1234 },
245
+ "hash_code" : {"const" : " 1234" },
247
246
},
248
247
"required" : ["city" , "state" , "unit" , "hash_code" ],
249
248
}
@@ -260,7 +259,7 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
260
259
"type" : "string" ,
261
260
"description" : "The timezone to fetch the current date and time for, e.g. 'America/New_York'" ,
262
261
},
263
- "hash_code" : {"const" : 2345 },
262
+ "hash_code" : {"const" : " 2345" },
264
263
},
265
264
"required" : ["timezone" , "hash_code" ],
266
265
}
@@ -280,14 +279,14 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
280
279
"state" : {
281
280
"type" : "string" ,
282
281
"description" : "the two-letter abbreviation for the state that the city is"
283
- " in, e.g. 'CA' which would mean 'California'" ,
282
+ " in, e.g. 'CA' which would mean 'California'" ,
284
283
},
285
284
"unit" : {
286
285
"type" : "string" ,
287
286
"description" : "The unit to fetch the temperature in" ,
288
287
"enum" : ["celsius" , "fahrenheit" ],
289
288
},
290
- "hash_code" : {"const" : 3456 },
289
+ "hash_code" : {"const" : " 3456" },
291
290
},
292
291
"required" : ["city" , "state" , "unit" , "hash_code" ],
293
292
}
@@ -304,7 +303,7 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
304
303
"type" : "string" ,
305
304
"description" : "The timezone to fetch the current date and time for, e.g. 'America/New_York'" ,
306
305
},
307
- "hash_code" : {"const" : 4567 },
306
+ "hash_code" : {"const" : " 4567" },
308
307
},
309
308
"required" : ["timezone" , "hash_code" ],
310
309
}
@@ -315,29 +314,28 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
315
314
}
316
315
317
316
CHECK_INFO = {
318
- 1234 : {
317
+ " 1234" : {
319
318
"name" : "get_current_weather" ,
320
319
"beg_tag" : "CALL" ,
321
320
"required" : ["city" , "state" , "unit" , "hash_code" ],
322
321
},
323
- 2345 : {
322
+ " 2345" : {
324
323
"name" : "get_current_date" ,
325
324
"beg_tag" : "CALL" ,
326
325
"required" : ["timezone" , "hash_code" ],
327
326
},
328
- 3456 : {
327
+ " 3456" : {
329
328
"name" : "get_current_weather" ,
330
329
"beg_tag" : "call" ,
331
330
"required" : ["city" , "state" , "unit" , "hash_code" ],
332
331
},
333
- 4567 : {
332
+ " 4567" : {
334
333
"name" : "get_current_date" ,
335
334
"beg_tag" : "call" ,
336
335
"required" : ["timezone" , "hash_code" ],
337
336
},
338
337
}
339
338
340
-
341
339
CHAT_COMPLETION_MESSAGES = [
342
340
# messages #0
343
341
[
@@ -369,10 +367,10 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
369
367
@pytest .mark .parametrize ("stream" , [False , True ])
370
368
@pytest .mark .parametrize ("messages" , CHAT_COMPLETION_MESSAGES )
371
369
def test_openai_v1_chat_completion_structural_tag (
372
- served_model : str ,
373
- launch_server , # pylint: disable=unused-argument
374
- stream : bool ,
375
- messages : List [Dict [str , str ]],
370
+ served_model : str ,
371
+ launch_server , # pylint: disable=unused-argument
372
+ stream : bool ,
373
+ messages : List [Dict [str , str ]],
376
374
):
377
375
# `served_model` and `launch_server` are pytest fixtures
378
376
# defined in conftest.py.
0 commit comments