Skip to content

[Feat] Allow symmetric_no_clipping_error for KleidiAI kernels, update Readme and validate Kleidi INT4 quantization path #2570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 26, 2025
Merged
2 changes: 2 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
Copyright 2023 Meta
All contributions by Arm:
Copyright (c) 2024-2025 Arm Limited and/or its affiliates

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

Expand Down
32 changes: 0 additions & 32 deletions torchao/experimental/docs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,38 +96,6 @@ quantize_(
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() is also supported, but much slower on CPU
),
)
```

KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows:

```python
from torchao.dtypes import PlainLayout
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.experimental.quant_api import (
int8_dynamic_activation_intx_weight,
)
from torchao.quantization.granularity import (
PerGroup,
PerRow,
)
from torchao.quantization.quant_api import quantize_
from torchao.quantization.quant_primitives import MappingType

my_model = Model()

quantize_(
my_model,
int8_dynamic_activation_intx_weight(
weight_dtype=torch.int4,
granularity=PerGroup(32), # PerRow() is also supported
has_weight_zeros=True, # Should be True
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"),
),
)
```

If you get stuck, consult
`torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
Expand Down Expand Up @@ -54,6 +55,7 @@ class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
for weight_mapping_type in [
MappingType.SYMMETRIC,
MappingType.ASYMMETRIC,
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
]
for weight_granularity in [
PerGroup(128),
Expand All @@ -71,6 +73,12 @@ def test_accuracy(
"""
Checks the accuracy of packed layouts
"""
if (
weight_dtype == torch.int1
and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR
):
return

m = 3
n = 1071
k = 2048
Expand Down
34 changes: 34 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,40 @@ quantize_(model, FPXWeightOnlyConfig(3, 2))

You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype.

```
KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows:
```python
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
quantize_,
)
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
Target,
)
from torchao.quantization.granularity import PerGroup, PerAxis
from torchao.quantization.quant_primitives import MappingType
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
my_model = Model()
# Set quantization layout
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=Target.ATEN)
quantize_(
my_model,
Int8DynamicActivationIntxWeightConfig(
weight_scale_dtype=torch.float32,
weight_granularity=PerGroup(32), #PerAxis is also supported
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR, # MappingType.SYMMETRIC can also be used but increases error
layout=layout,
weight_dtype=torch.int4,
),
)
```

## Affine Quantization Details
Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization.

Expand Down
13 changes: 10 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

Expand Down Expand Up @@ -862,8 +862,9 @@ def __post_init__(self):
assert self.weight_mapping_type in [
MappingType.ASYMMETRIC,
MappingType.SYMMETRIC,
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
], (
f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.weight_mapping_type}"
f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}"
)
assert self.act_mapping_type in [
MappingType.ASYMMETRIC,
Expand Down Expand Up @@ -917,6 +918,12 @@ def _int8_dynamic_activation_intx_weight_transform(
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype]

# We quantize with QDQLayout, and then construct the packed weight tensor later
# set preserve_zero based on weight mapping type
preserve_zero = weight_mapping_type in [
MappingType.SYMMETRIC,
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
]

weight = to_affine_quantized_intx(
input_float=weight,
mapping_type=weight_mapping_type,
Expand All @@ -926,7 +933,7 @@ def _int8_dynamic_activation_intx_weight_transform(
quant_max=quant_max,
scale_dtype=weight_scale_dtype,
zero_point_dtype=torch.int8,
preserve_zero=(weight_mapping_type == MappingType.SYMMETRIC),
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.INT,
_layout=QDQLayout(),
)
Expand Down
Loading