Skip to content

Commit 8b70dd7

Browse files
committed
[fix] type annotation in test scripts
1 parent 4f980f3 commit 8b70dd7

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

python/mlc_llm/protocol/openai_api_protocol.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def check_request_response_format(self) -> "RequestResponseFormat":
106106
for tag in self.tags:
107107
if set(tag.keys()) != {"begin", "schema", "end"}:
108108
raise ValueError(
109-
f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}."
109+
"Each tag must contain exactly 'begin', 'schema' and 'end' keys."
110+
f"Got keys: {list(tag.keys())}."
110111
)
111112
elif self.tags is not None or self.triggers is not None:
112113
raise Warning(

tests/python/serve/server/test_server_structural_tag.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import json
1212
import os
1313
import re
14-
from typing import Dict, List, Optional, Tuple
14+
from typing import Any, Dict, List, Optional
1515

1616
import pytest
1717
import requests
@@ -136,16 +136,21 @@ def check_openai_stream_response(
136136

137137

138138
def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
139-
schema = json.loads(schema)
139+
try:
140+
paras: Dict[str, Any] = json.loads(schema)
141+
except json.JSONDecodeError as e:
142+
print(f"Invalid JSON format: {e}")
143+
assert False
144+
assert "hash_code" in paras
140145
assert "hash_code" in schema
141-
hash_code = schema["hash_code"]
146+
hash_code = paras["hash_code"]
142147
assert hash_code in CHECK_INFO
143148
info = CHECK_INFO[hash_code]
144149
assert name_beg == info["name"]
145150
assert name_end == info["name"]
146151
assert beg_tag == info["beg_tag"]
147152
for key in info["required"]:
148-
assert key in schema
153+
assert key in paras
149154

150155

151156
# NOTE: the end-tag format and the hash_code number is been hidden in the SYSTEM_PROMPT.
@@ -219,7 +224,6 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
219224
You are a helpful assistant.""",
220225
}
221226

222-
223227
STRUCTURAL_TAGS = {
224228
"triggers": ["<CALL--->", "<call--->"],
225229
"tags": [
@@ -337,7 +341,6 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str):
337341
},
338342
}
339343

340-
341344
CHAT_COMPLETION_MESSAGES = [
342345
# messages #0
343346
[

0 commit comments

Comments
 (0)