Skip to content

Commit 51a0583

Browse files
committed
refactor _get_source_transforms to remove args parameter
1 parent 2837867 commit 51a0583

File tree

1 file changed

+117
-32
lines changed

1 file changed

+117
-32
lines changed

examples/models/llama/export_llama_lib.py

+117-32
Original file line numberDiff line numberDiff line change
@@ -651,10 +651,31 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
651651
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
652652
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
653653
_get_source_transforms(
654-
modelname=args.model,
655654
dtype_override=dtype_override,
655+
checkpoint=args.checkpoint,
656656
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
657-
args=args,
657+
tokenizer_path=args.tokenizer_path,
658+
use_spin_quant=args.use_spin_quant,
659+
embedding_quantize=args.embedding_quantize,
660+
quantization_mode=args.quantization_mode,
661+
expand_rope_table=args.expand_rope_table,
662+
use_custom_sdpa_with_attention_mask=getattr(args, "use_custom_sdpa_with_attention_mask", False),
663+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
664+
quantize_kv_cache=args.quantize_kv_cache,
665+
use_kv_cache=args.use_kv_cache,
666+
qnn=args.qnn,
667+
use_qnn_sha=args.use_qnn_sha,
668+
optimized_rotation_path=args.optimized_rotation_path,
669+
mps=args.mps,
670+
coreml=args.coreml,
671+
coreml_ios=args.coreml_ios,
672+
vulkan=args.vulkan,
673+
use_shared_embedding=args.use_shared_embedding,
674+
use_qat=args.use_qat,
675+
use_lora=args.use_lora,
676+
preq_mode=args.preq_mode,
677+
preq_group_size=args.preq_group_size,
678+
preq_embedding_quantize=args.preq_embedding_quantize,
658679
)
659680
)
660681

