Skip to content

Commit eb61a24

Browse files
committed
[Fix]: Enable SYMMETRIC_NO_CLIPPING_ERR Mapping type and tests
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
1 parent ca8a5f1 commit eb61a24

File tree

3 files changed

+117
-2
lines changed

3 files changed

+117
-2
lines changed

torchao/experimental/docs/readme.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,37 @@ quantize_(
9898
)
9999
```
100100

101+
KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows:
102+
103+
```python
104+
from torchao.dtypes import PlainLayout
105+
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
106+
PackedLinearInt8DynamicActivationIntxWeightLayout,
107+
)
108+
from torchao.experimental.quant_api import (
109+
int8_dynamic_activation_intx_weight,
110+
)
111+
from torchao.quantization.granularity import (
112+
PerGroup,
113+
PerRow,
114+
)
115+
from torchao.quantization.quant_api import quantize_
116+
from torchao.quantization.quant_primitives import MappingType
117+
118+
my_model = Model()
119+
120+
quantize_(
121+
my_model,
122+
int8_dynamic_activation_intx_weight(
123+
weight_dtype=torch.int4,
124+
granularity=PerGroup(32), # PerRow() is also supported
125+
has_weight_zeros=True, # Should be True
126+
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error
127+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"),
128+
),
129+
)
130+
```
131+
101132
If you get stuck, consult
102133
`torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py`
103134
for a working example.

torchao/experimental/quant_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,13 +614,13 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
614614
if layout.target == Target.ATEN:
615615
if weight_dtype != torch.int4 or \
616616
has_weight_zeros != True or \
617-
weight_mapping_type != MappingType.SYMMETRIC:
617+
weight_mapping_type == MappingType.ASYMMETRIC:
618618
raise NotImplementedError(
619619
f"target 'aten' requires:\n"
620620
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
621621
f"- has_weight_zeros to be True,\n"
622622
f"- weight_dtype to be torch.int4,\n"
623-
f"- weight_mapping_type to be MappingType.SYMMETRIC"
623+
f"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR"
624624
)
625625
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
626626
if torch.backends.kleidiai.is_available():
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
12+
from torchao.dtypes import PlainLayout
13+
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
14+
PackedLinearInt8DynamicActivationIntxWeightLayout,
15+
)
16+
from torchao.experimental.quant_api import (
17+
int8_dynamic_activation_intx_weight,
18+
)
19+
from torchao.quantization.granularity import (
20+
PerGroup,
21+
PerRow,
22+
)
23+
from torchao.quantization.quant_api import quantize_
24+
from torchao.utils import unwrap_tensor_subclass
25+
from torchao.quantization.quant_primitives import MappingType
26+
27+
28+
class TestPackedLinearInt8DynamicActivationIntxWeightLayoutAten(unittest.TestCase):
29+
def test_accuracy(self):
30+
"""
31+
Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing
32+
its results to the results of a reference model that uses PlainLayout()
33+
"""
34+
granularities = [PerRow()]
35+
m = 32
36+
n = 128
37+
k = 256
38+
activations = torch.randn(m, k)
39+
weight_mapping_type = MappingType.SYMMETRIC_NO_CLIPPING_ERR
40+
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
41+
42+
for weight_dtype in [
43+
torch.int4,
44+
]:
45+
for has_weight_zeros in [True]:
46+
for granularity in granularities:
47+
print(
48+
f"Testing weight_dtype={weight_dtype}, has_weight_zeros={
49+
has_weight_zeros}, granularity={granularity}"
50+
)
51+
quantized_model = copy.deepcopy(model)
52+
quantize_(
53+
quantized_model,
54+
int8_dynamic_activation_intx_weight(
55+
weight_dtype=weight_dtype,
56+
granularity=granularity,
57+
has_weight_zeros=has_weight_zeros,
58+
weight_mapping_type=weight_mapping_type,
59+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
60+
target="aten"), # default
61+
),
62+
)
63+
64+
quantized_model_reference = copy.deepcopy(model)
65+
quantize_(
66+
quantized_model_reference,
67+
int8_dynamic_activation_intx_weight(
68+
weight_dtype=weight_dtype,
69+
granularity=granularity,
70+
has_weight_zeros=has_weight_zeros,
71+
layout=PlainLayout(),
72+
),
73+
)
74+
75+
with torch.no_grad():
76+
res = quantized_model(activations)
77+
ref = quantized_model_reference(activations)
78+
79+
mean_err = ((res - ref).abs() / ref).mean()
80+
self.assertTrue(mean_err < 0.04)
81+
82+
83+
if __name__ == "__main__":
84+
unittest.main()

0 commit comments

Comments
 (0)