Skip to content

Commit 135419e

Browse files
authored
[Test][Refactor] Update tests to use require_test_model (#2415)
This PR updates tests to use the `require_test_model` testing util for better out-of-box testing while avoid automatic downloading. Some tests that require manually model compilation are kept in the old test style (e.g., with model "llava", "eagle", etc.). This PR also fixes some typing issues suggested by mypy.
1 parent cfc0597 commit 135419e

11 files changed

+290
-174
lines changed

python/mlc_llm/serve/engine.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,21 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35+
# Note: we define both AsyncChat and Chat for Python type analysis.
36+
class AsyncChat: # pylint: disable=too-few-public-methods
37+
"""The proxy class to direct to async chat completions."""
38+
39+
def __init__(self, engine: weakref.ReferenceType) -> None:
40+
assert isinstance(engine(), AsyncMLCEngine)
41+
self.completions = AsyncChatCompletion(engine)
42+
43+
3544
class Chat: # pylint: disable=too-few-public-methods
3645
"""The proxy class to direct to chat completions."""
3746

3847
def __init__(self, engine: weakref.ReferenceType) -> None:
39-
assert isinstance(engine(), (AsyncMLCEngine, MLCEngine))
40-
self.completions = (
41-
AsyncChatCompletion(engine) # type: ignore
42-
if isinstance(engine(), AsyncMLCEngine)
43-
else ChatCompletion(engine) # type: ignore
44-
)
48+
assert isinstance(engine(), MLCEngine)
49+
self.completions = ChatCompletion(engine)
4550

4651

4752
class AsyncChatCompletion: # pylint: disable=too-few-public-methods
@@ -151,7 +156,7 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals
151156
Extra debug options to pass to the request.
152157
153158
Returns
154-
------
159+
-------
155160
response : ChatCompletionResponse
156161
The chat completion response conforming to OpenAI API.
157162
See mlc_llm/protocol/openai_api_protocol.py or
@@ -643,7 +648,7 @@ def create( # pylint: disable=too-many-arguments,too-many-locals
643648
response_format: Optional[Dict[str, Any]] = None,
644649
request_id: Optional[str] = None,
645650
debug_config: Optional[Dict[str, Any]] = None,
646-
) -> openai_api_protocol.CompletionResponse:
651+
) -> Iterator[openai_api_protocol.CompletionResponse]:
647652
"""Synchronous streaming completion interface with OpenAI API compatibility.
648653
The method streams back CompletionResponse that conforms to
649654
OpenAI API one at a time via yield.
@@ -698,7 +703,7 @@ def create( # pylint: disable=too-many-arguments,too-many-locals
698703
response_format: Optional[Dict[str, Any]] = None,
699704
request_id: Optional[str] = None,
700705
debug_config: Optional[Dict[str, Any]] = None,
701-
) -> Iterator[openai_api_protocol.CompletionResponse]:
706+
) -> openai_api_protocol.CompletionResponse:
702707
"""Synchronous non-streaming completion interface with OpenAI API compatibility.
703708
704709
See https://platform.openai.com/docs/api-reference/completions/create for specification.
@@ -714,7 +719,7 @@ def create( # pylint: disable=too-many-arguments,too-many-locals
714719
Extra debug options to pass to the request.
715720
716721
Returns
717-
------
722+
-------
718723
response : CompletionResponse
719724
The completion response conforming to OpenAI API.
720725
See mlc_llm/protocol/openai_api_protocol.py or
@@ -750,7 +755,10 @@ def create( # pylint: disable=too-many-arguments,too-many-locals
750755
response_format: Optional[Dict[str, Any]] = None,
751756
request_id: Optional[str] = None,
752757
debug_config: Optional[Dict[str, Any]] = None,
753-
) -> Iterator[openai_api_protocol.CompletionResponse]:
758+
) -> Union[
759+
Iterator[openai_api_protocol.CompletionResponse],
760+
openai_api_protocol.CompletionResponse,
761+
]:
754762
"""Synchronous completion interface with OpenAI API compatibility.
755763
756764
See https://platform.openai.com/docs/api-reference/completions/create for specification.
@@ -864,7 +872,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
864872
engine_config=engine_config,
865873
enable_tracing=enable_tracing,
866874
)
867-
self.chat = Chat(weakref.ref(self))
875+
self.chat = AsyncChat(weakref.ref(self))
868876
self.completions = AsyncCompletion(weakref.ref(self))
869877

870878
async def abort(self, request_id: str) -> None:
@@ -1568,7 +1576,10 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals
15681576
response_format: Optional[Dict[str, Any]] = None,
15691577
request_id: Optional[str] = None,
15701578
debug_config: Optional[Dict[str, Any]] = None,
1571-
) -> Iterator[openai_api_protocol.CompletionResponse]:
1579+
) -> Union[
1580+
Iterator[openai_api_protocol.CompletionResponse],
1581+
openai_api_protocol.CompletionResponse,
1582+
]:
15721583
"""Synchronous completion internal interface with OpenAI API compatibility.
15731584
15741585
See https://platform.openai.com/docs/api-reference/completions/create for specification.

