4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from typing import Optional , Tuple
8
+
7
9
import torch
8
10
from torch ._dynamo import is_compiling as dynamo_is_compiling
9
11
from torch ._higher_order_ops .out_dtype import out_dtype
12
14
quantized_decomposed_lib ,
13
15
)
14
16
from torch .library import impl
15
- from typing import Tuple
16
17
17
18
__all__ = [
18
19
"safe_int_mm" ,
@@ -899,7 +900,7 @@ def group_quantize_tensor_symmetric(
899
900
900
901
901
902
quantized_decomposed_lib .define (
902
- "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, "
903
+ "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, "
903
904
"int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor"
904
905
)
905
906
@@ -912,7 +913,7 @@ def group_quantize_tensor_symmetric(
912
913
def dequantize_per_channel_group (
913
914
w_int8 : torch .Tensor ,
914
915
scales : torch .Tensor ,
915
- zero_points : torch .Tensor ,
916
+ zero_points : Optional [ torch .Tensor ] ,
916
917
quant_min : int ,
917
918
quant_max : int ,
918
919
dtype : torch .dtype ,
@@ -947,10 +948,8 @@ def dequantize_per_channel_group(
947
948
948
949
w_int8_grouped = w_int8 .reshape (- 1 , group_size )
949
950
scales = scales .reshape (- 1 , 1 )
950
- zero_points = zero_points .reshape (- 1 , 1 )
951
- w_dq = (
952
- w_int8_grouped .sub (zero_points ).mul (scales ).reshape_as (w_int8 ).to (output_dtype )
953
- )
951
+ zp = zero_points .reshape (- 1 , 1 ) if zero_points is not None else 0
952
+ w_dq = w_int8_grouped .sub (zp ).mul (scales ).reshape_as (w_int8 ).to (output_dtype )
954
953
return w_dq
955
954
956
955
0 commit comments