20
20
from compressed_tensors .quantization .quant_args import (
21
21
QuantizationArgs ,
22
22
QuantizationStrategy ,
23
+ QuantizationType ,
23
24
round_to_quantized_type ,
24
25
)
25
26
from compressed_tensors .quantization .quant_config import QuantizationStatus
@@ -49,6 +50,7 @@ def quantize(
49
50
args : QuantizationArgs ,
50
51
dtype : Optional [torch .dtype ] = None ,
51
52
g_idx : Optional [torch .Tensor ] = None ,
53
+ global_scale : Optional [torch .Tensor ] = None ,
52
54
) -> torch .Tensor :
53
55
"""
54
56
Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -63,6 +65,7 @@ def quantize(
63
65
:param args: quantization args dictating how to quantize x
64
66
:param dtype: optional dtype to cast the quantized output to
65
67
:param g_idx: optional mapping from column index to group index
68
+ :param global_scale: optional constant to scale the quantization scale during QDQ
66
69
:return: fake quantized tensor
67
70
"""
68
71
@@ -75,6 +78,7 @@ def quantize(
75
78
do_quantize = True ,
76
79
do_dequantize = False ,
77
80
g_idx = g_idx ,
81
+ global_scale = global_scale ,
78
82
)
79
83
80
84
@@ -86,6 +90,7 @@ def dequantize(
86
90
args : Optional [QuantizationArgs ] = None ,
87
91
dtype : Optional [torch .dtype ] = None ,
88
92
g_idx : Optional [torch .Tensor ] = None ,
93
+ global_scale : Optional [torch .Tensor ] = None ,
89
94
) -> torch .Tensor :
90
95
"""
91
96
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -97,6 +102,7 @@ def dequantize(
97
102
:param args: quantization args used to quantize x_q
98
103
:param dtype: optional dtype to cast the dequantized output to
99
104
:param g_idx: optional mapping from column index to group index
105
+ :param global_scale: optional constant to scale the quantization scale during QDQ
100
106
:return: dequantized float tensor
101
107
"""
102
108
if args is None :
@@ -128,6 +134,7 @@ def dequantize(
128
134
do_dequantize = True ,
129
135
dtype = dtype ,
130
136
g_idx = g_idx ,
137
+ global_scale = global_scale ,
131
138
)
132
139
133
140
@@ -138,6 +145,7 @@ def fake_quantize(
138
145
zero_point : torch .Tensor ,
139
146
args : QuantizationArgs ,
140
147
g_idx : Optional [torch .Tensor ] = None ,
148
+ global_scale : Optional [torch .Tensor ] = None ,
141
149
) -> torch .Tensor :
142
150
"""
143
151
Fake quantize the input tensor x by quantizing then dequantizing with
@@ -151,6 +159,7 @@ def fake_quantize(
151
159
:param zero_point: zero point tensor
152
160
:param args: quantization args dictating how to quantize x
153
161
:param g_idx: optional mapping from column index to group index
162
+ :param global_scale: optional constant to scale the quantization scale during QDQ
154
163
:return: fake quantized tensor
155
164
"""
156
165
return _process_quantization (
@@ -161,6 +170,7 @@ def fake_quantize(
161
170
do_quantize = True ,
162
171
do_dequantize = True ,
163
172
g_idx = g_idx ,
173
+ global_scale = global_scale ,
164
174
)
165
175
166
176
@@ -174,6 +184,7 @@ def _process_quantization(
174
184
dtype : Optional [torch .dtype ] = None ,
175
185
do_quantize : bool = True ,
176
186
do_dequantize : bool = True ,
187
+ global_scale : Optional [torch .Tensor ] = None ,
177
188
) -> torch .Tensor :
178
189
q_min , q_max = calculate_range (args , x .device )
179
190
group_size = args .group_size
@@ -221,35 +232,44 @@ def _process_quantization(
221
232
end = start + group_count
222
233
if do_quantize :
223
234
output [:, start :end ] = _quantize (
224
- x [:, start :end ],
225
- sc ,
226
- zp ,
227
- q_min ,
228
- q_max ,
229
- args ,
235
+ x = x [:, start :end ],
236
+ scale = sc ,
237
+ zero_point = zp ,
238
+ q_min = q_min ,
239
+ q_max = q_max ,
240
+ args = args ,
230
241
dtype = dtype ,
242
+ global_scale = global_scale ,
231
243
)
232
244
233
245
if do_dequantize :
234
246
input = output [:, start :end ] if do_quantize else x [:, start :end ]
235
- output [:, start :end ] = _dequantize (input , sc , zp )
247
+ output [:, start :end ] = _dequantize (
248
+ x_q = input , scale = sc , zero_point = zp , global_scale = global_scale
249
+ )
236
250
237
251
if not is_column_order :
238
252
output = safe_permute (output , torch .argsort (perm ), dim = 1 )
239
253
240
254
else : # covers channel, token and tensor strategies
241
255
if do_quantize :
242
256
output = _quantize (
243
- x ,
244
- scale ,
245
- zero_point ,
246
- q_min ,
247
- q_max ,
248
- args ,
257
+ x = x ,
258
+ scale = scale ,
259
+ zero_point = zero_point ,
260
+ q_min = q_min ,
261
+ q_max = q_max ,
262
+ args = args ,
249
263
dtype = dtype ,
264
+ global_scale = global_scale ,
250
265
)
251
266
if do_dequantize :
252
- output = _dequantize (output if do_quantize else x , scale , zero_point )
267
+ output = _dequantize (
268
+ output if do_quantize else x ,
269
+ scale = scale ,
270
+ zero_point = zero_point ,
271
+ global_scale = global_scale ,
272
+ )
253
273
254
274
return output
255
275
@@ -330,6 +350,7 @@ def forward_quantize(
330
350
return value
331
351
332
352
g_idx = getattr (module , "weight_g_idx" , None )
353
+ global_scale = getattr (module , f"{ base_name } _global_scale" , None )
333
354
334
355
if args .dynamic :
335
356
# dynamic quantization - determine the scale/zp on the fly
@@ -345,6 +366,7 @@ def forward_quantize(
345
366
zero_point = zero_point ,
346
367
args = args ,
347
368
g_idx = g_idx ,
369
+ global_scale = global_scale ,
348
370
)
349
371
350
372
@@ -357,11 +379,18 @@ def _quantize(
357
379
q_max : torch .Tensor ,
358
380
args : QuantizationArgs ,
359
381
dtype : Optional [torch .dtype ] = None ,
382
+ global_scale : Optional [torch .Tensor ] = None ,
360
383
) -> torch .Tensor :
361
384
385
+ # if a global scale is optionally provided, use it
386
+ # to further scale the local `scale` parameter
387
+ if global_scale :
388
+ scale = scale .to (global_scale .dtype ) / global_scale
389
+
362
390
scaled = x / scale
363
391
if zero_point is not None :
364
392
scaled += zero_point .to (x .dtype )
393
+
365
394
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
366
395
clamped_value = torch .clamp (
367
396
scaled ,
@@ -381,7 +410,14 @@ def _dequantize(
381
410
scale : torch .Tensor ,
382
411
zero_point : torch .Tensor = None ,
383
412
dtype : Optional [torch .dtype ] = None ,
413
+ global_scale : Optional [torch .Tensor ] = None ,
384
414
) -> torch .Tensor :
415
+
416
+ # if a global scale is optionally provided, use it
417
+ # to further scale the local `scale` parameter
418
+ if global_scale :
419
+ scale = scale .to (global_scale .dtype ) / global_scale
420
+
385
421
dequant_value = x_q .to (scale .dtype )
386
422
387
423
if zero_point is not None :
0 commit comments