Skip to content

Commit a231ae1

Browse files
authored
[Delivery] Update model delivery script (#2565)
Some improvements of the delivery script: - provide different overrides for different quantization. e.g. we can change prefill chunk size for q0/q3/q4 - rerun gen config only if only conv_template changes - do NOT recreate HF repo when the repo already exists. This will preserve commit history - dry-run validation
1 parent 42f146d commit a231ae1

File tree

1 file changed

+120
-33
lines changed

1 file changed

+120
-33
lines changed

python/mlc_llm/cli/delivery.py

Lines changed: 120 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@
3030
T = TypeVar("T", bound="BaseModel")
3131

3232

33+
class OverrideConfigs(BaseModel):
34+
"""
35+
The class that specifies the override configurations.
36+
"""
37+
38+
context_window_size: Optional[int] = None
39+
sliding_window_size: Optional[int] = None
40+
prefill_chunk_size: Optional[int] = None
41+
attention_sink_size: Optional[int] = None
42+
tensor_parallel_shards: Optional[int] = None
43+
44+
3345
class ModelDeliveryTask(BaseModel):
3446
"""
3547
Example:
@@ -38,21 +50,21 @@ class ModelDeliveryTask(BaseModel):
3850
"model": "HF://microsoft/Phi-3-mini-128k-instruct",
3951
"conv_template": "phi-3",
4052
"quantization": ["q3f16_1"],
41-
"context_window_size": 4096
53+
"overrides": {
54+
"q3f16_1": {
55+
"context_window_size": 512
56+
}
57+
}
4258
}
4359
"""
4460

4561
model_id: str
4662
model: str
4763
conv_template: str
48-
quantization: Optional[Union[List[str], str]] = Field(default_factory=list)
64+
quantization: Union[List[str], str] = Field(default_factory=list)
65+
overrides: Dict[str, OverrideConfigs] = Field(default_factory=dict)
4966
destination: Optional[str] = None
50-
51-
context_window_size: Optional[int] = None
52-
sliding_window_size: Optional[int] = None
53-
prefill_chunk_size: Optional[int] = None
54-
attention_sink_size: Optional[int] = None
55-
tensor_parallel_shards: Optional[int] = None
67+
gen_config_only: Optional[bool] = False
5668

5769

5870
class ModelDeliveryList(BaseModel):
@@ -63,7 +75,8 @@ class ModelDeliveryList(BaseModel):
6375
tasks: List[ModelDeliveryTask]
6476
# For delivered log, the default destination and quantization fields are optional
6577
default_destination: Optional[str] = None
66-
default_quantization: Optional[List[str]] = None
78+
default_quantization: List[str] = Field(default_factory=list)
79+
default_overrides: Dict[str, OverrideConfigs] = Field(default_factory=dict)
6780

6881
@classmethod
6982
def from_json(cls: Type[T], json_dict: Dict[str, Any]) -> T:
@@ -115,10 +128,7 @@ def _run_quantization(
115128
except HfHubHTTPError as error:
116129
if error.response.status_code != 409:
117130
raise
118-
logger.info("[HF] Repo already exists. Recreating...")
119-
api.delete_repo(repo_id=repo)
120-
api.create_repo(repo_id=repo, private=False)
121-
logger.info("[HF] Repo recreated")
131+
logger.info("[HF] Repo already exists. Skipping creation.")
122132
succeeded = True
123133
log_path = Path(output_dir) / "logs.txt"
124134
with log_path.open("a", encoding="utf-8") as log_file:
@@ -147,21 +157,24 @@ def _run_quantization(
147157

148158
print(" ".join(cmd), file=log_file, flush=True)
149159
subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ)
150-
cmd = [
151-
sys.executable,
152-
"-m",
153-
"mlc_llm",
154-
"convert_weight",
155-
str(model_info.model),
156-
"--quantization",
157-
model_info.quantization,
158-
"--output",
159-
output_dir,
160-
]
161-
print(" ".join(cmd), file=log_file, flush=True)
162-
subprocess.run(cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ)
160+
if not model_info.gen_config_only:
161+
cmd = [
162+
sys.executable,
163+
"-m",
164+
"mlc_llm",
165+
"convert_weight",
166+
str(model_info.model),
167+
"--quantization",
168+
model_info.quantization,
169+
"--output",
170+
output_dir,
171+
]
172+
print(" ".join(cmd), file=log_file, flush=True)
173+
subprocess.run(
174+
cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ
175+
)
163176
logger.info("[MLC] Complete!")
164-
if not (Path(output_dir) / "ndarray-cache.json").exists():
177+
if not (Path(output_dir) / "ndarray-cache.json").exists() and not model_info.gen_config_only:
165178
logger.error(
166179
"[%s] Model %s. Quantization %s. No weights metadata found.",
167180
red("FAILED"),
@@ -175,7 +188,7 @@ def _run_quantization(
175188
api.upload_folder(
176189
folder_path=output_dir,
177190
repo_id=repo,
178-
commit_message="Initial commit",
191+
ignore_patterns=["logs.txt"],
179192
)
180193
except Exception as exc: # pylint: disable=broad-except
181194
logger.error("[%s] %s. Retrying...", red("FAILED"), exc)
@@ -198,38 +211,99 @@ def _get_current_log(log: str) -> ModelDeliveryList:
198211
return current_log
199212

200213

214+
def _generate_model_delivery_diff( # pylint: disable=too-many-locals
215+
spec: ModelDeliveryList, log: ModelDeliveryList
216+
) -> ModelDeliveryList:
217+
diff_tasks = []
218+
default_quantization = spec.default_quantization
219+
default_overrides = spec.default_overrides
220+
221+
for task in spec.tasks:
222+
model_id = task.model_id
223+
conv_template = task.conv_template
224+
quantization = task.quantization
225+
overrides = {**default_overrides, **task.overrides}
226+
227+
logger.info("Checking task: %s %s %s %s", model_id, conv_template, quantization, overrides)
228+
log_tasks = [t for t in log.tasks if t.model_id == model_id]
229+
delivered_quantizations = set()
230+
gen_config_only = set()
231+
232+
for log_task in log_tasks:
233+
log_quantization = log_task.quantization
234+
assert isinstance(log_quantization, str)
235+
log_override = log_task.overrides.get(log_quantization, OverrideConfigs())
236+
override = overrides.get(log_quantization, OverrideConfigs())
237+
if log_override == override:
238+
if log_task.conv_template == conv_template:
239+
delivered_quantizations.add(log_quantization)
240+
else:
241+
gen_config_only.add(log_quantization)
242+
243+
all_quantizations = set(default_quantization) | set(quantization)
244+
quantization_diff = all_quantizations - set(delivered_quantizations)
245+
246+
if quantization_diff:
247+
for q in quantization_diff:
248+
logger.info("Adding task %s %s %s to the diff.", model_id, conv_template, q)
249+
task_copy = task.model_copy()
250+
task_copy.quantization = [q]
251+
task_copy.overrides = {q: overrides.get(q, OverrideConfigs())}
252+
task_copy.gen_config_only = task_copy.gen_config_only or q in gen_config_only
253+
diff_tasks.append(task_copy)
254+
else:
255+
logger.info("Task %s %s %s is up-to-date.", model_id, conv_template, quantization)
256+
257+
diff_config = spec.model_copy()
258+
diff_config.default_quantization = []
259+
diff_config.default_overrides = {}
260+
diff_config.tasks = diff_tasks
261+
262+
logger.info("Model delivery diff: %s", diff_config.model_dump_json(indent=4, exclude_none=True))
263+
264+
return diff_config
265+
266+
201267
def _main( # pylint: disable=too-many-locals, too-many-arguments
202268
username: str,
203269
api: HfApi,
204270
spec: ModelDeliveryList,
205271
log: str,
206272
hf_local_dir: Optional[str],
207273
output: str,
274+
dry_run: bool,
208275
):
276+
delivery_diff = _generate_model_delivery_diff(spec, _get_current_log(log))
277+
if dry_run:
278+
logger.info("Dry run. No actual delivery.")
279+
return
280+
209281
failed_cases: List[Tuple[str, str]] = []
210282
delivered_log = _get_current_log(log)
211-
for task_index, task in enumerate(spec.tasks, 1):
283+
for task_index, task in enumerate(delivery_diff.tasks, 1):
212284
logger.info(
213285
bold("[{task_index}/{total_tasks}] Processing model: ").format(
214286
task_index=task_index,
215-
total_tasks=len(spec.tasks),
287+
total_tasks=len(delivery_diff.tasks),
216288
)
217289
+ green(task.model_id)
218290
)
219291
model = _clone_repo(task.model, hf_local_dir)
220292

221293
quantizations = []
222294

223-
if spec.default_quantization:
224-
quantizations += spec.default_quantization
295+
if delivery_diff.default_quantization:
296+
quantizations += delivery_diff.default_quantization
225297

226298
if task.quantization:
227299
if isinstance(task.quantization, str):
228300
quantizations.append(task.quantization)
229301
else:
230302
quantizations += task.quantization
231303

232-
default_destination = spec.default_destination or "{username}/{model_id}-{quantization}-MLC"
304+
default_destination = (
305+
delivery_diff.default_destination or "{username}/{model_id}-{quantization}-MLC"
306+
)
233307
for quantization in quantizations:
234308
repo = default_destination.format(
235309
username=username,
@@ -260,12 +334,19 @@ def _main( # pylint: disable=too-many-locals, too-many-arguments
260334
(task.model_id, quantization),
261335
)
262336
else:
337+
delivered_log.tasks = [
338+
task
339+
for task in delivered_log.tasks
340+
if task.model_id != model_info.model_id
341+
or task.quantization != model_info.quantization
342+
]
263343
delivered_log.tasks.append(model_info)
264344
if failed_cases:
265345
logger.info("Total %s %s:", len(failed_cases), red("failures"))
266346
for model_id, quantization in failed_cases:
267347
logger.info(" Model %s. Quantization %s.", model_id, quantization)
268348

349+
delivered_log.tasks.sort(key=lambda task: task.model_id)
269350
logger.info("Writing log to %s", log)
270351
with open(log, "w", encoding="utf-8") as o_f:
271352
json.dump(delivered_log.to_json(), o_f, indent=4)
@@ -336,6 +417,11 @@ def _get_default_hf_token() -> str:
336417
required=False,
337418
help="Local directory to store the downloaded HuggingFace model",
338419
)
420+
parser.add_argument(
421+
"--dry-run",
422+
action="store_true",
423+
help="Dry run without uploading to HuggingFace Hub",
424+
)
339425
parsed = parser.parse_args()
340426
_main(
341427
parsed.username,
@@ -344,6 +430,7 @@ def _get_default_hf_token() -> str:
344430
api=HfApi(token=parsed.token),
345431
hf_local_dir=parsed.hf_local_dir,
346432
output=parsed.output,
433+
dry_run=parsed.dry_run,
347434
)
348435

349436

0 commit comments

Comments
 (0)