Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions backends/xnnpack/recipes/xnnpack_recipe_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_xnnpack_executorch_backend_config,
)
from executorch.export import (
AOQuantizationConfig,
BackendRecipeProvider,
ExportRecipe,
LoweringRecipe,
Expand Down Expand Up @@ -57,31 +58,37 @@ def create_recipe(
if recipe_type == XNNPackRecipeType.FP32:
return self._build_fp32_recipe(recipe_type)

elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL:
elif recipe_type == XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL:
return self._build_quantized_recipe(
recipe_type, is_per_channel=True, is_dynamic=True
)

elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_CHANNEL:
elif recipe_type == XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL:
return self._build_quantized_recipe(
recipe_type, is_per_channel=True, is_dynamic=False
)

elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_TENSOR:
elif recipe_type == XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR:
return self._build_quantized_recipe(
recipe_type, is_per_channel=False, is_dynamic=False
)

elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL:
return self._build_int8da_intx_weight_recipe(
elif (
recipe_type
== XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL
):
return self._build_torchao_quantized_recipe(
recipe_type=recipe_type,
is_per_channel=True,
weight_dtype=torch.int4,
)

elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR:
elif (
recipe_type
== XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
):
group_size = kwargs.get("group_size", 32)
return self._build_int8da_intx_weight_recipe(
return self._build_torchao_quantized_recipe(
recipe_type=recipe_type,
is_per_channel=False,
weight_dtype=torch.int4,
Expand Down Expand Up @@ -132,7 +139,7 @@ def _build_quantized_recipe(
executorch_backend_config=get_xnnpack_executorch_backend_config(),
)

def _build_int8da_intx_weight_recipe(
def _build_torchao_quantized_recipe(
self,
recipe_type: RecipeType,
is_per_channel: bool = True,
Expand All @@ -141,17 +148,21 @@ def _build_int8da_intx_weight_recipe(
) -> ExportRecipe:
if is_per_channel:
weight_granularity = PerAxis(axis=0)
assert weight_dtype == torch.int4 or weight_dtype == torch.int8
else:
weight_granularity = PerGroup(group_size=group_size)
assert weight_dtype == torch.int4

config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype,
weight_granularity=weight_granularity,
config = AOQuantizationConfig(
Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype,
weight_granularity=weight_granularity,
)
)

quant_recipe = QuantizationRecipe(
quantizers=None,
ao_base_config=[config],
ao_quantization_configs=[config],
)

return ExportRecipe(
Expand All @@ -162,7 +173,10 @@ def _build_int8da_intx_weight_recipe(
)

def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
if recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR:
if (
recipe_type
== XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
):
expected_keys = {"group_size"}
unexpected = set(kwargs.keys()) - expected_keys
if unexpected:
Expand Down
21 changes: 12 additions & 9 deletions backends/xnnpack/recipes/xnnpack_recipe_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@ class XNNPackRecipeType(RecipeType):
"""XNNPACK-specific recipe types"""

FP32 = "fp32"

## PT2E-based quantization recipes
# INT8 Dynamic Quantization
INT8_DYNAMIC_PER_CHANNEL = "int8_dynamic_per_channel"
PT2E_INT8_DYNAMIC_PER_CHANNEL = "pt2e_int8_dynamic_per_channel"
# INT8 Static Quantization, needs calibration dataset
PT2E_INT8_STATIC_PER_CHANNEL = "pt2e_int8_static_per_channel"
PT2E_INT8_STATIC_PER_TENSOR = "pt2e_int8_static_per_tensor"

## TorchAO-based quantization recipes
# INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0
INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8da_int4w_per_channel"
TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = (
"torchao_int8da_int4w_per_channel"
)
# INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32
# can be overriden by group_size kwarg
INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8da_int4w_per_tensor"
# INT8 Static Activations INT4 Weight Quantization
INT8_STATIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8a_int4w_per_channel"
INT8_STATIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8a_int44w_per_tensor"
# INT8 Static Quantization, needs calibration dataset
INT8_STATIC_PER_CHANNEL = "int8_static_per_channel"
INT8_STATIC_PER_TENSOR = "int8_static_per_tensor"
TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "torchao_int8da_int4w_per_tensor"

@classmethod
def get_backend_name(cls) -> str:
Expand Down
94 changes: 54 additions & 40 deletions backends/xnnpack/test/recipes/test_xnnpack_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from executorch.examples.models.model_factory import EagerModelFactory
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
from executorch.exir.schema import DelegateCall, Program
from executorch.export import export, ExportRecipe, recipe_registry
from executorch.export import export, ExportRecipe, recipe_registry, StageType
from torch import nn
from torch.testing._internal.common_quantization import TestHelperModules
from torchao.quantization.utils import compute_error


class TestXnnpackRecipes(unittest.TestCase):
Expand All @@ -38,6 +39,29 @@ def check_fully_delegated(self, program: Program) -> None:
self.assertEqual(len(instructions), 1)
self.assertIsInstance(instructions[0].instr_args, DelegateCall)

# pyre-ignore
def _compare_eager_quantized_model_outputs(
self, session, example_inputs, atol: float
) -> None:
"""Utility to compare eager quantized model output with session output after xnnpack lowering"""
torch_export_stage_output = session.get_stage_artifacts()[
StageType.TORCH_EXPORT
]
eager_quantized_model = torch_export_stage_output.data["forward"].module()
output = session.run_method("forward", example_inputs[0])[0]
expected = eager_quantized_model(*example_inputs[0])
Tester._assert_outputs_equal(output, expected, atol=atol)

def _compare_eager_unquantized_model_outputs(
self, session, eager_unquantized_model, example_inputs, sqnr_threshold=20
):
"""Utility to compare eager unquantized model output with session output using SQNR"""
quantized_output = session.run_method("forward", example_inputs[0])[0]
original_output = eager_unquantized_model(*example_inputs[0])
error = compute_error(original_output, quantized_output)
print(f"{self._testMethodName} - SQNR: {error} dB")
self.assertTrue(error > sqnr_threshold)

def test_basic_recipe(self) -> None:
m_eager = TestHelperModules.TwoLinearModule().eval()
example_inputs = [(torch.randn(9, 8),)]
Expand All @@ -46,18 +70,13 @@ def test_basic_recipe(self) -> None:
example_inputs=example_inputs,
export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32),
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-3,
)
)
self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3)
self.check_fully_delegated(session.get_executorch_program())
self._compare_eager_unquantized_model_outputs(session, m_eager, example_inputs)

def test_int8_dynamic_quant_recipe(self) -> None:
test_cases = [
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL),
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL),
]

for export_recipe in test_cases:
Expand All @@ -70,19 +89,18 @@ def test_int8_dynamic_quant_recipe(self) -> None:
example_inputs=example_inputs,
export_recipe=export_recipe,
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-1,
)
self._compare_eager_quantized_model_outputs(
session, example_inputs, 1e-1
)
self.check_fully_delegated(session.get_executorch_program())
self._compare_eager_unquantized_model_outputs(
session, m_eager, example_inputs
)

