4
4
import argparse
5
5
import dataclasses
6
6
import json
7
+ import logging
7
8
import os
8
9
import uuid
9
10
15
16
TensorizerConfig ,
16
17
tensorize_lora_adapter ,
17
18
tensorize_vllm_model ,
19
+ tensorizer_kwargs_arg ,
18
20
)
19
21
from vllm .utils import FlexibleArgumentParser
20
22
23
+ logger = logging .getLogger ()
24
+
25
+
21
26
# yapf conflicts with isort for this docstring
22
27
# yapf: disable
23
28
"""
119
124
"""
120
125
121
126
122
- def parse_args ():
127
+ def get_parser ():
123
128
parser = FlexibleArgumentParser (
124
129
description = "An example script that can be used to serialize and "
125
130
"deserialize vLLM models. These models "
@@ -135,13 +140,13 @@ def parse_args():
135
140
required = False ,
136
141
help = "Path to a LoRA adapter to "
137
142
"serialize along with model tensors. This can then be deserialized "
138
- "along with the model by passing a tensorizer_config kwarg to "
139
- "LoRARequest with type TensorizerConfig. See the docstring for this "
140
- "for a usage example. "
141
-
143
+ "along with the model by instantiating a TensorizerConfig object, "
144
+ "creating a dict from it with TensorizerConfig.to_serializable(), "
145
+ "and passing it to LoRARequest's initializer with the kwarg "
146
+ "tensorizer_config_dict."
142
147
)
143
148
144
- subparsers = parser .add_subparsers (dest = 'command' )
149
+ subparsers = parser .add_subparsers (dest = 'command' , required = True )
145
150
146
151
serialize_parser = subparsers .add_parser (
147
152
'serialize' , help = "Serialize a model to `--serialized-directory`" )
@@ -171,6 +176,14 @@ def parse_args():
171
176
"where `suffix` is given by `--suffix` or a random UUID if not "
172
177
"provided." )
173
178
179
+ serialize_parser .add_argument (
180
+ "--serialization-kwargs" ,
181
+ type = tensorizer_kwargs_arg ,
182
+ required = False ,
183
+ help = ("A JSON string containing additional keyword arguments to "
184
+ "pass to Tensorizer's TensorSerializer during "
185
+ "serialization." ))
186
+
174
187
serialize_parser .add_argument (
175
188
"--keyfile" ,
176
189
type = str ,
@@ -186,21 +199,45 @@ def parse_args():
186
199
deserialize_parser .add_argument (
187
200
"--path-to-tensors" ,
188
201
type = str ,
189
- required = True ,
202
+ required = False ,
190
203
help = "The local path or S3 URI to the model tensors to deserialize. " )
191
204
205
+ deserialize_parser .add_argument (
206
+ "--serialized-directory" ,
207
+ type = str ,
208
+ required = False ,
209
+ help = "Directory with model artifacts for loading. Assumes a "
210
+ "model.tensors file exists therein. Can supersede "
211
+ "--path-to-tensors." )
212
+
192
213
deserialize_parser .add_argument (
193
214
"--keyfile" ,
194
215
type = str ,
195
216
required = False ,
196
217
help = ("Path to a binary key to use to decrypt the model weights,"
197
218
" if the model was serialized with encryption" ))
198
219
199
- TensorizerArgs .add_cli_args (deserialize_parser )
220
+ deserialize_parser .add_argument (
221
+ "--deserialization-kwargs" ,
222
+ type = tensorizer_kwargs_arg ,
223
+ required = False ,
224
+ help = ("A JSON string containing additional keyword arguments to "
225
+ "pass to Tensorizer's `TensorDeserializer` during "
226
+ "deserialization." ))
200
227
201
- return parser . parse_args ( )
228
+ TensorizerArgs . add_cli_args ( deserialize_parser )
202
229
230
+ return parser
203
231
232
+ def merge_extra_config_with_tensorizer_config (extra_cfg : dict ,
233
+ cfg : TensorizerConfig ):
234
+ for k , v in extra_cfg .items ():
235
+ if hasattr (cfg , k ):
236
+ setattr (cfg , k , v )
237
+ logger .info (
238
+ "Updating TensorizerConfig with %s from "
239
+ "--model-loader-extra-config provided" , k
240
+ )
204
241
205
242
def deserialize (args , tensorizer_config ):
206
243
if args .lora_path :
@@ -230,7 +267,8 @@ def deserialize(args, tensorizer_config):
230
267
lora_request = LoRARequest ("sql-lora" ,
231
268
1 ,
232
269
args .lora_path ,
233
- tensorizer_config = tensorizer_config )
270
+ tensorizer_config_dict = tensorizer_config
271
+ .to_serializable ())
234
272
)
235
273
)
236
274
else :
@@ -243,7 +281,8 @@ def deserialize(args, tensorizer_config):
243
281
244
282
245
283
def main ():
246
- args = parse_args ()
284
+ parser = get_parser ()
285
+ args = parser .parse_args ()
247
286
248
287
s3_access_key_id = (getattr (args , 's3_access_key_id' , None )
249
288
or os .environ .get ("S3_ACCESS_KEY_ID" , None ))
@@ -265,13 +304,24 @@ def main():
265
304
else :
266
305
keyfile = None
267
306
307
+ extra_config = {}
268
308
if args .model_loader_extra_config :
269
- config = json .loads (args .model_loader_extra_config )
270
- tensorizer_args = \
271
- TensorizerConfig (** config )._construct_tensorizer_args ()
272
- tensorizer_args .tensorizer_uri = args .path_to_tensors
273
- else :
274
- tensorizer_args = None
309
+ extra_config = json .loads (args .model_loader_extra_config )
310
+
311
+
312
+ tensorizer_dir = (args .serialized_directory or
313
+ extra_config .get ("tensorizer_dir" ))
314
+ tensorizer_uri = (getattr (args , "path_to_tensors" , None )
315
+ or extra_config .get ("tensorizer_uri" ))
316
+
317
+ if tensorizer_dir and tensorizer_uri :
318
+ parser .error ("--serialized-directory and --path-to-tensors "
319
+ "cannot both be provided" )
320
+
321
+ if not tensorizer_dir and not tensorizer_uri :
322
+ parser .error ("Either --serialized-directory or --path-to-tensors "
323
+ "must be provided" )
324
+
275
325
276
326
if args .command == "serialize" :
277
327
eng_args_dict = {f .name : getattr (args , f .name ) for f in
@@ -281,7 +331,7 @@ def main():
281
331
argparse .Namespace (** eng_args_dict )
282
332
)
283
333
284
- input_dir = args . serialized_directory .rstrip ('/' )
334
+ input_dir = tensorizer_dir .rstrip ('/' )
285
335
suffix = args .suffix if args .suffix else uuid .uuid4 ().hex
286
336
base_path = f"{ input_dir } /vllm/{ model_ref } /{ suffix } "
287
337
if engine_args .tensor_parallel_size > 1 :
@@ -292,21 +342,29 @@ def main():
292
342
tensorizer_config = TensorizerConfig (
293
343
tensorizer_uri = model_path ,
294
344
encryption_keyfile = keyfile ,
295
- ** credentials )
345
+ serialization_kwargs = args .serialization_kwargs or {},
346
+ ** credentials
347
+ )
296
348
297
349
if args .lora_path :
298
350
tensorizer_config .lora_dir = tensorizer_config .tensorizer_dir
299
351
tensorize_lora_adapter (args .lora_path , tensorizer_config )
300
352
353
+ merge_extra_config_with_tensorizer_config (extra_config ,
354
+ tensorizer_config )
301
355
tensorize_vllm_model (engine_args , tensorizer_config )
302
356
303
357
elif args .command == "deserialize" :
304
- if not tensorizer_args :
305
- tensorizer_config = TensorizerConfig (
306
- tensorizer_uri = args .path_to_tensors ,
307
- encryption_keyfile = keyfile ,
308
- ** credentials
309
- )
358
+ tensorizer_config = TensorizerConfig (
359
+ tensorizer_uri = args .path_to_tensors ,
360
+ tensorizer_dir = args .serialized_directory ,
361
+ encryption_keyfile = keyfile ,
362
+ deserialization_kwargs = args .deserialization_kwargs or {},
363
+ ** credentials
364
+ )
365
+
366
+ merge_extra_config_with_tensorizer_config (extra_config ,
367
+ tensorizer_config )
310
368
deserialize (args , tensorizer_config )
311
369
else :
312
370
raise ValueError ("Either serialize or deserialize must be specified." )
0 commit comments