Skip to content

Get rid of llm_config enums #11810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 87 additions & 82 deletions examples/models/llama/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ast
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import ClassVar, List, Optional


Expand All @@ -25,32 +24,27 @@
################################################################################


class ModelType(str, Enum):
STORIES110M = "stories110m"
LLAMA2 = "llama2"
LLAMA3 = "llama3"
LLAMA3_1 = "llama3_1"
LLAMA3_2 = "llama3_2"
LLAMA3_2_VISION = "llama3_2_vision"
STATIC_LLAMA = "static_llama"
QWEN2_5 = "qwen2_5"
QWEN3_0_6B = "qwen3-0_6b"
QWEN3_1_7B = "qwen3-1_7b"
QWEN3_4B = "qwen3-4b"
PHI_4_MINI = "phi_4_mini"
SMOLLM2 = "smollm2"
MODEL_TYPE_OPTIONS = [
"stories110m",
"llama2",
"llama3",
"llama3_1",
"llama3_2",
"llama3_2_vision",
"static_llama",
"qwen2_5",
"qwen3-0_6b",
"qwen3-1_7b",
"qwen3-4b",
"phi_4_mini",
"smollm2",
]


class PreqMode(str, Enum):
"""
If you are dealing with pre-quantized checkpoints, this used to
be the way to specify them. Now you don't need to specify these
options if you use a TorchAo-prequantized checkpoint, but they
are still around to preserve backward compatibility.
"""

PREQ_8DA4W = "8da4w"
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
PREQ_MODE_OPTIONS = [
"8da4w",
"8da4w_output_8da8w",
]


@dataclass
Expand Down Expand Up @@ -82,34 +76,36 @@ class BaseConfig:
are loaded.
"""

model_class: ModelType = ModelType.LLAMA3
model_class: str = "llama3"
params: Optional[str] = None
checkpoint: Optional[str] = None
checkpoint_dir: Optional[str] = None
tokenizer_path: Optional[str] = None
metadata: Optional[str] = None
use_lora: int = 0
fairseq2: bool = False
preq_mode: Optional[PreqMode] = None
preq_mode: Optional[str] = None
preq_group_size: int = 32
preq_embedding_quantize: str = "8,0"

def __post_init__(self):
if self.model_class not in MODEL_TYPE_OPTIONS:
raise ValueError(f"model_class must be one of {MODEL_TYPE_OPTIONS}, got '{self.model_class}'")

if self.preq_mode is not None and self.preq_mode not in PREQ_MODE_OPTIONS:
raise ValueError(f"preq_mode must be one of {PREQ_MODE_OPTIONS}, got '{self.preq_mode}'")


################################################################################
################################# ModelConfig ##################################
################################################################################


class DtypeOverride(str, Enum):
"""
DType of the model. Highly recommended to use "fp32", unless you want to
export without a backend, in which case you can also use "bf16". "fp16"
is not recommended.
"""

FP32 = "fp32"
FP16 = "fp16"
BF16 = "bf16"
DTYPE_OVERRIDE_OPTIONS = [
"fp32",
"fp16",
"bf16",
]


@dataclass
Expand Down Expand Up @@ -147,7 +143,7 @@ class ModelConfig:
[16] pattern specifies all layers have a sliding window of 16.
"""

dtype_override: DtypeOverride = DtypeOverride.FP32
dtype_override: str = "fp32"
enable_dynamic_shape: bool = True
use_shared_embedding: bool = False
use_sdpa_with_kv_cache: bool = False
Expand All @@ -160,6 +156,9 @@ class ModelConfig:
local_global_attention: Optional[List[int]] = None

def __post_init__(self):
if self.dtype_override not in DTYPE_OVERRIDE_OPTIONS:
raise ValueError(f"dtype_override must be one of {DTYPE_OVERRIDE_OPTIONS}, got '{self.dtype_override}'")

self._validate_attention_sink()
self._validate_local_global_attention()

Expand Down Expand Up @@ -261,31 +260,25 @@ class DebugConfig:
################################################################################