python/mlc_llm/support/constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def _get_test_model_path() -> List[Path]:
5353
# by default, we reuse the cache dir via mlc_llm chat
5454
# note that we do not auto download for testcase
5555
# to avoid networking dependencies
56-
return [_get_cache_dir() / "model_weights" / "mlc-ai"]
56+
return [
57+
_get_cache_dir() / "model_weights" / "mlc-ai",
58+
Path(os.path.abspath(os.path.curdir)),
59+
]
5760

5861

5962
MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None)
Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""Extra utilities to mark tests"""
22

33
import functools
4+
import inspect
5+
from pathlib import Path
46
from typing import Callable
57

68
import pytest
79

810
from mlc_llm.support.constants import MLC_TEST_MODEL_PATH
911

1012

11-
def require_test_model(model: str):
13+
def require_test_model(*models: str):
1214
"""Testcase decorator to require a model
1315
1416
Examples
@@ -24,31 +26,54 @@ def test_reload_reset_unload(model):
2426
2527
Parameters
2628
----------
27-
model : str
28-
The model dir name
29+
models : List[str]
30+
The model directories or URLs.
2931
"""
30-
model_path = None
31-
for base_path in MLC_TEST_MODEL_PATH:
32-
if (base_path / model / "mlc-chat-config.json").is_file():
33-
model_path = base_path / model
34-
missing_model = model_path is None
32+
model_paths = []
33+
missing_models = []
34+
35+
for model in models:
36+
model_path = None
37+
for base_path in MLC_TEST_MODEL_PATH:
38+
if (base_path / model / "mlc-chat-config.json").is_file():
39+
model_path = base_path / model
40+
if model_path is None and (Path(model) / "mlc-chat-config.json").is_file():
41+
model_path = Path(model)
42+
43+
if model_path is None:
44+
missing_models.append(model)
45+
else:
46+
model_paths.append(str(model_path))
47+
3548
message = (
36-
f"Model {model} does not exist in candidate paths {[str(p) for p in MLC_TEST_MODEL_PATH]},"
49+
f"Model {', '.join(missing_models)} not found in candidate paths "
50+
f"{[str(p) for p in MLC_TEST_MODEL_PATH]},"
3751
" if you set MLC_TEST_MODEL_PATH, please ensure model paths are in the right location,"
3852
" by default we reuse cache, try to run mlc_llm chat to download right set of models."
3953
)
4054

41-
def _decorator(func: Callable[[str], None]):
42-
wrapped = functools.partial(func, str(model_path))
55+
def _decorator(func: Callable[..., None]):
56+
wrapped = functools.partial(func, *model_paths)
4357
wrapped.__name__ = func.__name__ # type: ignore
4458

45-
@functools.wraps(wrapped)
46-
def wrapper(*args, **kwargs):
47-
if missing_model:
48-
print(f"{message} skipping...")
49-
return
50-
wrapped(*args, **kwargs)
59+
if inspect.iscoroutinefunction(wrapped):
60+
# The function is a coroutine function ("async def func(...)")
61+
@functools.wraps(wrapped)
62+
async def wrapper(*args, **kwargs):
63+
if len(missing_models) > 0:
64+
print(f"{message} skipping...")
65+
return
66+
await wrapped(*args, **kwargs)
67+
68+
else:
69+
# The function is a normal function ("def func(...)")
70+
@functools.wraps(wrapped)
71+
def wrapper(*args, **kwargs):
72+
if len(missing_models) > 0:
73+
print(f"{message} skipping...")
74+
return
75+
wrapped(*args, **kwargs)
5176

52-
return pytest.mark.skipif(missing_model, reason=message)(wrapper)
77+
return pytest.mark.skipif(len(missing_models) > 0, reason=message)(wrapper)
5378

5479
return _decorator

tests/python/serve/test_serve_async_engine.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List
55

66
from mlc_llm.serve import AsyncMLCEngine, EngineConfig, GenerationConfig
7+
from mlc_llm.testing import require_test_model
78

89
prompts = [
910
"What is the meaning of life?",
@@ -19,9 +20,9 @@
1920
]
2021

2122

22-
async def test_engine_generate():
23+
@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC")
24+
async def test_engine_generate(model: str):
2325
# Create engine
24-
model = "HF://mlc-ai/Llama-2-7b-chat-hf-q0f16-MLC"
2526
async_engine = AsyncMLCEngine(
2627
model=model,
2728
mode="server",
@@ -74,9 +75,9 @@ async def generate_task(
7475
del async_engine
7576

7677

77-
async def test_chat_completion():
78+
@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC")
79+
async def test_chat_completion(model: str):
7880
# Create engine
79-
model = "HF://mlc-ai/Llama-2-7b-chat-hf-q0f16-MLC"
8081
async_engine = AsyncMLCEngine(
8182
model=model,
8283
mode="server",
@@ -101,6 +102,7 @@ async def generate_task(prompt: str, request_id: str):
101102
):
102103
for choice in response.choices:
103104
assert choice.delta.role == "assistant"
105+
assert isinstance(choice.delta.content, str)
104106
output_texts[rid][choice.index] += choice.delta.content
105107

106108
tasks = [
@@ -124,9 +126,9 @@ async def generate_task(prompt: str, request_id: str):
124126
del async_engine
125127

126128

127-
async def test_chat_completion_non_stream():
129+
@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC")
130+
async def test_chat_completion_non_stream(model: str):
128131
# Create engine
129-
model = "HF://mlc-ai/Llama-2-7b-chat-hf-q0f16-MLC"
130132
async_engine = AsyncMLCEngine(
131133
model=model,
132134
mode="server",
@@ -150,6 +152,7 @@ async def generate_task(prompt: str, request_id: str):
150152
)
151153
for choice in response.choices:
152154
assert choice.message.role == "assistant"
155+
assert isinstance(choice.message.content, str)
153156
output_texts[rid][choice.index] += choice.message.content
154157

155158
tasks = [
@@ -173,9 +176,9 @@ async def generate_task(prompt: str, request_id: str):
173176
del async_engine
174177

175178

176-
async def test_completion():
179+
@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC")
180+
async def test_completion(model: str):
177181
# Create engine
178-
model = "HF://mlc-ai/Llama-2-7b-chat-hf-q0f16-MLC"
179182
async_engine = AsyncMLCEngine(
180183
model=model,
181184
mode="server",
@@ -223,9 +226,9 @@ async def generate_task(prompt: str, request_id: str):
223226
del async_engine
224227

225228

226-
async def test_completion_non_stream():
229+
@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC")
230+
async def test_completion_non_stream(model: str):
227231
# Create engine
228-
model = "HF://mlc-ai/Llama-2-7b-chat-hf-q0f16-MLC"
229232
async_engine = AsyncMLCEngine(
230233
model=model,
231234
mode="server",

tests/python/serve/test_serve_async_engine_spec.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List
55

66
from mlc_llm.serve import AsyncMLCEngine, EngineConfig, GenerationConfig
7+
from mlc_llm.testing import require_test_model
78

89
prompts = [
910
"What is the meaning of life?",
@@ -19,10 +20,12 @@
1920
]
2021

2122

22-
async def test_engine_generate():
23+
@require_test_model(
24+
"Llama-2-7b-chat-hf-q0f16-MLC",
25+
"Llama-2-7b-chat-hf-q4f16_1-MLC",
26+
)
27+
async def test_engine_generate(model: str, small_model: str):
2328
# Create engine
24-
model = "HF://mlc-ai/Llama-2-7b-chat-hf-q0f16-MLC"
25-
small_model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC"
2629
async_engine = AsyncMLCEngine(
2730
model=model,
2831
mode="server",

0 commit comments

Comments
 (0)