@@ -651,10 +651,31 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
651
651
logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
652
652
edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
653
653
_get_source_transforms (
654
- modelname = args .model ,
655
654
dtype_override = dtype_override ,
655
+ checkpoint = args .checkpoint ,
656
656
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
657
- args = args ,
657
+ tokenizer_path = args .tokenizer_path ,
658
+ use_spin_quant = args .use_spin_quant ,
659
+ embedding_quantize = args .embedding_quantize ,
660
+ quantization_mode = args .quantization_mode ,
661
+ expand_rope_table = args .expand_rope_table ,
662
+ use_custom_sdpa_with_attention_mask = getattr (args , "use_custom_sdpa_with_attention_mask" , False ),
663
+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
664
+ quantize_kv_cache = args .quantize_kv_cache ,
665
+ use_kv_cache = args .use_kv_cache ,
666
+ qnn = args .qnn ,
667
+ use_qnn_sha = args .use_qnn_sha ,
668
+ optimized_rotation_path = args .optimized_rotation_path ,
669
+ mps = args .mps ,
670
+ coreml = args .coreml ,
671
+ coreml_ios = args .coreml_ios ,
672
+ vulkan = args .vulkan ,
673
+ use_shared_embedding = args .use_shared_embedding ,
674
+ use_qat = args .use_qat ,
675
+ use_lora = args .use_lora ,
676
+ preq_mode = args .preq_mode ,
677
+ preq_group_size = args .preq_group_size ,
678
+ preq_embedding_quantize = args .preq_embedding_quantize ,
658
679
)
659
680
)
660
681
@@ -1145,23 +1166,65 @@ def _load_llama_model(
1145
1166
1146
1167
1147
1168
def _get_source_transforms ( # noqa
1148
- modelname : str ,
1149
1169
dtype_override : DType ,
1150
1170
* ,
1171
+ checkpoint : Optional [str ] = None ,
1151
1172
checkpoint_dtype : Optional [DType ] = None ,
1152
- args ,
1173
+ tokenizer_path : Optional [str ] = None ,
1174
+ use_spin_quant : Optional [str ] = None ,
1175
+ embedding_quantize : Optional [str ] = None ,
1176
+ quantization_mode : Optional [str ] = None ,
1177
+ expand_rope_table : bool = False ,
1178
+ use_custom_sdpa_with_attention_mask : bool = False ,
1179
+ use_sdpa_with_kv_cache : bool = False ,
1180
+ quantize_kv_cache : bool = False ,
1181
+ use_kv_cache : bool = False ,
1182
+ qnn : bool = False ,
1183
+ use_qnn_sha : bool = False ,
1184
+ optimized_rotation_path : Optional [str ] = None ,
1185
+ mps : bool = False ,
1186
+ coreml : bool = False ,
1187
+ coreml_ios : int = 15 ,
1188
+ vulkan : bool = False ,
1189
+ use_shared_embedding : bool = False ,
1190
+ use_qat : bool = False ,
1191
+ use_lora : int = 0 ,
1192
+ preq_mode : Optional [str ] = None ,
1193
+ preq_group_size : int = 32 ,
1194
+ preq_embedding_quantize : str = "8,0" ,
1153
1195
) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
1154
1196
"""
1155
1197
Return a list of functions that transform a graph.
1156
1198
1157
1199
Args:
1158
- modelname: The name of the model.
1159
1200
dtype_override: The dtype to use for the model.
1201
+ checkpoint: Path to the checkpoint file.
1160
1202
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1161
1203
it means that you want to run quantize transformations on the weights represented
1162
1204
in their original dtype, while the overall dtype of the model maybe something
1163
1205
different. If not specified, defaults to dtype_override.
1164
- args: The arguments passed to the script.
1206
+ tokenizer_path: Path to the tokenizer file.
1207
+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1208
+ embedding_quantize: Type of embedding quantization.
1209
+ quantization_mode: Type of quantization mode.
1210
+ expand_rope_table: Whether to expand rope table.
1211
+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1212
+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1213
+ quantize_kv_cache: Whether to quantize KV cache.
1214
+ use_kv_cache: Whether to use KV cache.
1215
+ qnn: Whether to use QNN.
1216
+ use_qnn_sha: Whether to use QNN SHA.
1217
+ optimized_rotation_path: Path to optimized rotation.
1218
+ mps: Whether to use MPS.
1219
+ coreml: Whether to use CoreML.
1220
+ coreml_ios: CoreML iOS version.
1221
+ vulkan: Whether to use Vulkan.
1222
+ use_shared_embedding: Whether to use shared embedding.
1223
+ use_qat: Whether to use QAT.
1224
+ use_lora: LoRA rank (0 means no LoRA).
1225
+ preq_mode: Pre-quantization mode.
1226
+ preq_group_size: Pre-quantization group size.
1227
+ preq_embedding_quantize: Pre-quantization embedding quantize.
1165
1228
1166
1229
Returns:
1167
1230
A list of transformation functions.
@@ -1172,21 +1235,21 @@ def _get_source_transforms( # noqa
1172
1235
1173
1236
transforms = []
1174
1237
1175
- if args . use_spin_quant :
1176
- if args . use_spin_quant == "cuda" :
1238
+ if use_spin_quant :
1239
+ if use_spin_quant == "cuda" :
1177
1240
from .source_transformation .spin_quant import (
1178
1241
inject_fast_hadamard_transform_cuda_for_spin_quant ,
1179
1242
)
1180
1243
1181
1244
transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1182
- elif args . use_spin_quant == "native" :
1245
+ elif use_spin_quant == "native" :
1183
1246
from .source_transformation .spin_quant import (
1184
1247
inject_fast_hadamard_transform_native_for_spin_quant ,
1185
1248
)
1186
1249
1187
1250
transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
1188
1251
1189
- if args . embedding_quantize :
1252
+ if embedding_quantize :
1190
1253
"""
1191
1254
When this option is selected, it finds all embedding layers and transforms
1192
1255
into quantized embedding equivalent module.
@@ -1196,12 +1259,25 @@ def _get_source_transforms( # noqa
1196
1259
transformations based on the given checkpoint first. In those cases,
1197
1260
this wil be a no-op.
1198
1261
"""
1199
- modelname = f"{ modelname } _e"
1262
+ # Create a mock args object with the necessary attributes
1263
+ class Args :
1264
+ pass
1265
+ args = Args ()
1266
+ args .checkpoint = checkpoint
1267
+ args .tokenizer_path = tokenizer_path
1268
+ args .embedding_quantize = embedding_quantize
1269
+ args .use_shared_embedding = use_shared_embedding
1270
+ args .use_qat = use_qat
1271
+ args .use_lora = use_lora
1272
+ args .preq_mode = preq_mode
1273
+ args .preq_group_size = preq_group_size
1274
+ args .preq_embedding_quantize = preq_embedding_quantize
1275
+
1200
1276
transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
1201
1277
1202
1278
# quantization_mode should be applied after embedding_quantize
1203
1279
# to support shared_embedding
1204
- if args . quantization_mode :
1280
+ if quantization_mode :
1205
1281
"""
1206
1282
When this option is selected, it finds all linear layers and transforms
1207
1283
into quantized linear equivalent module.
@@ -1215,7 +1291,19 @@ def _get_source_transforms( # noqa
1215
1291
There are cases where this may be a no-op, namely, if all linears are
1216
1292
quantized in the checkpoint.
1217
1293
"""
1218
- modelname = f"{ modelname } _q"
1294
+ # Create a mock args object with the necessary attributes
1295
+ class Args :
1296
+ pass
1297
+ args = Args ()
1298
+ args .checkpoint = checkpoint
1299
+ args .tokenizer_path = tokenizer_path
1300
+ args .quantization_mode = quantization_mode
1301
+ args .group_size = preq_group_size # Using preq_group_size as group_size
1302
+ args .use_shared_embedding = use_shared_embedding
1303
+ args .use_qat = use_qat
1304
+ args .use_lora = use_lora
1305
+ args .preq_mode = preq_mode
1306
+
1219
1307
transforms .append (
1220
1308
get_quant_weight_transform (
1221
1309
args = args ,
@@ -1224,15 +1312,12 @@ def _get_source_transforms( # noqa
1224
1312
)
1225
1313
)
1226
1314
1227
- if args . expand_rope_table :
1315
+ if expand_rope_table :
1228
1316
transforms .append (materialze_broadcast_of_rope_freq_cis )
1229
1317
1230
- use_attention_mask_for_custom_sdpa = False
1231
- if isinstance (args , argparse .Namespace ):
1232
- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1233
- use_attention_mask_for_custom_sdpa = True
1318
+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
1234
1319
1235
- if args . use_sdpa_with_kv_cache :
1320
+ if use_sdpa_with_kv_cache :
1236
1321
transforms .append (replace_kv_cache_with_custom_kv_cache )
1237
1322
# todo: do this optionally
1238
1323
# if use attention mask instead of causal attention
@@ -1244,23 +1329,23 @@ def _get_source_transforms( # noqa
1244
1329
else :
1245
1330
transforms .append (replace_sdpa_with_custom_op )
1246
1331
1247
- if args . quantize_kv_cache :
1248
- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1332
+ if quantize_kv_cache :
1333
+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1249
1334
transforms .append (replace_kv_cache_with_quantized_kv_cache )
1250
1335
# Right now
1251
1336
transforms .append (replace_sdpa_with_quantized_sdpa )
1252
1337
1253
- if args . use_kv_cache :
1254
- if args . qnn :
1338
+ if use_kv_cache :
1339
+ if qnn :
1255
1340
from executorch .backends .qualcomm .utils .utils import (
1256
1341
convert_linear_to_conv2d ,
1257
1342
)
1258
1343
1259
- if args . use_qnn_sha :
1260
- if args . optimized_rotation_path :
1344
+ if use_qnn_sha :
1345
+ if optimized_rotation_path :
1261
1346
transforms .append (fuse_layer_norms )
1262
1347
transforms .append (
1263
- get_model_with_r1_r2 (args . optimized_rotation_path )
1348
+ get_model_with_r1_r2 (optimized_rotation_path )
1264
1349
)
1265
1350
transforms .append (replace_attention_to_attention_sha )
1266
1351
transforms .append (replace_causal_mask )
@@ -1272,29 +1357,29 @@ def _get_source_transforms( # noqa
1272
1357
transforms .append (replace_sdpa_with_flex_sdpa )
1273
1358
transforms .append (replace_causal_mask )
1274
1359
transforms .append (replace_rms_norm_with_native_rms_norm )
1275
- if args . optimized_rotation_path :
1360
+ if optimized_rotation_path :
1276
1361
transforms .append (fuse_layer_norms )
1277
1362
transforms .append (
1278
- get_model_with_r1_r2 (args . optimized_rotation_path )
1363
+ get_model_with_r1_r2 (optimized_rotation_path )
1279
1364
)
1280
1365
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1281
1366
transforms .append (convert_linear_to_conv2d )
1282
1367
1283
- elif args . mps :
1368
+ elif mps :
1284
1369
# Currently mps doesn't support sdpa op, use the simpler decomposition
1285
1370
# to get free perf gain.
1286
1371
transforms .append (replace_sdpa_with_simple_sdpa )
1287
1372
transforms .append (replace_causal_mask )
1288
1373
1289
- elif args . coreml :
1374
+ elif coreml :
1290
1375
# iOS 18 introduced fused sdpa op
1291
- if args . coreml_ios >= 18 :
1376
+ if coreml_ios >= 18 :
1292
1377
transforms .append (replace_sdpa_with_coreml_sdpa )
1293
1378
else :
1294
1379
transforms .append (replace_sdpa_with_simple_sdpa )
1295
1380
transforms .append (replace_kv_cache_with_coreml_kv_cache )
1296
1381
1297
- if args . vulkan :
1382
+ if vulkan :
1298
1383
transforms .append (replace_with_vulkan_rotary_emb )
1299
1384
1300
1385
return transforms
0 commit comments