class Pt2eQuantize(str, Enum):
"""
Type of backend-specific Pt2e quantization strategy to use.

Pt2e uses a different quantization library that is graph-based
compared to `qmode`, which is also specified in the QuantizationConfig
and is source transform-based.
"""
PT2E_QUANTIZE_OPTIONS = [
"xnnpack_dynamic",
"xnnpack_dynamic_qc4",
"qnn_8a8w",
"qnn_16a16w",
"qnn_16a4w",
"coreml_c4w",
"coreml_8a_c8w",
"coreml_8a_c4w",
"coreml_baseline_8a_c8w",
"coreml_baseline_8a_c4w",
"vulkan_8w",
]

XNNPACK_DYNAMIC = "xnnpack_dynamic"
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
QNN_8A8W = "qnn_8a8w"
QNN_16A16W = "qnn_16a16w"
QNN_16A4W = "qnn_16a4w"
COREML_C4W = "coreml_c4w"
COREML_8A_C8W = "coreml_8a_c8w"
COREML_8A_C4W = "coreml_8a_c4w"
COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
VULKAN_8W = "vulkan_8w"


class SpinQuant(str, Enum):
CUDA = "cuda"
NATIVE = "native"
SPIN_QUANT_OPTIONS = [
"cuda",
"native",
]


@dataclass
Expand Down Expand Up @@ -320,16 +313,22 @@ class QuantizationConfig:

qmode: Optional[str] = None
embedding_quantize: Optional[str] = None
pt2e_quantize: Optional[Pt2eQuantize] = None
pt2e_quantize: Optional[str] = None
group_size: Optional[int] = None
use_spin_quant: Optional[SpinQuant] = None
use_spin_quant: Optional[str] = None
use_qat: bool = False
calibration_tasks: Optional[List[str]] = None
calibration_limit: Optional[int] = None
calibration_seq_length: Optional[int] = None
calibration_data: str = "Once upon a time"

def __post_init__(self):
if self.pt2e_quantize is not None and self.pt2e_quantize not in PT2E_QUANTIZE_OPTIONS:
raise ValueError(f"pt2e_quantize must be one of {PT2E_QUANTIZE_OPTIONS}, got '{self.pt2e_quantize}'")

if self.use_spin_quant is not None and self.use_spin_quant not in SPIN_QUANT_OPTIONS:
raise ValueError(f"use_spin_quant must be one of {SPIN_QUANT_OPTIONS}, got '{self.use_spin_quant}'")

if self.qmode:
self._validate_qmode()

Expand Down Expand Up @@ -377,16 +376,18 @@ class XNNPackConfig:
extended_ops: bool = False


class CoreMLQuantize(str, Enum):
B4W = "b4w"
C4W = "c4w"
COREML_QUANTIZE_OPTIONS = [
"b4w",
"c4w",
]


class CoreMLComputeUnit(str, Enum):
CPU_ONLY = "cpu_only"
CPU_AND_GPU = "cpu_and_gpu"
CPU_AND_NE = "cpu_and_ne"
ALL = "all"
COREML_COMPUTE_UNIT_OPTIONS = [
"cpu_only",
"cpu_and_gpu",
"cpu_and_ne",
"all",
]


@dataclass
Expand All @@ -398,11 +399,17 @@ class CoreMLConfig:
enabled: bool = False
enable_state: bool = False
preserve_sdpa: bool = False
quantize: Optional[CoreMLQuantize] = None
quantize: Optional[str] = None
ios: int = 15
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY
compute_units: str = "cpu_only"

def __post_init__(self):
if self.quantize is not None and self.quantize not in COREML_QUANTIZE_OPTIONS:
raise ValueError(f"quantize must be one of {COREML_QUANTIZE_OPTIONS}, got '{self.quantize}'")

if self.compute_units not in COREML_COMPUTE_UNIT_OPTIONS:
raise ValueError(f"compute_units must be one of {COREML_COMPUTE_UNIT_OPTIONS}, got '{self.compute_units}'")

if self.ios not in (15, 16, 17, 18):
raise ValueError(f"Invalid coreml ios version: {self.ios}")