@@ -1145,23 +1166,65 @@ def _load_llama_model(
11451166

11461167

11471168
def _get_source_transforms( # noqa
1148-
modelname: str,
11491169
dtype_override: DType,
11501170
*,
1171+
checkpoint: Optional[str] = None,
11511172
checkpoint_dtype: Optional[DType] = None,
1152-
args,
1173+
tokenizer_path: Optional[str] = None,
1174+
use_spin_quant: Optional[str] = None,
1175+
embedding_quantize: Optional[str] = None,
1176+
quantization_mode: Optional[str] = None,
1177+
expand_rope_table: bool = False,
1178+
use_custom_sdpa_with_attention_mask: bool = False,
1179+
use_sdpa_with_kv_cache: bool = False,
1180+
quantize_kv_cache: bool = False,
1181+
use_kv_cache: bool = False,
1182+
qnn: bool = False,
1183+
use_qnn_sha: bool = False,
1184+
optimized_rotation_path: Optional[str] = None,
1185+
mps: bool = False,
1186+
coreml: bool = False,
1187+
coreml_ios: int = 15,
1188+
vulkan: bool = False,
1189+
use_shared_embedding: bool = False,
1190+
use_qat: bool = False,
1191+
use_lora: int = 0,
1192+
preq_mode: Optional[str] = None,
1193+
preq_group_size: int = 32,
1194+
preq_embedding_quantize: str = "8,0",
11531195
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
11541196
"""
11551197
Return a list of functions that transform a graph.
11561198
11571199
Args:
1158-
modelname: The name of the model.
11591200
dtype_override: The dtype to use for the model.
1201+
checkpoint: Path to the checkpoint file.
11601202
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
11611203
it means that you want to run quantize transformations on the weights represented
11621204
in their original dtype, while the overall dtype of the model maybe something
11631205
different. If not specified, defaults to dtype_override.
1164-
args: The arguments passed to the script.
1206+
tokenizer_path: Path to the tokenizer file.
1207+
use_spin_quant: Type of spin quant to use ("cuda" or "native").
1208+
embedding_quantize: Type of embedding quantization.
1209+
quantization_mode: Type of quantization mode.
1210+
expand_rope_table: Whether to expand rope table.
1211+
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1212+
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1213+
quantize_kv_cache: Whether to quantize KV cache.
1214+
use_kv_cache: Whether to use KV cache.
1215+
qnn: Whether to use QNN.
1216+
use_qnn_sha: Whether to use QNN SHA.
1217+
optimized_rotation_path: Path to optimized rotation.
1218+
mps: Whether to use MPS.
1219+
coreml: Whether to use CoreML.
1220+
coreml_ios: CoreML iOS version.
1221+
vulkan: Whether to use Vulkan.
1222+
use_shared_embedding: Whether to use shared embedding.
1223+
use_qat: Whether to use QAT.
1224+
use_lora: LoRA rank (0 means no LoRA).
1225+
preq_mode: Pre-quantization mode.
1226+
preq_group_size: Pre-quantization group size.
1227+
preq_embedding_quantize: Pre-quantization embedding quantize.
11651228
11661229
Returns:
11671230
A list of transformation functions.
@@ -1172,21 +1235,21 @@ def _get_source_transforms( # noqa
11721235

11731236
transforms = []
11741237

1175-
if args.use_spin_quant:
1176-
if args.use_spin_quant == "cuda":
1238+
if use_spin_quant:
1239+
if use_spin_quant == "cuda":
11771240
from .source_transformation.spin_quant import (
11781241
inject_fast_hadamard_transform_cuda_for_spin_quant,
11791242
)
11801243

11811244
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
1182-
elif args.use_spin_quant == "native":
1245+
elif use_spin_quant == "native":
11831246
from .source_transformation.spin_quant import (
11841247
inject_fast_hadamard_transform_native_for_spin_quant,
11851248
)
11861249

11871250
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11881251

1189-
if args.embedding_quantize:
1252+
if embedding_quantize:
11901253
"""
11911254
When this option is selected, it finds all embedding layers and transforms
11921255
into quantized embedding equivalent module.
@@ -1196,12 +1259,25 @@ def _get_source_transforms( # noqa
11961259
transformations based on the given checkpoint first. In those cases,
11971260
this wil be a no-op.
11981261
"""
1199-
modelname = f"{modelname}_e"
1262+
# Create a mock args object with the necessary attributes
1263+
class Args:
1264+
pass
1265+
args = Args()
1266+
args.checkpoint = checkpoint
1267+
args.tokenizer_path = tokenizer_path
1268+
args.embedding_quantize = embedding_quantize
1269+
args.use_shared_embedding = use_shared_embedding
1270+
args.use_qat = use_qat
1271+
args.use_lora = use_lora
1272+
args.preq_mode = preq_mode
1273+
args.preq_group_size = preq_group_size
1274+
args.preq_embedding_quantize = preq_embedding_quantize
1275+
12001276
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
12011277

12021278
# quantization_mode should be applied after embedding_quantize
12031279
# to support shared_embedding
1204-
if args.quantization_mode:
1280+
if quantization_mode:
12051281
"""
12061282
When this option is selected, it finds all linear layers and transforms
12071283
into quantized linear equivalent module.
@@ -1215,7 +1291,19 @@ def _get_source_transforms( # noqa
12151291
There are cases where this may be a no-op, namely, if all linears are
12161292
quantized in the checkpoint.
12171293
"""
1218-
modelname = f"{modelname}_q"
1294+
# Create a mock args object with the necessary attributes
1295+
class Args:
1296+
pass
1297+
args = Args()
1298+
args.checkpoint = checkpoint
1299+
args.tokenizer_path = tokenizer_path
1300+
args.quantization_mode = quantization_mode
1301+
args.group_size = preq_group_size # Using preq_group_size as group_size
1302+
args.use_shared_embedding = use_shared_embedding
1303+
args.use_qat = use_qat
1304+
args.use_lora = use_lora
1305+
args.preq_mode = preq_mode
1306+
12191307
transforms.append(
12201308
get_quant_weight_transform(
12211309
args=args,
@@ -1224,15 +1312,12 @@ def _get_source_transforms( # noqa
12241312
)
12251313
)
12261314

1227-
if args.expand_rope_table:
1315+
if expand_rope_table:
12281316
transforms.append(materialze_broadcast_of_rope_freq_cis)
12291317

1230-
use_attention_mask_for_custom_sdpa = False
1231-
if isinstance(args, argparse.Namespace):
1232-
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1233-
use_attention_mask_for_custom_sdpa = True
1318+
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12341319

1235-
if args.use_sdpa_with_kv_cache:
1320+
if use_sdpa_with_kv_cache:
12361321
transforms.append(replace_kv_cache_with_custom_kv_cache)
12371322
# todo: do this optionally
12381323
# if use attention mask instead of causal attention
@@ -1244,23 +1329,23 @@ def _get_source_transforms( # noqa
12441329
else:
12451330
transforms.append(replace_sdpa_with_custom_op)
12461331

1247-
if args.quantize_kv_cache:
1248-
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1332+
if quantize_kv_cache:
1333+
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12491334
transforms.append(replace_kv_cache_with_quantized_kv_cache)
12501335
# Right now
12511336
transforms.append(replace_sdpa_with_quantized_sdpa)
12521337

1253-
if args.use_kv_cache:
1254-
if args.qnn:
1338+
if use_kv_cache:
1339+
if qnn:
12551340
from executorch.backends.qualcomm.utils.utils import (
12561341
convert_linear_to_conv2d,
12571342
)
12581343

1259-
if args.use_qnn_sha:
1260-
if args.optimized_rotation_path:
1344+
if use_qnn_sha:
1345+
if optimized_rotation_path:
12611346
transforms.append(fuse_layer_norms)
12621347
transforms.append(
1263-
get_model_with_r1_r2(args.optimized_rotation_path)
1348+
get_model_with_r1_r2(optimized_rotation_path)
12641349
)
12651350
transforms.append(replace_attention_to_attention_sha)
12661351
transforms.append(replace_causal_mask)
@@ -1272,29 +1357,29 @@ def _get_source_transforms( # noqa
12721357
transforms.append(replace_sdpa_with_flex_sdpa)
12731358
transforms.append(replace_causal_mask)
12741359
transforms.append(replace_rms_norm_with_native_rms_norm)
1275-
if args.optimized_rotation_path:
1360+
if optimized_rotation_path:
12761361
transforms.append(fuse_layer_norms)
12771362
transforms.append(
1278-
get_model_with_r1_r2(args.optimized_rotation_path)
1363+
get_model_with_r1_r2(optimized_rotation_path)
12791364
)
12801365
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
12811366
transforms.append(convert_linear_to_conv2d)
12821367

1283-
elif args.mps:
1368+
elif mps:
12841369
# Currently mps doesn't support sdpa op, use the simpler decomposition
12851370
# to get free perf gain.
12861371
transforms.append(replace_sdpa_with_simple_sdpa)
12871372
transforms.append(replace_causal_mask)
12881373

1289-
elif args.coreml:
1374+
elif coreml:
12901375
# iOS 18 introduced fused sdpa op
1291-
if args.coreml_ios >= 18:
1376+
if coreml_ios >= 18:
12921377
transforms.append(replace_sdpa_with_coreml_sdpa)
12931378
else:
12941379
transforms.append(replace_sdpa_with_simple_sdpa)
12951380
transforms.append(replace_kv_cache_with_coreml_kv_cache)
12961381

1297-
if args.vulkan:
1382+
if vulkan:
12981383
transforms.append(replace_with_vulkan_rotary_emb)
12991384

13001385
return transforms

0 commit comments

Comments
 (0)