Skip to content

Commit 6bbd49c

Browse files
authored
[EngineConfig] Add override options (#2550)
This PR introduces override options to the Python side EngineConfig so that they'll be reflected in JIT model compilation.
1 parent 26a9cf0 commit 6bbd49c

File tree

15 files changed

+95
-102
lines changed

15 files changed

+95
-102
lines changed

docs/deploy/cli.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,5 @@ MODEL The model folder after compiling with MLC-LLM build proce
8787
with the device id set to 0 for default.
8888
--overrides Model configuration override. Supports overriding
8989
``context_window_size``, ``prefill_chunk_size``, ``sliding_window_size``, ``attention_sink_size``,
90-
``max_batch_size`` and ``tensor_parallel_shards``. The overrides could be explicitly
90+
and ``tensor_parallel_shards``. The overrides could be explicitly
9191
specified via details knobs, e.g. --overrides ``context_window_size=1024;prefill_chunk_size=128``.

docs/deploy/python_engine.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ for the complete chat completion interface.
9494
.. code:: python
9595
9696
from mlc_llm import MLCEngine
97-
from mlc_llm.serve.config import ModelConfigOverride
97+
from mlc_llm.serve.config import EngineConfig
9898
9999
model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC"
100100
engine = MLCEngine(
101101
model,
102-
model_config_overrides=ModelConfigOverride(tensor_parallel_shards=2),
102+
engine_config=EngineConfig(tensor_parallel_shards=2),
103103
)
104104
105105
@@ -196,12 +196,12 @@ for the complete chat completion interface.
196196
.. code:: python
197197
198198
from mlc_llm import AsyncMLCEngine
199-
from mlc_llm.serve.config import ModelConfigOverride
199+
from mlc_llm.serve.config import EngineConfig
200200
201201
model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC"
202202
engine = AsyncMLCEngine(
203203
model,
204-
model_config_overrides=ModelConfigOverride(tensor_parallel_shards=2),
204+
engine_config=EngineConfig(tensor_parallel_shards=2),
205205
)
206206
207207

docs/get_started/introduction.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,12 @@ If you would like to do concurrent asynchronous generation, you can use :class:`
153153
.. code:: python
154154
155155
from mlc_llm import MLCEngine
156-
from mlc_llm.serve.config import ModelConfigOverride
156+
from mlc_llm.serve.config import EngineConfig
157157
158158
model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC"
159159
engine = MLCEngine(
160160
model,
161-
model_config_overrides=ModelConfigOverride(tensor_parallel_shards=2),
161+
engine_config=EngineConfig(tensor_parallel_shards=2),
162162
)
163163
164164

python/mlc_llm/cli/calibrate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mlc_llm.interface.help import HELP
55
from mlc_llm.support.argparse import ArgumentParser
66

7-
from .serve import EngineAndModelConfigOverride
7+
from .serve import EngineConfigOverride
88

99

1010
def main(argv):
@@ -51,7 +51,7 @@ def main(argv):
5151
)
5252
parser.add_argument(
5353
"--overrides",
54-
type=EngineAndModelConfigOverride.from_str,
54+
type=EngineConfigOverride.from_str,
5555
default="",
5656
help=HELP["overrides_serve"],
5757
)

python/mlc_llm/cli/chat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Command line entrypoint of chat."""
22

3-
from mlc_llm.interface.chat import chat
3+
from mlc_llm.interface.chat import ModelConfigOverride, chat
44
from mlc_llm.interface.help import HELP
5-
from mlc_llm.serve.config import ModelConfigOverride
65
from mlc_llm.support.argparse import ArgumentParser
76

87

python/mlc_llm/cli/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
def main(argv):
28-
"""Parse command line argumennts and call `mlc_llm.compiler.compile`."""
28+
"""Parse command line arguments and call `mlc_llm.compiler.compile`."""
2929

3030
def _parse_output(path: Union[str, Path]) -> Path:
3131
path = Path(path)

python/mlc_llm/cli/serve.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77

88
from mlc_llm.interface.help import HELP
99
from mlc_llm.interface.serve import serve
10-
from mlc_llm.serve.config import ModelConfigOverride
1110
from mlc_llm.support import argparse
1211
from mlc_llm.support.argparse import ArgumentParser
1312

1413

1514
@dataclasses.dataclass
16-
class EngineAndModelConfigOverride: # pylint: disable=too-many-instance-attributes
15+
class EngineConfigOverride: # pylint: disable=too-many-instance-attributes
1716
"""Arguments for overriding engine config."""
1817

1918
# Overrides for EngineConfig (runtime)
@@ -24,8 +23,6 @@ class EngineAndModelConfigOverride: # pylint: disable=too-many-instance-attribu
2423
gpu_memory_utilization: Optional[float] = None
2524
spec_draft_length: Optional[int] = None
2625
prefix_cache_max_num_recycling_seqs: Optional[int] = None
27-
28-
# Overrides for model config (compile time)
2926
context_window_size: Optional[int] = None
3027
sliding_window_size: Optional[int] = None
3128
attention_sink_size: Optional[int] = None
@@ -51,7 +48,7 @@ def __repr__(self) -> str:
5148
return out.getvalue().rstrip()
5249

5350
@staticmethod
54-
def from_str(source: str) -> "EngineAndModelConfigOverride":
51+
def from_str(source: str) -> "EngineConfigOverride":
5552
"""Parse engine config override values from a string."""
5653
parser = argparse.ArgumentParser(description="Engine config override values")
5754

@@ -67,7 +64,7 @@ def from_str(source: str) -> "EngineAndModelConfigOverride":
6764
parser.add_argument("--attention_sink_size", type=int, default=None)
6865
parser.add_argument("--tensor_parallel_shards", type=int, default=None)
6966
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
70-
return EngineAndModelConfigOverride(
67+
return EngineConfigOverride(
7168
max_num_sequence=results.max_num_sequence,
7269
max_total_seq_length=results.max_total_seq_length,
7370
prefill_chunk_size=results.prefill_chunk_size,
@@ -81,17 +78,6 @@ def from_str(source: str) -> "EngineAndModelConfigOverride":
8178
tensor_parallel_shards=results.tensor_parallel_shards,
8279
)
8380

84-
def to_model_config_overrides(self) -> ModelConfigOverride:
85-
"""Extract the model config overrides."""
86-
return ModelConfigOverride(
87-
context_window_size=self.context_window_size,
88-
sliding_window_size=self.sliding_window_size,
89-
prefill_chunk_size=self.prefill_chunk_size,
90-
attention_sink_size=self.attention_sink_size,
91-
max_batch_size=self.max_num_sequence,
92-
tensor_parallel_shards=self.tensor_parallel_shards,
93-
)
94-
9581

9682
def main(argv):
9783
"""Parse command line arguments and call `mlc_llm.interface.serve`."""
@@ -145,7 +131,7 @@ def main(argv):
145131
)
146132
parser.add_argument(
147133
"--overrides",
148-
type=EngineAndModelConfigOverride.from_str,
134+
type=EngineConfigOverride.from_str,
149135
default="",
150136
help=HELP["overrides_serve"],
151137
)
@@ -199,16 +185,19 @@ def main(argv):
199185
mode=parsed.mode,
200186
enable_debug=parsed.enable_debug,
201187
additional_models=additional_models,
188+
tensor_parallel_shards=parsed.overrides.tensor_parallel_shards,
202189
speculative_mode=parsed.speculative_mode,
203190
prefix_cache_mode=parsed.prefix_cache_mode,
204191
max_num_sequence=parsed.overrides.max_num_sequence,
205192
max_total_sequence_length=parsed.overrides.max_total_seq_length,
193+
max_single_sequence_length=parsed.overrides.context_window_size,
206194
prefill_chunk_size=parsed.overrides.prefill_chunk_size,
195+
sliding_window_size=parsed.overrides.sliding_window_size,
196+
attention_sink_size=parsed.overrides.attention_sink_size,
207197
max_history_size=parsed.overrides.max_history_size,
208198
gpu_memory_utilization=parsed.overrides.gpu_memory_utilization,
209199
spec_draft_length=parsed.overrides.spec_draft_length,
210200
prefix_cache_max_num_recycling_seqs=parsed.overrides.prefix_cache_max_num_recycling_seqs,
211-
model_config_overrides=parsed.overrides.to_model_config_overrides(),
212201
enable_tracing=parsed.enable_tracing,
213202
host=parsed.host,
214203
port=parsed.port,

python/mlc_llm/interface/chat.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from mlc_llm.json_ffi import JSONFFIEngine
1010
from mlc_llm.protocol import openai_api_protocol
11-
from mlc_llm.serve.config import EngineConfig, ModelConfigOverride
11+
from mlc_llm.serve.config import EngineConfig
1212
from mlc_llm.serve.engine import MLCEngine
1313
from mlc_llm.serve.engine_base import _query_engine_metrics
1414
from mlc_llm.support import argparse
@@ -79,6 +79,36 @@ def from_str(source: str) -> "ChatCompletionOverride":
7979
)
8080

8181

82+
@dataclasses.dataclass
83+
class ModelConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes
84+
"""Flags for overriding model config."""
85+
86+
context_window_size: Optional[int] = None
87+
sliding_window_size: Optional[int] = None
88+
prefill_chunk_size: Optional[int] = None
89+
attention_sink_size: Optional[int] = None
90+
tensor_parallel_shards: Optional[int] = None
91+
92+
@staticmethod
93+
def from_str(source: str) -> "ModelConfigOverride":
94+
"""Parse model config override values from a string."""
95+
parser = argparse.ArgumentParser(description="model config override values")
96+
parser.add_argument("--tensor_parallel_shards", type=int, default=None)
97+
parser.add_argument("--context_window_size", type=int, default=None)
98+
parser.add_argument("--sliding_window_size", type=int, default=None)
99+
parser.add_argument("--prefill_chunk_size", type=int, default=None)
100+
parser.add_argument("--attention_sink_size", type=int, default=None)
101+
102+
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
103+
return ModelConfigOverride(
104+
tensor_parallel_shards=results.tensor_parallel_shards,
105+
context_window_size=results.context_window_size,
106+
sliding_window_size=results.sliding_window_size,
107+
prefill_chunk_size=results.prefill_chunk_size,
108+
attention_sink_size=results.attention_sink_size,
109+
)
110+
111+
82112
class ChatState:
83113
"""Simple helper class to manage chat state.
84114
@@ -255,8 +285,11 @@ def chat(
255285
model_lib=model_lib,
256286
mode="interactive",
257287
engine_config=EngineConfig(
288+
max_single_sequence_length=overrides.context_window_size,
258289
prefill_chunk_size=overrides.prefill_chunk_size,
290+
sliding_window_size=overrides.sliding_window_size,
291+
attention_sink_size=overrides.attention_sink_size,
292+
tensor_parallel_shards=overrides.tensor_parallel_shards,
259293
),
260-
model_config_overrides=overrides,
261294
)
262295
).chat()

python/mlc_llm/interface/help.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
"modelconfig_overrides": """
129129
Model configuration override. Supports overriding,
130130
`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`,
131-
`max_batch_size` and `tensor_parallel_shards`. The overrides could be explicitly
131+
`max_num_sequence` and `tensor_parallel_shards`. The overrides could be explicitly
132132
specified via details knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128".
133133
""".strip(),
134134
"debug_dump": """

python/mlc_llm/interface/serve.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from mlc_llm.protocol import error_protocol
1010
from mlc_llm.serve import engine
11-
from mlc_llm.serve.config import ModelConfigOverride
1211
from mlc_llm.serve.entrypoints import (
1312
debug_entrypoints,
1413
metrics_entrypoints,
@@ -27,16 +26,19 @@ def serve(
2726
mode: Literal["local", "interactive", "server"],
2827
enable_debug: bool,
2928
additional_models: List[Union[str, Tuple[str, str]]],
29+
tensor_parallel_shards: Optional[int],
3030
max_num_sequence: Optional[int],
3131
max_total_sequence_length: Optional[int],
32+
max_single_sequence_length: Optional[int],
3233
prefill_chunk_size: Optional[int],
34+
sliding_window_size: Optional[int],
35+
attention_sink_size: Optional[int],
3336
max_history_size: Optional[int],
3437
gpu_memory_utilization: Optional[float],
3538
speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"],
3639
spec_draft_length: Optional[int],
3740
prefix_cache_mode: Literal["disable", "radix"],
3841
prefix_cache_max_num_recycling_seqs: Optional[int],
39-
model_config_overrides: Optional[ModelConfigOverride],
4042
enable_tracing: bool,
4143
host: str,
4244
port: int,
@@ -54,17 +56,20 @@ def serve(
5456
mode=mode,
5557
engine_config=engine.EngineConfig(
5658
additional_models=additional_models,
59+
tensor_parallel_shards=tensor_parallel_shards,
5760
max_num_sequence=max_num_sequence,
5861
max_total_sequence_length=max_total_sequence_length,
62+
max_single_sequence_length=max_single_sequence_length,
5963
prefill_chunk_size=prefill_chunk_size,
64+
sliding_window_size=sliding_window_size,
65+
attention_sink_size=attention_sink_size,
6066
max_history_size=max_history_size,
6167
gpu_memory_utilization=gpu_memory_utilization,
6268
speculative_mode=speculative_mode,
6369
spec_draft_length=spec_draft_length,
6470
prefix_cache_mode=prefix_cache_mode,
6571
prefix_cache_max_num_recycling_seqs=prefix_cache_max_num_recycling_seqs,
6672
),
67-
model_config_overrides=model_config_overrides,
6873
enable_tracing=enable_tracing,
6974
)
7075

0 commit comments

Comments
 (0)