Skip to content

Commit 5a31ec8

Browse files
authored
Fix QDQ layout slice operation when zero_point is None (#2054)
up
1 parent 04259eb commit 5a31ec8

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

torchao/dtypes/uintx/q_dq_layout.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
181181

182182
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
183183
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
184-
zero_point = aten.slice.Tensor(
185-
zero_point, dim, start_scale, end_scale, step
186-
)
187-
# this is to handle padding
188-
int_data, scale, zero_point = self._layout.post_process(
189-
int_data, scale, zero_point, self.block_size
190-
)
184+
if zero_point is not None:
185+
zero_point = aten.slice.Tensor(
186+
zero_point, dim, start_scale, end_scale, step
187+
)
191188
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
192189
return return_and_correct_aliasing(func, args, kwargs, sliced)
193190
else:

torchao/quantization/quant_api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,9 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
697697
Weights are quantized with scales and optionally zeros (controlled by weight_zero_point_domain) in a groupwise or
698698
channelwise manner using the number of bits specified by weight_dtype.
699699
700+
This layout is identical to Int8DynamicActivationInt4WeightConfig when weight_dtype = torch.int4 and other args
701+
are the same. However, this layout is more general and supports other weight dtypes.
702+
700703
args:
701704
weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8.
702705
torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6
@@ -796,6 +799,9 @@ def _int8_dynamic_activation_intx_weight_transform(
796799

797800
# We quantize with QDQLayout, and then construct the packed weight tensor later
798801
has_weight_zeros = weight_zero_point_domain == ZeroPointDomain.INT
802+
preserve_zero = (weight_mapping_type == MappingType.SYMMETRIC) or (
803+
weight_zero_point_domain == ZeroPointDomain.NONE
804+
)
799805
weight = to_affine_quantized_intx(
800806
input_float=weight,
801807
mapping_type=weight_mapping_type,
@@ -806,8 +812,7 @@ def _int8_dynamic_activation_intx_weight_transform(
806812
eps=torch.finfo(torch.float32).eps,
807813
scale_dtype=weight_scale_dtype,
808814
zero_point_dtype=torch.int8 if has_weight_zeros else None,
809-
preserve_zero=has_weight_zeros
810-
or (weight_mapping_type == MappingType.SYMMETRIC),
815+
preserve_zero=preserve_zero,
811816
zero_point_domain=weight_zero_point_domain,
812817
_layout=QDQLayout(),
813818
)

0 commit comments

Comments
 (0)