Skip to content

Commit 8fdf6a9

Browse files
committed
[Fix]: Use Custom Tensor for dyn_quant_matmul_4bit aten op
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
1 parent de51c5f commit 8fdf6a9

File tree

4 files changed

+107
-30
lines changed

4 files changed

+107
-30
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def from_hp_to_intx(
204204
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
205205
_layout: Layout = PlainLayout(),
206206
use_hqq: bool = False,
207-
bias: Optional[torch.Tensor] = None
208207
):
209208
"""Convert a high precision tensor to an integer affine quantized tensor."""
210209
original_shape = input_float.shape
@@ -277,11 +276,7 @@ def from_hp_to_intx(
277276

278277
data = _layout.post_process(data)
279278
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
280-
args = [data, scale, zero_point, _layout]
281-
# Only PackedLinearInt8DynamicActivationIntxWeightLayout() with "aten" target supports bias
282-
if bias is not None:
283-
args.append(bias)
284-
tensor_impl = tensor_impl_ctr(*args)
279+
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
285280
return cls(
286281
tensor_impl,
287282
block_size,

torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from torch.utils._python_dispatch import return_and_correct_aliasing
1313

1414
from torchao.dtypes.affine_quantized_tensor import (
15+
AffineQuantizedTensor,
16+
get_tensor_impl_constructor,
1517
register_layout,
1618
)
1719
from torchao.dtypes.affine_quantized_tensor_ops import (
@@ -20,7 +22,11 @@
2022
from torchao.dtypes.utils import AQTTensorImpl, Layout
2123
from torchao.quantization.quant_primitives import (
2224
ZeroPointDomain,
25+
MappingType,
26+
choose_qparams_affine,
27+
quantize_affine,
2328
)
29+
2430
from torchao.utils import (
2531
TORCH_VERSION_AT_LEAST_2_6,
2632
)
@@ -325,3 +331,78 @@ def _impl_2d_aten(input_tensor, weight_tensor):
325331
_linear_check,
326332
_linear_impl,
327333
)
334+
335+
336+
class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor):
337+
"""
338+
PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class.
339+
"""
340+
341+
@classmethod
342+
def from_hp_to_intx(
343+
cls,
344+
input_float: torch.Tensor,
345+
mapping_type: MappingType,
346+
block_size: Tuple[int, ...],
347+
target_dtype: torch.dtype,
348+
quant_min: Optional[int] = None,
349+
quant_max: Optional[int] = None,
350+
eps: Optional[float] = None,
351+
scale_dtype: Optional[torch.dtype] = None,
352+
zero_point_dtype: Optional[torch.dtype] = None,
353+
preserve_zero: bool = True,
354+
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
355+
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(),
356+
use_hqq: bool = False,
357+
bias: Optional[torch.Tensor] = None
358+
):
359+
assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
360+
assert isinstance(
361+
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
362+
assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
363+
original_shape = input_float.shape
364+
input_float = _layout.pre_process(input_float)
365+
366+
scale, zero_point = choose_qparams_affine(
367+
input_float,
368+
mapping_type,
369+
block_size,
370+
target_dtype,
371+
quant_min,
372+
quant_max,
373+
eps,
374+
scale_dtype,
375+
zero_point_dtype,
376+
preserve_zero,
377+
zero_point_domain,
378+
)
379+
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
380+
# TODO should probably consolidate ZeroPointDomain.NONE and None
381+
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
382+
zero_point = None
383+
data = quantize_affine(
384+
input_float,
385+
block_size,
386+
scale,
387+
zero_point,
388+
target_dtype,
389+
quant_min,
390+
quant_max,
391+
zero_point_domain,
392+
)
393+
# Note: output will be uint8 tensor for sub byte tensors for now
394+
395+
data = _layout.post_process(data)
396+
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
397+
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias)
398+
return cls(
399+
tensor_impl,
400+
block_size,
401+
original_shape,
402+
quant_min,
403+
quant_max,
404+
zero_point_domain,
405+
dtype=input_float.dtype,
406+
)
407+
408+
to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx

torchao/experimental/quant_api.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
494494

495495
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
496496
PackedLinearInt8DynamicActivationIntxWeightLayout,
497+
to_packedlinearint8dynamicactivationintxweight_quantized_intx,
497498
)
498499
from torchao.quantization.linear_activation_quantized_tensor import (
499500
to_linear_activation_quantized,
@@ -576,6 +577,7 @@ def int8_dynamic_activation_intx_weight(
576577
)
577578
bit_width = dtype_to_bit_width[weight_dtype]
578579
layout_arg = layout
580+
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten"
579581

580582
def apply(weight, bias: Optional[torch.Tensor] = None):
581583
if isinstance(granularity, PerGroup):
@@ -591,6 +593,10 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
591593

592594
layout = layout_arg
593595
scale_dtype = None
596+
tensor_quantizer = to_affine_quantized_intx
597+
quant_min = -(1 << (bit_width - 1))
598+
quant_max = (1 << (bit_width - 1)) - 1
599+
594600
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
595601
assert (
596602
weight.device == torch.device("cpu")
@@ -613,26 +619,23 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
613619
if torch.backends.kleidiai.is_available():
614620
if isinstance(granularity, PerGroup):
615621
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
616-
617-
quant_min = -(1 << (bit_width - 1))
618-
quant_max = (1 << (bit_width - 1)) - 1
619-
weight = to_affine_quantized_intx(
620-
weight,
621-
mapping_type=weight_mapping_type,
622-
block_size=(1, group_size),
623-
target_dtype=torch.int32,
624-
quant_min=quant_min,
625-
quant_max=quant_max,
626-
eps=torch.finfo(torch.float32).eps,
627-
scale_dtype=scale_dtype,
628-
zero_point_dtype=torch.int8,
629-
preserve_zero=has_weight_zeros,
630-
zero_point_domain=ZeroPointDomain.INT
631-
if has_weight_zeros
632-
else ZeroPointDomain.NONE,
633-
_layout=layout,
634-
bias=bias
635-
)
622+
tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx
623+
624+
quantizer_args = [weight,
625+
weight_mapping_type,
626+
(1, group_size),
627+
torch.int32,
628+
quant_min,
629+
quant_max,
630+
torch.finfo(torch.float32).eps,
631+
scale_dtype,
632+
torch.int8,
633+
has_weight_zeros,
634+
ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE,
635+
layout,
636+
False] + ([bias] if propagate_bias else [])
637+
638+
weight = tensor_quantizer(*quantizer_args)
636639

637640
# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
638641
# with the kernel and it should not be applied separately
@@ -650,7 +653,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
650653
weight = to_linear_activation_quantized(weight, activation_quant_func)
651654
return weight
652655

653-
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten"
654656
return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias)
655657

656658

torchao/quantization/quant_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,11 +457,10 @@ def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, pro
457457

458458
def insert_subclass(lin):
459459
requires_grad = allow_requires_grad and lin.weight.requires_grad
460-
args = [lin.weight]
461460
if propagate_bias == True:
462-
args.append(lin.bias)
461+
kwargs["bias"] = lin.bias
463462
lin.weight = torch.nn.Parameter(
464-
constructor(*args, **kwargs), requires_grad=requires_grad
463+
constructor(lin.weight, **kwargs), requires_grad=requires_grad
465464
)
466465
lin.extra_repr = types.MethodType(_linear_extra_repr, lin)
467466
return lin

0 commit comments

Comments
 (0)