Skip to content

Commit 8b38a4b

Browse files
authored
[REFACTOR] Remove dependencies on legacy chat_module (#2424)
This PR removes the all dependencies from chat_module.py So we can prepare for deprecating this module. This PR refactors and moves MLCChatConfig to protocol. This helps us to consolidate all API spec and config files under the protocol folder. The protocol folder mainly keeps the data schema and metadata, most of the actions(gen_config) are still kept in their current location.
1 parent 13c0661 commit 8b38a4b

File tree

15 files changed

+198
-129
lines changed

15 files changed

+198
-129
lines changed

android/MLCEngineExample/mlc-package-config.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
{
55
"model": "HF://mlc-ai/phi-2-q4f16_1-MLC",
66
"estimated_vram_bytes": 2036816936,
7-
"model_id": "phi-2-q4f16_1"
7+
"model_id": "phi-2-q4f16_1",
8+
"overrides": {
9+
"prefill_chunk_size": 1024
10+
}
811
}
912
]
1013
}

python/mlc_llm/chat_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def __init__( # pylint: disable=too-many-arguments
783783

784784
self.model_lib = jit.jit(
785785
model_path=Path(self.model_path),
786-
chat_config=asdict(self.chat_config),
786+
overrides=asdict(self.chat_config),
787787
device=self.device,
788788
).model_lib_path
789789
_inspect_model_lib_metadata_memory_usage(self.model_lib, self.config_file_path)

python/mlc_llm/cli/delivery.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
import tempfile
1111
from pathlib import Path
12-
from typing import Any, Callable, Dict, List, Tuple, Union
12+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1313

1414
from huggingface_hub import HfApi # pylint: disable=import-error
1515
from huggingface_hub.utils import HfHubHTTPError # pylint: disable=import-error
@@ -43,11 +43,11 @@ class ModelInfo: # pylint: disable=too-many-instance-attributes
4343
source_format: str = "auto"
4444
# If unspecified in CLI, remains to be None and will not be
4545
# passed to `gen_config` or `convert_weight`
46-
context_window_size: int = None
47-
sliding_window_size: int = None
48-
prefill_chunk_size: int = None
49-
attention_sink_size: int = None
50-
tensor_parallel_shards: int = None
46+
context_window_size: Optional[int] = None
47+
sliding_window_size: Optional[int] = None
48+
prefill_chunk_size: Optional[int] = None
49+
attention_sink_size: Optional[int] = None
50+
tensor_parallel_shards: Optional[int] = None
5151

5252

5353
class DeferredScope:

python/mlc_llm/contrib/embeddings/embeddings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from tvm.runtime import Device, Module
1212
from tvm.runtime.relax_vm import VirtualMachine
1313

14-
from mlc_llm.chat_module import _get_model_path
1514
from mlc_llm.serve import engine_utils
1615
from mlc_llm.support.auto_device import detect_device
1716
from mlc_llm.tokenizer import Tokenizer
@@ -143,7 +142,7 @@ def __init__( # pylint: disable=too-many-arguments
143142
self.mod, self.params, self.metadata = _get_tvm_module(
144143
model, model_lib_path, self.device, instrument
145144
)
146-
self.model_path, _ = _get_model_path(model)
145+
self.model_path = model
147146
self.tokenizer = Tokenizer(self.model_path)
148147
self.prefill_func = self.mod["prefill"]
149148

python/mlc_llm/interface/gen_config.py

Lines changed: 10 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Generator of mlc-chat-config.json and tokenizer configuration."""
2-
3-
import dataclasses
2+
# pylint: disable=E1101
43
import json
54
import re
65
import shutil
76
from dataclasses import asdict
87
from pathlib import Path
9-
from typing import Any, Dict, List, Optional, Union
8+
from typing import Optional
109

1110
from mlc_llm.conversation_template import ConvTemplateRegistry
1211
from mlc_llm.model import Model
12+
from mlc_llm.protocol.mlc_chat_config import MLCChatConfig
1313
from mlc_llm.quantization import Quantization
1414
from mlc_llm.support import convert_tiktoken, logging
1515
from mlc_llm.support.style import bold, green, red
@@ -22,63 +22,13 @@
2222
FOUND = green("Found")
2323
NOT_FOUND = red("Not found")
2424
FAILED = red("Failed")
25-
VERSION = "0.1.0"
26-
27-
28-
@dataclasses.dataclass
29-
class MLCChatConfig: # pylint: disable=too-many-instance-attributes
30-
"""Fields in the dumped `mlc-chat-config.json` file."""
3125

32-
model_type: str
33-
quantization: str
34-
model_config: Dict[str, Any]
35-
vocab_size: int
36-
context_window_size: int
37-
sliding_window_size: int
38-
prefill_chunk_size: int
39-
attention_sink_size: int
40-
tensor_parallel_shards: int
41-
# Control the behavior of the runtime
42-
mean_gen_len: int = None
43-
max_gen_len: int = None
44-
shift_fill_factor: float = None
45-
# Configuration of text generation
46-
temperature: float = None
47-
presence_penalty: float = None
48-
frequency_penalty: float = None
49-
repetition_penalty: float = None
50-
top_p: float = None
51-
# Conversation template
52-
conv_template: Union[str, Dict[str, Any]] = None
53-
pad_token_id: int = None
54-
bos_token_id: int = None
55-
eos_token_id: int = None
56-
# Tokenizer configuration
57-
tokenizer_files: List[str] = dataclasses.field(default_factory=list)
58-
# The content of tokenizer.TokenizerInfo
59-
tokenizer_info: Dict[str, Any] = dataclasses.field(default_factory=dict)
60-
# Version control
61-
version: str = VERSION
6226

63-
def apply_defaults(self) -> None:
64-
"""Apply system default value."""
65-
defaults = {
66-
"pad_token_id": 0,
67-
"bos_token_id": 1,
68-
"eos_token_id": 2,
69-
"temperature": 0.7,
70-
"presence_penalty": 0.0,
71-
"frequency_penalty": 0.0,
72-
"repetition_penalty": 1.0,
73-
"top_p": 0.95,
74-
"mean_gen_len": 128,
75-
"max_gen_len": 512,
76-
"shift_fill_factor": 0.3,
77-
}
78-
for key, value in defaults.items():
79-
if getattr(self, key) is None:
80-
setattr(self, key, value)
81-
logger.info("[System default] Setting %s: %s", bold(key), value)
27+
def apply_system_defaults_for_missing_fields(mlc_chat_config: MLCChatConfig) -> None:
28+
"""Apply system default value."""
29+
for key, value in mlc_chat_config.get_system_defaults_for_missing_fields().items():
30+
setattr(mlc_chat_config, key, value)
31+
logger.info("[System default] Setting %s: %s", bold(key), value)
8232

8333

8434
def check_string(s: str) -> bool:
@@ -265,10 +215,10 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
265215
logger.info("Detected tokenizer info: %s", mlc_chat_config.tokenizer_info)
266216

267217
# Step 4. Load system default value
268-
mlc_chat_config.apply_defaults()
218+
apply_system_defaults_for_missing_fields(mlc_chat_config)
269219
# Step 5. Dump the configuration file to output directory
270220
with (output / "mlc-chat-config.json").open("w", encoding="utf-8") as out_file:
271-
json.dump(dataclasses.asdict(mlc_chat_config), out_file, indent=2)
221+
json.dump(mlc_chat_config.model_dump(), out_file, indent=2)
272222
logger.info("Dumping configuration file to: %s", bold(out_file.name))
273223

274224

python/mlc_llm/interface/jit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def log_jit_policy():
4949

5050
def jit( # pylint: disable=too-many-locals,too-many-statements
5151
model_path: Path,
52-
chat_config: Dict[str, Any],
52+
overrides: Dict[str, Any],
5353
device: Union[Device, str],
5454
system_lib_prefix: Optional[str] = None,
5555
*,
@@ -70,7 +70,7 @@ def jit( # pylint: disable=too-many-locals,too-many-statements
7070
lib_suffix = MLC_DSO_SUFFIX if device not in ["iphone", "android"] else "tar"
7171

7272
def _get_optimization_flags() -> str:
73-
opt = chat_config.pop("opt", None)
73+
opt = overrides.pop("opt", None)
7474
if opt is None:
7575
opt = "O2"
7676
return repr(OptimizationFlags.from_str(opt))
@@ -79,7 +79,7 @@ def _get_overrides() -> str:
7979
forbid_list = ["context_window_size", "sliding_window_size", "attention_sink_size"]
8080
result = []
8181
for field in dataclasses.fields(ModelConfigOverride):
82-
value = chat_config.get(field.name, None)
82+
value = overrides.get(field.name, None)
8383
if value is not None:
8484
if field.name in forbid_list and value == -1:
8585
continue
@@ -92,7 +92,7 @@ def _get_model_config() -> Dict[str, Any]:
9292
model_config = mlc_chat_config.pop("model_config")
9393
model_config.update(mlc_chat_config)
9494
for field in dataclasses.fields(ModelConfigOverride):
95-
value = chat_config.get(field.name, None)
95+
value = overrides.get(field.name, None)
9696
if value is not None:
9797
model_config[field.name] = value
9898
return MODELS[model_type].config.from_dict(model_config).asdict()

python/mlc_llm/interface/package.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66
import shutil
77
import subprocess
88
import sys
9-
from dataclasses import asdict
109
from pathlib import Path
1110
from typing import Any, Dict, List, Literal
1211

13-
from mlc_llm.chat_module import ChatConfig, _get_chat_config, _get_model_path
1412
from mlc_llm.interface import jit
15-
from mlc_llm.support import logging, style
13+
from mlc_llm.support import download, logging, style
1614

1715
logging.enable_logging()
1816
logger = logging.getLogger(__name__)
@@ -56,6 +54,7 @@ def build_model_library( # pylint: disable=too-many-branches,too-many-locals,to
5654
bundle_weight = model_entry.get("bundle_weight", False)
5755
overrides = model_entry.get("overrides", {})
5856
model_lib = model_entry.get("model_lib", None)
57+
5958
estimated_vram_bytes = model_entry["estimated_vram_bytes"]
6059
if not isinstance(model, str):
6160
raise ValueError('The value of "model" in "model_list" is expected to be a string.')
@@ -71,12 +70,8 @@ def build_model_library( # pylint: disable=too-many-branches,too-many-locals,to
7170
raise ValueError('The value of "model_lib" in "model_list" is expected to be string.')
7271

7372
# - Load model config. Download happens when needed.
74-
model_path_and_config_file_path = _get_model_path(model)
75-
model_path = Path(model_path_and_config_file_path[0])
76-
config_file_path = model_path_and_config_file_path[1]
77-
chat_config = _get_chat_config(
78-
config_file_path, user_chat_config=ChatConfig.from_dict(overrides)
79-
)
73+
model_path = download.get_or_download_model(model)
74+
8075
# - Jit compile if the model lib path is not specified.
8176
model_lib_path = (
8277
model_lib_path_for_prepare_libs.get(model_lib, None) if model_lib is not None else None
@@ -96,7 +91,7 @@ def build_model_library( # pylint: disable=too-many-branches,too-many-locals,to
9691
model_lib_path, model_lib = dataclasses.astuple(
9792
jit.jit(
9893
model_path=model_path,
99-
chat_config=asdict(chat_config),
94+
overrides=overrides,
10095
device=device,
10196
system_lib_prefix=model_lib,
10297
skip_log_jit_policy=True,

python/mlc_llm/op/position_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments
176176
num_q_heads: int,
177177
num_kv_heads: int,
178178
dtype: str,
179-
rotary_dim: int = None,
179+
rotary_dim: Optional[int] = None,
180180
):
181181
"""Return the TIR function that computes Llama-style RoPE with q position map.
182182

python/mlc_llm/protocol/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""The protocols for MLC LLM server"""
1+
"""Definitions of pydantic models for API entry points and configurations"""
22
from . import openai_api_protocol
33

44
RequestProtocol = openai_api_protocol.CompletionRequest

python/mlc_llm/protocol/error_protocol.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Error protocols in MLC LLM"""
22

33
from http import HTTPStatus
4+
from typing import Optional
45

56
import fastapi
67
from pydantic import BaseModel
@@ -18,7 +19,7 @@ class ErrorResponse(BaseModel):
1819

1920
object: str = "error"
2021
message: str
21-
code: int = None
22+
code: Optional[int] = None
2223

2324

2425
def create_error_response(status_code: HTTPStatus, message: str) -> fastapi.responses.JSONResponse:

0 commit comments

Comments
 (0)