Skip to content

Commit c1b564a

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
optional zero points on dequantize_per_channel_group (#56)
Summary: Pull Request resolved: #56 Reviewed By: jerryzh168 Differential Revision: D54885425 fbshipit-source-id: 90fb97c605b98e59202019b831d6f929100a893f
1 parent c7fbf5a commit c1b564a

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional, Tuple
8+
79
import torch
810
from torch._dynamo import is_compiling as dynamo_is_compiling
911
from torch._higher_order_ops.out_dtype import out_dtype
@@ -12,7 +14,6 @@
1214
quantized_decomposed_lib,
1315
)
1416
from torch.library import impl
15-
from typing import Tuple
1617

1718
__all__ = [
1819
"safe_int_mm",
@@ -899,7 +900,7 @@ def group_quantize_tensor_symmetric(
899900

900901

901902
quantized_decomposed_lib.define(
902-
"dequantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, "
903+
"dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, "
903904
"int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor"
904905
)
905906

@@ -912,7 +913,7 @@ def group_quantize_tensor_symmetric(
912913
def dequantize_per_channel_group(
913914
w_int8: torch.Tensor,
914915
scales: torch.Tensor,
915-
zero_points: torch.Tensor,
916+
zero_points: Optional[torch.Tensor],
916917
quant_min: int,
917918
quant_max: int,
918919
dtype: torch.dtype,
@@ -947,10 +948,8 @@ def dequantize_per_channel_group(
947948

948949
w_int8_grouped = w_int8.reshape(-1, group_size)
949950
scales = scales.reshape(-1, 1)
950-
zero_points = zero_points.reshape(-1, 1)
951-
w_dq = (
952-
w_int8_grouped.sub(zero_points).mul(scales).reshape_as(w_int8).to(output_dtype)
953-
)
951+
zp = zero_points.reshape(-1, 1) if zero_points is not None else 0
952+
w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype)
954953
return w_dq
955954

956955

0 commit comments

Comments
 (0)