5
5
6
6
import torch
7
7
from torch ._dynamo import is_compiling as dynamo_is_compiling
8
- from segment_anything_fast . int_mm import _int_mm_dequant
8
+ from torch . _higher_order_ops . out_dtype import out_dtype
9
9
10
10
def quantize_activation_per_token (t , scales ):
11
11
t = torch .round (t / scales ).clamp (- 127 , 127 ).to (torch .int8 )
@@ -14,10 +14,13 @@ def quantize_activation_per_token(t, scales):
14
14
def quantize_activation_per_token_absmax (t ):
15
15
n_bits = 8
16
16
# if the shape of t is [B, N, K], the shape of scales will be [B, N, 1]
17
- # want float scales to avoid overflows
18
- scales = t .abs ().max (dim = - 1 , keepdim = True )[0 ].float ()
17
+
18
+ # dequant with fp16 scale can cause overflow
19
+ # otherwise leave scales as same dtype
20
+ if t .dtype == torch .float16 :
21
+ t = t .float ()
19
22
q_max = 2 ** (n_bits - 1 ) - 1
20
- scales . clamp_ ( min = 1e-5 ).div_ (q_max )
23
+ scales = t . abs (). max ( dim = - 1 , keepdim = True )[ 0 ]. clamp ( min = 1e-5 ).div (q_max )
21
24
# Note: the original smoothquant does not clamp to qmin/qmax here,
22
25
# but some of the tests with bfloat16 ended up with a flipped sign
23
26
# if we don't clamp. TODO(future) look into this further.
@@ -58,18 +61,18 @@ def quant_int8_dynamic_per_token_linear(
58
61
w_vals_int8_t ,
59
62
w_scales ,
60
63
bias ,
61
- out_dtype = torch .float32 ,
64
+ output_dtype = torch .float32 ,
62
65
):
63
66
64
67
def quant_int8_per_token_matmul (
65
68
x_vals_int8 ,
66
69
x_scales ,
67
70
w_vals_int8_t ,
68
71
w_scales ,
69
- out_dtype = torch .float32 ,
72
+ output_dtype = torch .float32 ,
70
73
):
71
74
# Quantized matmul of int8 operands that accumulates to int32 and returns
72
- # out_dtype . For now, this is written for approximate numerical
75
+ # output_dtype . For now, this is written for approximate numerical
73
76
# Assumes that activation and weight quantization are symmetric,
74
77
# i.e. act_zp and w_zp is 0.
75
78
# Assumes that weight quantization is per-channel.
@@ -78,7 +81,7 @@ def quant_int8_per_token_matmul(
78
81
# https://github.com/google/gemmlowp/blob/master/doc/quantization.md
79
82
# for an overview of quantized matmul compute
80
83
81
- # in scalar form, assuming out_dtype is fp32 and zw == 0:
84
+ # in scalar form, assuming output_dtype is fp32 and zw == 0:
82
85
#
83
86
# Y_i_j_fp32 = sx * sw dot(X_i, W_j)
84
87
#
@@ -87,19 +90,16 @@ def quant_int8_per_token_matmul(
87
90
f'x dtype { x_vals_int8 .dtype } not yet supported'
88
91
assert w_vals_int8_t .dtype == torch .int8 , \
89
92
f'w dtype { w_vals_int8_t .dtype } not yet supported'
90
- assert w_scales .dtype == out_dtype , \
91
- f'{ w_scales .dtype } does not match { out_dtype } '
93
+ assert w_scales .dtype == output_dtype , \
94
+ f'{ w_scales .dtype } does not match { output_dtype } '
92
95
93
96
#
94
97
# 1. do the matrix form of dot(X_i, W_j)
95
98
#
96
99
97
100
98
- # TODO(before land): add test case for input with bsz
99
101
tmp = x_vals_int8 .reshape (- 1 , x_vals_int8 .shape [- 1 ])
100
- x_scales_flat = x_scales .view (tmp .size (0 ), 1 )
101
- w_scales_flat = w_scales .unsqueeze (0 )
102
- # y_dot_int32 = torch._int_mm(tmp, w_vals_int8_t)
102
+ y_dot_int32 = out_dtype (torch .ops .aten .mm .default , torch .int32 , tmp , w_vals_int8_t )
103
103
104
104
#
105
105
# 2. rescale the output
@@ -108,21 +108,16 @@ def quant_int8_per_token_matmul(
108
108
# large that y_dot_int32 * a float16 scale is greater than the maximum
109
109
# value of a float 16, (which results in a value of inf even if multiplying
110
110
# by the other scale would bring it within the expected range)
111
+ assert x_scales .dtype != torch .float16 , f"can have overflows happen when x_scales is float16"
111
112
112
- assert x_scales .dtype == torch .float , f"x_scales needs to be a torch.float32 but got { x_scales .dtype } "
113
-
114
- # y = y_dot_int32 * x_scales_flat * w_scales_flat
115
- # # can downcast only at the very end
116
- # y = y.to(out_dtype)
113
+ y = (y_dot_int32 * x_scales .view (- 1 , 1 ) * w_scales )
117
114
118
- # y = _int_mm_dequant(tmp, w_vals_int8_t, x_scales_flat, w_scales_flat, out_dtype)
119
- y = torch .ops .custom_int_mm .int_mm_dequant (tmp , w_vals_int8_t , x_scales_flat , w_scales_flat , out_dtype )
120
- return y .reshape (* x_vals_int8 .shape [:- 1 ], - 1 )
115
+ return y .reshape (* x_vals_int8 .shape [:- 1 ], - 1 ).to (output_dtype )
121
116
122
117
# like F.linear, but with int8 dynamic quantization of activation,
123
118
# and a quantized weight
124
119
mm_out = quant_int8_per_token_matmul (
125
- x_vals_int8 , x_scales , w_vals_int8_t , w_scales , out_dtype )
120
+ x_vals_int8 , x_scales , w_vals_int8_t , w_scales , output_dtype )
126
121
if bias is not None :
127
122
mm_out += bias
128
123
return mm_out
0 commit comments