Skip to content

Commit e283cd0

Browse files
authored
[JIT] Support overriding optimization flags in JIT (#3032)
This PR adds the optimization flags override (`"opt"`) for MLCEngine, chat and serve when running JIT compilation. Prior to this PR, the JIT compilation always uses O2 as the optimization flags.
1 parent 3578e79 commit e283cd0

File tree

5 files changed

+26
-0
lines changed

5 files changed

+26
-0
lines changed

python/mlc_llm/cli/serve.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class EngineConfigOverride: # pylint: disable=too-many-instance-attributes
3131
attention_sink_size: Optional[int] = None
3232
tensor_parallel_shards: Optional[int] = None
3333
pipeline_parallel_stages: Optional[int] = None
34+
opt: Optional[str] = None
3435

3536
def __repr__(self) -> str:
3637
out = StringIO()
@@ -53,6 +54,7 @@ def __repr__(self) -> str:
5354
print(f";attention_sink_size={self.attention_sink_size}", file=out, end="")
5455
print(f";tensor_parallel_shards={self.tensor_parallel_shards}", file=out, end="")
5556
print(f";pipeline_parallel_stages={self.pipeline_parallel_stages}", file=out, end="")
57+
print(f";opt={self.opt}", file=out, end="")
5658
return out.getvalue().rstrip()
5759

5860
@staticmethod
@@ -75,6 +77,7 @@ def from_str(source: str) -> "EngineConfigOverride":
7577
parser.add_argument("--attention_sink_size", type=int, default=None)
7678
parser.add_argument("--tensor_parallel_shards", type=int, default=None)
7779
parser.add_argument("--pipeline_parallel_stages", type=int, default=None)
80+
parser.add_argument("--opt", type=str, default=None)
7881
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
7982
return EngineConfigOverride(
8083
max_num_sequence=results.max_num_sequence,
@@ -92,6 +95,7 @@ def from_str(source: str) -> "EngineConfigOverride":
9295
attention_sink_size=results.attention_sink_size,
9396
tensor_parallel_shards=results.tensor_parallel_shards,
9497
pipeline_parallel_stages=results.pipeline_parallel_stages,
98+
opt=results.opt,
9599
)
96100

97101

@@ -210,6 +214,7 @@ def main(argv):
210214
additional_models=additional_models,
211215
tensor_parallel_shards=parsed.overrides.tensor_parallel_shards,
212216
pipeline_parallel_stages=parsed.overrides.pipeline_parallel_stages,
217+
opt=parsed.overrides.opt,
213218
speculative_mode=parsed.speculative_mode,
214219
prefix_cache_mode=parsed.prefix_cache_mode,
215220
max_num_sequence=parsed.overrides.max_num_sequence,

python/mlc_llm/interface/chat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,15 @@ class ModelConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-insta
8989
attention_sink_size: Optional[int] = None
9090
tensor_parallel_shards: Optional[int] = None
9191
pipeline_parallel_stages: Optional[int] = None
92+
opt: Optional[str] = None
9293

9394
@staticmethod
9495
def from_str(source: str) -> "ModelConfigOverride":
9596
"""Parse model config override values from a string."""
9697
parser = argparse.ArgumentParser(description="model config override values")
9798
parser.add_argument("--tensor_parallel_shards", type=int, default=None)
9899
parser.add_argument("--pipeline_parallel_stages", type=int, default=None)
100+
parser.add_argument("--opt", type=str, default=None)
99101
parser.add_argument("--context_window_size", type=int, default=None)
100102
parser.add_argument("--sliding_window_size", type=int, default=None)
101103
parser.add_argument("--prefill_chunk_size", type=int, default=None)
@@ -105,6 +107,7 @@ def from_str(source: str) -> "ModelConfigOverride":
105107
return ModelConfigOverride(
106108
tensor_parallel_shards=results.tensor_parallel_shards,
107109
pipeline_parallel_stages=results.pipeline_parallel_stages,
110+
opt=results.opt,
108111
context_window_size=results.context_window_size,
109112
sliding_window_size=results.sliding_window_size,
110113
prefill_chunk_size=results.prefill_chunk_size,
@@ -294,6 +297,7 @@ def chat(
294297
attention_sink_size=overrides.attention_sink_size,
295298
tensor_parallel_shards=overrides.tensor_parallel_shards,
296299
pipeline_parallel_stages=overrides.pipeline_parallel_stages,
300+
opt=overrides.opt,
297301
),
298302
)
299303
).chat()

python/mlc_llm/interface/serve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def serve(
2828
additional_models: List[Union[str, Tuple[str, str]]],
2929
tensor_parallel_shards: Optional[int],
3030
pipeline_parallel_stages: Optional[int],
31+
opt: Optional[str],
3132
max_num_sequence: Optional[int],
3233
max_total_sequence_length: Optional[int],
3334
max_single_sequence_length: Optional[int],
@@ -61,6 +62,7 @@ def serve(
6162
additional_models=additional_models,
6263
tensor_parallel_shards=tensor_parallel_shards,
6364
pipeline_parallel_stages=pipeline_parallel_stages,
65+
opt=opt,
6466
max_num_sequence=max_num_sequence,
6567
max_total_sequence_length=max_total_sequence_length,
6668
max_single_sequence_length=max_single_sequence_length,

python/mlc_llm/serve/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,22 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes
4646
4747
tensor_parallel_shards : Optional[int]
4848
Number of shards to split the model into in tensor parallelism multi-gpu inference.
49+
When "model_lib" is given, this field will be ignored, and the tensor_parallel_shards
50+
in the model_lib metadata will be used.
4951
5052
pipeline_parallel_stages : Optional[int]
5153
Number of pipeline stages to split the model layers for pipeline parallelism.
54+
When "model_lib" is given, this field will be ignored, and the pipeline_parallel_stages
55+
in the model_lib metadata will be used.
56+
57+
opt : Optional[str]
58+
The optimization flags for JIT compilation.
59+
When "model_lib" is given, this field will be ignored.
60+
MLC LLM maintains a predefined set of optimization flags,
61+
denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them,
62+
and O3 represents extreme optimization that could potentially break the system.
63+
Meanwhile, optimization flags could be explicitly specified via details knobs, e.g.
64+
"cublas_gemm=1;cudagraph=0".
5265
5366
gpu_memory_utilization : Optional[float]
5467
A number in (0, 1) denoting the fraction of GPU memory used by the server in total.
@@ -127,6 +140,7 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes
127140
mode: Optional[Literal["local", "interactive", "server"]] = None
128141
tensor_parallel_shards: Optional[int] = None
129142
pipeline_parallel_stages: Optional[int] = None
143+
opt: Optional[str] = None
130144
gpu_memory_utilization: Optional[float] = None
131145
kv_cache_page_size: int = 16
132146
max_num_sequence: Optional[int] = None

python/mlc_llm/serve/engine_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]:
159159
"tensor_parallel_shards": engine_config.tensor_parallel_shards,
160160
"pipeline_parallel_stages": engine_config.pipeline_parallel_stages,
161161
"max_batch_size": engine_config.max_num_sequence,
162+
"opt": engine_config.opt,
162163
}
163164

164165
model_lib = jit.jit(

0 commit comments

Comments
 (0)