@@ -50,6 +50,7 @@ def quantize(
50
50
args : QuantizationArgs ,
51
51
dtype : Optional [torch .dtype ] = None ,
52
52
g_idx : Optional [torch .Tensor ] = None ,
53
+ global_scale : Optional [torch .Tensor ] = None ,
53
54
) -> torch .Tensor :
54
55
"""
55
56
Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -76,6 +77,7 @@ def quantize(
76
77
do_quantize = True ,
77
78
do_dequantize = False ,
78
79
g_idx = g_idx ,
80
+ global_scale = global_scale ,
79
81
)
80
82
81
83
@@ -87,6 +89,7 @@ def dequantize(
87
89
args : Optional [QuantizationArgs ] = None ,
88
90
dtype : Optional [torch .dtype ] = None ,
89
91
g_idx : Optional [torch .Tensor ] = None ,
92
+ global_scale : Optional [torch .Tensor ] = None ,
90
93
) -> torch .Tensor :
91
94
"""
92
95
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -129,6 +132,7 @@ def dequantize(
129
132
do_dequantize = True ,
130
133
dtype = dtype ,
131
134
g_idx = g_idx ,
135
+ global_scale = global_scale ,
132
136
)
133
137
134
138
@@ -139,6 +143,7 @@ def fake_quantize(
139
143
zero_point : torch .Tensor ,
140
144
args : QuantizationArgs ,
141
145
g_idx : Optional [torch .Tensor ] = None ,
146
+ global_scale : Optiona [torch .Tensor ] = None ,
142
147
) -> torch .Tensor :
143
148
"""
144
149
Fake quantize the input tensor x by quantizing then dequantizing with
@@ -162,6 +167,7 @@ def fake_quantize(
162
167
do_quantize = True ,
163
168
do_dequantize = True ,
164
169
g_idx = g_idx ,
170
+ global_scale = global_scale ,
165
171
)
166
172
167
173
@@ -175,6 +181,7 @@ def _process_quantization(
175
181
dtype : Optional [torch .dtype ] = None ,
176
182
do_quantize : bool = True ,
177
183
do_dequantize : bool = True ,
184
+ global_scale : Optional [torch .Tensor ] = None ,
178
185
) -> torch .Tensor :
179
186
q_min , q_max = calculate_range (args , x .device )
180
187
group_size = args .group_size
@@ -222,35 +229,44 @@ def _process_quantization(
222
229
end = start + group_count
223
230
if do_quantize :
224
231
output [:, start :end ] = _quantize (
225
- x [:, start :end ],
226
- sc ,
227
- zp ,
228
- q_min ,
229
- q_max ,
230
- args ,
232
+ x = x [:, start :end ],
233
+ scale = sc ,
234
+ zero_point = zp ,
235
+ q_min = q_min ,
236
+ q_max = q_max ,
237
+ args = args ,
231
238
dtype = dtype ,
239
+ global_scale = global_scale ,
232
240
)
233
241
234
242
if do_dequantize :
235
243
input = output [:, start :end ] if do_quantize else x [:, start :end ]
236
- output [:, start :end ] = _dequantize (input , sc , zp )
244
+ output [:, start :end ] = _dequantize (
245
+ x = input , scale = sc , zero_point = zp , global_scale = global_scale
246
+ )
237
247
238
248
if not is_column_order :
239
249
output = safe_permute (output , torch .argsort (perm ), dim = 1 )
240
250
241
251
else : # covers channel, token and tensor strategies
242
252
if do_quantize :
243
253
output = _quantize (
244
- x ,
245
- scale ,
246
- zero_point ,
247
- q_min ,
248
- q_max ,
249
- args ,
254
+ x = x ,
255
+ scale = scale ,
256
+ zero_point = zero_point ,
257
+ q_min = q_min ,
258
+ q_max = q_max ,
259
+ args = args ,
250
260
dtype = dtype ,
261
+ global_scale = global_scale ,
251
262
)
252
263
if do_dequantize :
253
- output = _dequantize (output if do_quantize else x , scale , zero_point )
264
+ output = _dequantize (
265
+ output if do_quantize else x ,
266
+ scale = scale ,
267
+ zero_point = zero_point ,
268
+ global_scale = global_scale ,
269
+ )
254
270
255
271
return output
256
272
@@ -331,6 +347,7 @@ def forward_quantize(
331
347
return value
332
348
333
349
g_idx = getattr (module , "weight_g_idx" , None )
350
+ global_scale = getattr (module , f"{ base_name } _global_scale" , None )
334
351
335
352
if args .dynamic :
336
353
# dynamic quantization - determine the scale/zp on the fly
@@ -346,6 +363,7 @@ def forward_quantize(
346
363
zero_point = zero_point ,
347
364
args = args ,
348
365
g_idx = g_idx ,
366
+ global_scale = global_scale ,
349
367
)
350
368
351
369
@@ -358,11 +376,16 @@ def _quantize(
358
376
q_max : torch .Tensor ,
359
377
args : QuantizationArgs ,
360
378
dtype : Optional [torch .dtype ] = None ,
379
+ global_scale : Optional [torch .Tensor ] = None ,
361
380
) -> torch .Tensor :
362
381
382
+ if global_scale :
383
+ scale = scale .to (global_scale .dtype ) * global_scale
384
+
363
385
scaled = x / scale
364
386
if zero_point is not None :
365
387
scaled += zero_point .to (x .dtype )
388
+
366
389
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
367
390
clamped_value = torch .clamp (
368
391
scaled ,
@@ -382,8 +405,12 @@ def _dequantize(
382
405
scale : torch .Tensor ,
383
406
zero_point : torch .Tensor = None ,
384
407
dtype : Optional [torch .dtype ] = None ,
408
+ global_scale : Optional [torch .Tensor ] = None ,
385
409
) -> torch .Tensor :
386
410
411
+ if global_scale :
412
+ scale = scale .to (global_scale .dtype ) * global_scale
413
+
387
414
dequant_value = x_q .to (scale .dtype )
388
415
389
416
if zero_point is not None :
0 commit comments