30
30
T = TypeVar ("T" , bound = "BaseModel" )
31
31
32
32
33
+ class OverrideConfigs (BaseModel ):
34
+ """
35
+ The class that specifies the override configurations.
36
+ """
37
+
38
+ context_window_size : Optional [int ] = None
39
+ sliding_window_size : Optional [int ] = None
40
+ prefill_chunk_size : Optional [int ] = None
41
+ attention_sink_size : Optional [int ] = None
42
+ tensor_parallel_shards : Optional [int ] = None
43
+
44
+
33
45
class ModelDeliveryTask (BaseModel ):
34
46
"""
35
47
Example:
@@ -38,21 +50,21 @@ class ModelDeliveryTask(BaseModel):
38
50
"model": "HF://microsoft/Phi-3-mini-128k-instruct",
39
51
"conv_template": "phi-3",
40
52
"quantization": ["q3f16_1"],
41
- "context_window_size": 4096
53
+ "overrides": {
54
+ "q3f16_1": {
55
+ "context_window_size": 512
56
+ }
57
+ }
42
58
}
43
59
"""
44
60
45
61
model_id : str
46
62
model : str
47
63
conv_template : str
48
- quantization : Optional [Union [List [str ], str ]] = Field (default_factory = list )
64
+ quantization : Union [List [str ], str ] = Field (default_factory = list )
65
+ overrides : Dict [str , OverrideConfigs ] = Field (default_factory = dict )
49
66
destination : Optional [str ] = None
50
-
51
- context_window_size : Optional [int ] = None
52
- sliding_window_size : Optional [int ] = None
53
- prefill_chunk_size : Optional [int ] = None
54
- attention_sink_size : Optional [int ] = None
55
- tensor_parallel_shards : Optional [int ] = None
67
+ gen_config_only : Optional [bool ] = False
56
68
57
69
58
70
class ModelDeliveryList (BaseModel ):
@@ -63,7 +75,8 @@ class ModelDeliveryList(BaseModel):
63
75
tasks : List [ModelDeliveryTask ]
64
76
# For delivered log, the default destination and quantization fields are optional
65
77
default_destination : Optional [str ] = None
66
- default_quantization : Optional [List [str ]] = None
78
+ default_quantization : List [str ] = Field (default_factory = list )
79
+ default_overrides : Dict [str , OverrideConfigs ] = Field (default_factory = dict )
67
80
68
81
@classmethod
69
82
def from_json (cls : Type [T ], json_dict : Dict [str , Any ]) -> T :
@@ -115,10 +128,7 @@ def _run_quantization(
115
128
except HfHubHTTPError as error :
116
129
if error .response .status_code != 409 :
117
130
raise
118
- logger .info ("[HF] Repo already exists. Recreating..." )
119
- api .delete_repo (repo_id = repo )
120
- api .create_repo (repo_id = repo , private = False )
121
- logger .info ("[HF] Repo recreated" )
131
+ logger .info ("[HF] Repo already exists. Skipping creation." )
122
132
succeeded = True
123
133
log_path = Path (output_dir ) / "logs.txt"
124
134
with log_path .open ("a" , encoding = "utf-8" ) as log_file :
@@ -147,21 +157,24 @@ def _run_quantization(
147
157
148
158
print (" " .join (cmd ), file = log_file , flush = True )
149
159
subprocess .run (cmd , check = True , stdout = log_file , stderr = subprocess .STDOUT , env = os .environ )
150
- cmd = [
151
- sys .executable ,
152
- "-m" ,
153
- "mlc_llm" ,
154
- "convert_weight" ,
155
- str (model_info .model ),
156
- "--quantization" ,
157
- model_info .quantization ,
158
- "--output" ,
159
- output_dir ,
160
- ]
161
- print (" " .join (cmd ), file = log_file , flush = True )
162
- subprocess .run (cmd , check = False , stdout = log_file , stderr = subprocess .STDOUT , env = os .environ )
160
+ if not model_info .gen_config_only :
161
+ cmd = [
162
+ sys .executable ,
163
+ "-m" ,
164
+ "mlc_llm" ,
165
+ "convert_weight" ,
166
+ str (model_info .model ),
167
+ "--quantization" ,
168
+ model_info .quantization ,
169
+ "--output" ,
170
+ output_dir ,
171
+ ]
172
+ print (" " .join (cmd ), file = log_file , flush = True )
173
+ subprocess .run (
174
+ cmd , check = False , stdout = log_file , stderr = subprocess .STDOUT , env = os .environ
175
+ )
163
176
logger .info ("[MLC] Complete!" )
164
- if not (Path (output_dir ) / "ndarray-cache.json" ).exists ():
177
+ if not (Path (output_dir ) / "ndarray-cache.json" ).exists () and not model_info . gen_config_only :
165
178
logger .error (
166
179
"[%s] Model %s. Quantization %s. No weights metadata found." ,
167
180
red ("FAILED" ),
@@ -175,7 +188,7 @@ def _run_quantization(
175
188
api .upload_folder (
176
189
folder_path = output_dir ,
177
190
repo_id = repo ,
178
- commit_message = "Initial commit" ,
191
+ ignore_patterns = [ "logs.txt" ] ,
179
192
)
180
193
except Exception as exc : # pylint: disable=broad-except
181
194
logger .error ("[%s] %s. Retrying..." , red ("FAILED" ), exc )
@@ -198,38 +211,99 @@ def _get_current_log(log: str) -> ModelDeliveryList:
198
211
return current_log
199
212
200
213
214
+ def _generate_model_delivery_diff ( # pylint: disable=too-many-locals
215
+ spec : ModelDeliveryList , log : ModelDeliveryList
216
+ ) -> ModelDeliveryList :
217
+ diff_tasks = []
218
+ default_quantization = spec .default_quantization
219
+ default_overrides = spec .default_overrides
220
+
221
+ for task in spec .tasks :
222
+ model_id = task .model_id
223
+ conv_template = task .conv_template
224
+ quantization = task .quantization
225
+ overrides = {** default_overrides , ** task .overrides }
226
+
227
+ logger .info ("Checking task: %s %s %s %s" , model_id , conv_template , quantization , overrides )
228
+ log_tasks = [t for t in log .tasks if t .model_id == model_id ]
229
+ delivered_quantizations = set ()
230
+ gen_config_only = set ()
231
+
232
+ for log_task in log_tasks :
233
+ log_quantization = log_task .quantization
234
+ assert isinstance (log_quantization , str )
235
+ log_override = log_task .overrides .get (log_quantization , OverrideConfigs ())
236
+ override = overrides .get (log_quantization , OverrideConfigs ())
237
+ if log_override == override :
238
+ if log_task .conv_template == conv_template :
239
+ delivered_quantizations .add (log_quantization )
240
+ else :
241
+ gen_config_only .add (log_quantization )
242
+
243
+ all_quantizations = set (default_quantization ) | set (quantization )
244
+ quantization_diff = all_quantizations - set (delivered_quantizations )
245
+
246
+ if quantization_diff :
247
+ for q in quantization_diff :
248
+ logger .info ("Adding task %s %s %s to the diff." , model_id , conv_template , q )
249
+ task_copy = task .model_copy ()
250
+ task_copy .quantization = [q ]
251
+ task_copy .overrides = {q : overrides .get (q , OverrideConfigs ())}
252
+ task_copy .gen_config_only = task_copy .gen_config_only or q in gen_config_only
253
+ diff_tasks .append (task_copy )
254
+ else :
255
+ logger .info ("Task %s %s %s is up-to-date." , model_id , conv_template , quantization )
256
+
257
+ diff_config = spec .model_copy ()
258
+ diff_config .default_quantization = []
259
+ diff_config .default_overrides = {}
260
+ diff_config .tasks = diff_tasks
261
+
262
+ logger .info ("Model delivery diff: %s" , diff_config .model_dump_json (indent = 4 , exclude_none = True ))
263
+
264
+ return diff_config
265
+
266
+
201
267
def _main ( # pylint: disable=too-many-locals, too-many-arguments
202
268
username : str ,
203
269
api : HfApi ,
204
270
spec : ModelDeliveryList ,
205
271
log : str ,
206
272
hf_local_dir : Optional [str ],
207
273
output : str ,
274
+ dry_run : bool ,
208
275
):
276
+ delivery_diff = _generate_model_delivery_diff (spec , _get_current_log (log ))
277
+ if dry_run :
278
+ logger .info ("Dry run. No actual delivery." )
279
+ return
280
+
209
281
failed_cases : List [Tuple [str , str ]] = []
210
282
delivered_log = _get_current_log (log )
211
- for task_index , task in enumerate (spec .tasks , 1 ):
283
+ for task_index , task in enumerate (delivery_diff .tasks , 1 ):
212
284
logger .info (
213
285
bold ("[{task_index}/{total_tasks}] Processing model: " ).format (
214
286
task_index = task_index ,
215
- total_tasks = len (spec .tasks ),
287
+ total_tasks = len (delivery_diff .tasks ),
216
288
)
217
289
+ green (task .model_id )
218
290
)
219
291
model = _clone_repo (task .model , hf_local_dir )
220
292
221
293
quantizations = []
222
294
223
- if spec .default_quantization :
224
- quantizations += spec .default_quantization
295
+ if delivery_diff .default_quantization :
296
+ quantizations += delivery_diff .default_quantization
225
297
226
298
if task .quantization :
227
299
if isinstance (task .quantization , str ):
228
300
quantizations .append (task .quantization )
229
301
else :
230
302
quantizations += task .quantization
231
303
232
- default_destination = spec .default_destination or "{username}/{model_id}-{quantization}-MLC"
304
+ default_destination = (
305
+ delivery_diff .default_destination or "{username}/{model_id}-{quantization}-MLC"
306
+ )
233
307
for quantization in quantizations :
234
308
repo = default_destination .format (
235
309
username = username ,
@@ -260,12 +334,19 @@ def _main( # pylint: disable=too-many-locals, too-many-arguments
260
334
(task .model_id , quantization ),
261
335
)
262
336
else :
337
+ delivered_log .tasks = [
338
+ task
339
+ for task in delivered_log .tasks
340
+ if task .model_id != model_info .model_id
341
+ or task .quantization != model_info .quantization
342
+ ]
263
343
delivered_log .tasks .append (model_info )
264
344
if failed_cases :
265
345
logger .info ("Total %s %s:" , len (failed_cases ), red ("failures" ))
266
346
for model_id , quantization in failed_cases :
267
347
logger .info (" Model %s. Quantization %s." , model_id , quantization )
268
348
349
+ delivered_log .tasks .sort (key = lambda task : task .model_id )
269
350
logger .info ("Writing log to %s" , log )
270
351
with open (log , "w" , encoding = "utf-8" ) as o_f :
271
352
json .dump (delivered_log .to_json (), o_f , indent = 4 )
@@ -336,6 +417,11 @@ def _get_default_hf_token() -> str:
336
417
required = False ,
337
418
help = "Local directory to store the downloaded HuggingFace model" ,
338
419
)
420
+ parser .add_argument (
421
+ "--dry-run" ,
422
+ action = "store_true" ,
423
+ help = "Dry run without uploading to HuggingFace Hub" ,
424
+ )
339
425
parsed = parser .parse_args ()
340
426
_main (
341
427
parsed .username ,
@@ -344,6 +430,7 @@ def _get_default_hf_token() -> str:
344
430
api = HfApi (token = parsed .token ),
345
431
hf_local_dir = parsed .hf_local_dir ,
346
432
output = parsed .output ,
433
+ dry_run = parsed .dry_run ,
347
434
)
348
435
349
436
0 commit comments