Expand Down Expand Up @@ -481,7 +488,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901

# BaseConfig
if hasattr(args, "model"):
llm_config.base.model_class = ModelType(args.model)
llm_config.base.model_class = args.model
if hasattr(args, "params"):
llm_config.base.params = args.params
if hasattr(args, "checkpoint"):
Expand All @@ -499,15 +506,15 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901

# PreqMode settings
if hasattr(args, "preq_mode") and args.preq_mode:
llm_config.base.preq_mode = PreqMode(args.preq_mode)
llm_config.base.preq_mode = args.preq_mode
if hasattr(args, "preq_group_size"):
llm_config.base.preq_group_size = args.preq_group_size
if hasattr(args, "preq_embedding_quantize"):
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize

# ModelConfig
if hasattr(args, "dtype_override"):
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
llm_config.model.dtype_override = args.dtype_override
if hasattr(args, "enable_dynamic_shape"):
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
if hasattr(args, "use_shared_embedding"):
Expand Down Expand Up @@ -549,11 +556,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
if hasattr(args, "embedding_quantize"):
llm_config.quantization.embedding_quantize = args.embedding_quantize
if hasattr(args, "pt2e_quantize") and args.pt2e_quantize:
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
llm_config.quantization.pt2e_quantize = args.pt2e_quantize
if hasattr(args, "group_size"):
llm_config.quantization.group_size = args.group_size
if hasattr(args, "use_spin_quant") and args.use_spin_quant:
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
llm_config.quantization.use_spin_quant = args.use_spin_quant
if hasattr(args, "use_qat"):
llm_config.quantization.use_qat = args.use_qat
if hasattr(args, "calibration_tasks"):
Expand Down Expand Up @@ -581,13 +588,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
args, "coreml_preserve_sdpa", False
)
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
llm_config.backend.coreml.quantize = args.coreml_quantize
if hasattr(args, "coreml_ios"):
llm_config.backend.coreml.ios = args.coreml_ios
if hasattr(args, "coreml_compute_units"):
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
args.coreml_compute_units
)
llm_config.backend.coreml.compute_units = args.coreml_compute_units

# Vulkan
if hasattr(args, "vulkan"):
Expand Down
31 changes: 29 additions & 2 deletions examples/models/llama/config/test_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from executorch.examples.models.llama.config.llm_config import (
BackendConfig,
BaseConfig,
CoreMLComputeUnit,
CoreMLConfig,
DebugConfig,
ExportConfig,
Expand Down Expand Up @@ -66,6 +65,34 @@ def test_shared_embedding_without_lowbit(self):
with self.assertRaises(ValueError):
LlmConfig(model=model_cfg, quantization=qcfg)

def test_invalid_model_type(self):
with self.assertRaises(ValueError):
BaseConfig(model_class="invalid_model")

def test_invalid_dtype_override(self):
with self.assertRaises(ValueError):
ModelConfig(dtype_override="invalid_dtype")

def test_invalid_preq_mode(self):
with self.assertRaises(ValueError):
BaseConfig(preq_mode="invalid_preq")

def test_invalid_pt2e_quantize(self):
with self.assertRaises(ValueError):
QuantizationConfig(pt2e_quantize="invalid_pt2e")

def test_invalid_spin_quant(self):
with self.assertRaises(ValueError):
QuantizationConfig(use_spin_quant="invalid_spin")

def test_invalid_coreml_quantize(self):
with self.assertRaises(ValueError):
CoreMLConfig(quantize="invalid_quantize")

def test_invalid_coreml_compute_units(self):
with self.assertRaises(ValueError):
CoreMLConfig(compute_units="invalid_compute_units")


class TestValidConstruction(unittest.TestCase):

Expand Down Expand Up @@ -94,7 +121,7 @@ def test_valid_llm_config(self):
backend=BackendConfig(
xnnpack=XNNPackConfig(enabled=False),
coreml=CoreMLConfig(
enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL
enabled=True, ios=17, compute_units="all"
),
),
)
Expand Down
Loading
Loading