def test_int8_static_quant_recipe(self) -> None:
test_cases = [
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_CHANNEL),
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_TENSOR),
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL),
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR),
]

for export_recipe in test_cases:
Expand All @@ -95,14 +113,13 @@ def test_int8_static_quant_recipe(self) -> None:
example_inputs=example_inputs,
export_recipe=export_recipe,
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-1,
)
self._compare_eager_quantized_model_outputs(
session, example_inputs, 1e-2
)
self.check_fully_delegated(session.get_executorch_program())
self._compare_eager_unquantized_model_outputs(
session, m_eager, example_inputs
)

def test_8a4w_recipe(self) -> None:
class SimpleLinearModel(nn.Module):
Expand All @@ -116,40 +133,36 @@ def forward(self, x) -> torch.Tensor:

test_cases = [
ExportRecipe.get_recipe(
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL,
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL,
),
ExportRecipe.get_recipe(
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
group_size=32,
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
group_size=8,
),
]

for export_recipe in test_cases:
with self.subTest(export_recipe=export_recipe):
model = SimpleLinearModel()
model = SimpleLinearModel().eval()
example_inputs = [(torch.randn(1, 32),)]
session = export(
model=model,
example_inputs=example_inputs,
export_recipe=export_recipe,
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
model(*example_inputs[0]),
atol=1e-2,
)
)
self.check_fully_delegated(session.get_executorch_program())
self._compare_eager_quantized_model_outputs(
session, example_inputs, 1e-3
)

def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType:
# Map QuantType to corresponding recipe name.
if quant_type == QuantType.STATIC_PER_CHANNEL:
return XNNPackRecipeType.INT8_STATIC_PER_CHANNEL
return XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL
elif quant_type == QuantType.DYNAMIC_PER_CHANNEL:
return XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL
return XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL
elif quant_type == QuantType.STATIC_PER_TENSOR:
return XNNPackRecipeType.INT8_STATIC_PER_TENSOR
return XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR
elif quant_type == QuantType.NONE:
return XNNPackRecipeType.FP32
else:
Expand Down Expand Up @@ -224,12 +237,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size(

# Should not raise any exception
recipe_w_default_group = provider.create_recipe(
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
)
self.assertIsNotNone(recipe_w_default_group)

recipe = provider.create_recipe(
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, group_size=64
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
group_size=64,
)
self.assertIsNotNone(recipe)

Expand All @@ -240,7 +254,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size(

with self.assertRaises(ValueError) as cm:
provider.create_recipe(
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
group_size="32", # String instead of int
)

Expand Down
9 changes: 8 additions & 1 deletion export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
"""

from .export import export, ExportSession
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
from .recipe import (
AOQuantizationConfig,
ExportRecipe,
LoweringRecipe,
QuantizationRecipe,
RecipeType,
)
from .recipe_provider import BackendRecipeProvider
from .recipe_registry import recipe_registry
from .types import StageType

__all__ = [
"AOQuantizationConfig",
"StageType",
"ExportRecipe",
"LoweringRecipe",
Expand Down
23 changes: 20 additions & 3 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from enum import Enum, EnumMeta
from typing import List, Optional, Sequence
from typing import Callable, List, Optional, Sequence

import torch

from executorch.exir._warnings import experimental

Expand Down Expand Up @@ -64,6 +66,20 @@ class Mode(str, Enum):
RELEASE = "release"


@dataclass
class AOQuantizationConfig:
"""
Configuration for torchao quantization with optional filter function.

Attributes:
ao_base_config: The AOBaseConfig for quantization
filter_fn: Optional filter function to selectively apply quantization
"""

ao_base_config: AOBaseConfig
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None


@dataclass
class QuantizationRecipe:
"""
Expand All @@ -73,11 +89,12 @@ class QuantizationRecipe:

Attributes:
quantizers: Optional list of quantizers for model quantization
ao_base_config: Optional list of AO base configurations
ao_quantization_configs: Optional list of AOQuantizationConfig objects that pair
AOBaseConfig with optional filter functions
"""

quantizers: Optional[List[Quantizer]] = None
ao_base_config: Optional[List[AOBaseConfig]] = None
ao_quantization_configs: Optional[List[AOQuantizationConfig]] = None

def get_quantizers(self) -> Optional[List[Quantizer]]:
"""
Expand Down
Loading
Loading