8
8
import contextlib
9
9
from dataclasses import dataclass
10
10
from functools import partial
11
- from typing import Any , Callable , Optional
11
+ from typing import Any , Callable , List , Optional
12
12
13
13
import torch
14
14
from torch ._higher_order_ops .out_dtype import out_dtype
23
23
_replace_literals_with_new_placeholders ,
24
24
remove_tensor_overload_for_qdq_ops ,
25
25
)
26
+ from torchao .quantization .quant_primitives import MappingType
27
+ from torchao .quantization .utils import _get_per_token_block_size
28
+ from torchao .utils import _register_custom_op
26
29
27
30
try :
28
31
from torch ._export .utils import _disable_aten_to_metadata_assertions
29
32
except :
30
33
_disable_aten_to_metadata_assertions = contextlib .nullcontext
31
34
35
+ quant_lib = torch .library .Library ("torchao" , "FRAGMENT" )
36
+ register_custom_op = _register_custom_op (quant_lib )
32
37
33
38
__all__ = [
34
39
"reference_representation_rewrite" ,
40
+ "_qdq_dynamic_quantized_linear_4bit_groupwise" ,
41
+ "_reference_dynamic_quantized_linear_4bit_groupwise" ,
35
42
]
36
43
37
44
@@ -203,6 +210,221 @@ def _reference_dynamic_quantized_linear(
203
210
return out_fp32
204
211
205
212
213
+ def _qdq_dynamic_quantized_linear_4bit_groupwise (
214
+ x_fp32 ,
215
+ x_eps ,
216
+ weight_i4 ,
217
+ weight_scale ,
218
+ weight_zero_point ,
219
+ bias_fp32 ,
220
+ group_size ,
221
+ ):
222
+ # Dynamic quantization of activation
223
+ x_mapping_type = MappingType .ASYMMETRIC
224
+ per_token_block_size = _get_per_token_block_size (x_fp32 )
225
+ x_quant_min = - 128
226
+ x_quant_max = 127
227
+ x_scale , x_zero_point = torch .ops .torchao .choose_qparams_affine (
228
+ x_fp32 ,
229
+ x_mapping_type .name ,
230
+ per_token_block_size ,
231
+ torch .int8 ,
232
+ x_quant_min ,
233
+ x_quant_max ,
234
+ x_eps ,
235
+ torch .float32 ,
236
+ torch .int32 ,
237
+ )
238
+ x_i8 = torch .ops .torchao .quantize_affine (
239
+ x_fp32 ,
240
+ per_token_block_size ,
241
+ x_scale ,
242
+ x_zero_point ,
243
+ torch .int8 ,
244
+ x_quant_min ,
245
+ x_quant_max ,
246
+ )
247
+ x_fp32 = torch .ops .torchao .dequantize_affine (
248
+ x_i8 ,
249
+ per_token_block_size ,
250
+ x_scale ,
251
+ x_zero_point ,
252
+ torch .int8 ,
253
+ x_quant_min ,
254
+ x_quant_max ,
255
+ torch .float32 ,
256
+ )
257
+
258
+ assert group_size > 0 , "Group size must be positive"
259
+ assert (
260
+ weight_i4 .shape [1 ] % group_size == 0
261
+ ), "Weight must be divisible by group_size"
262
+ assert weight_i4 .dim () == 2 , "Weight must be 2D tensor"
263
+ block_size = (1 , group_size )
264
+ weight_fp32 = torch .ops .torchao .dequantize_affine (
265
+ weight_i4 ,
266
+ block_size ,
267
+ weight_scale ,
268
+ weight_zero_point ,
269
+ torch .int8 ,
270
+ - 8 ,
271
+ 7 ,
272
+ )
273
+
274
+ out_fp32 = torch .ops .aten .linear .default (x_fp32 , weight_fp32 , bias_fp32 )
275
+ return out_fp32
276
+
277
+
278
+ @register_custom_op
279
+ def _reference_dqlinear_int4 (
280
+ x_fp32 : torch .Tensor ,
281
+ x_eps : float ,
282
+ weight_i4 : torch .Tensor ,
283
+ weight_scale : torch .Tensor ,
284
+ weight_zero_point : torch .Tensor , # Not used because assuming weight is symmetric
285
+ bias_fp32 : Optional [torch .Tensor ],
286
+ group_size : List [int ],
287
+ ) -> torch .Tensor :
288
+ """
289
+ Reference implementation for dynamically quantized linear 4-bit groupwise operation.
290
+ This implementation emulates actual numerics of on-device integer compute.
291
+
292
+ Args:
293
+ x_fp32: Input activation tensor in fp32
294
+ x_eps: Epsilon for quantization parameter computation
295
+ weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
296
+ weight_scale: Groupwise scales for weight dequantization
297
+ weight_zero_point: Groupwise zero points for weight (unused for symmetric)
298
+ bias_fp32: Optional bias tensor in fp32
299
+ group_size: Size of each group for groupwise quantization
300
+
301
+ Returns:
302
+ Output tensor in fp32
303
+ """
304
+ # Dynamic quantization of activation
305
+ group_size = group_size [1 ]
306
+ x_mapping_type = MappingType .ASYMMETRIC
307
+ per_token_block_size = _get_per_token_block_size (x_fp32 )
308
+ x_quant_min = - 128
309
+ x_quant_max = 127
310
+ x_scale , x_zero_point = torch .ops .torchao .choose_qparams_affine (
311
+ x_fp32 ,
312
+ x_mapping_type .name ,
313
+ per_token_block_size ,
314
+ torch .int8 ,
315
+ x_quant_min ,
316
+ x_quant_max ,
317
+ x_eps ,
318
+ torch .float32 ,
319
+ torch .int32 ,
320
+ )
321
+ x_i8 = torch .ops .torchao .quantize_affine (
322
+ x_fp32 ,
323
+ per_token_block_size ,
324
+ x_scale ,
325
+ x_zero_point ,
326
+ torch .int8 ,
327
+ x_quant_min ,
328
+ x_quant_max ,
329
+ )
330
+
331
+ # For groupwise quantization, we need to handle the computation differently
332
+ # weight_i4 shape: [out_features, in_features]
333
+ # weight_scale shape: [out_features, in_features // group_size]
334
+ # weight_zero_point shape: [out_features, in_features // group_size]
335
+ out_features , in_features = weight_i4 .shape
336
+ num_groups = in_features // group_size
337
+
338
+ # scales in xnnpack are stored as bf16 and converted to fp32 for computation
339
+ weight_scale = weight_scale .to (torch .bfloat16 ).to (torch .float32 )
340
+
341
+ assert x_i8 .dim () == 2 , "x_i8 must be 2D tensor"
342
+ # Reshape for group-wise processing
343
+ # x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
344
+ batch_size = x_i8 .shape [0 ]
345
+ x_i8_grouped = x_i8 .view (batch_size , num_groups , group_size )
346
+
347
+ # weight: [out_features, in_features] -> [out_features, num_groups, group_size]
348
+ weight_i4_grouped = weight_i4 .view (out_features , num_groups , group_size )
349
+
350
+ # Convert to int16 for computation
351
+ x_i32_grouped = x_i8_grouped .to (torch .int32 )
352
+ weight_i32_grouped = weight_i4_grouped .to (torch .int32 )
353
+
354
+ # Perform groupwise integer linear operation
355
+ acc_fp32 = torch .zeros (
356
+ batch_size , out_features , dtype = torch .float32 , device = x_fp32 .device
357
+ )
358
+
359
+ if weight_scale .ndim == 1 :
360
+ weight_scale = weight_scale .unsqueeze (0 )
361
+
362
+ for group_idx in range (num_groups ):
363
+ # Extract current group
364
+ x_group = x_i32_grouped [:, group_idx , :] # [batch_size, group_size]
365
+ weight_group = weight_i32_grouped [:, group_idx , :] # [out_features, group_size]
366
+ weight_group_col_sum = weight_group .sum (dim = - 1 ) # [out_features]
367
+
368
+ # Get scale for this group
369
+ weight_scale_group = weight_scale [:, group_idx ] # [out_features]
370
+
371
+ # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
372
+ group_acc = out_dtype (
373
+ torch .ops .aten .linear .default ,
374
+ torch .int32 ,
375
+ x_group ,
376
+ weight_group ,
377
+ None ,
378
+ )
379
+
380
+ # Output has to be scaled by x_scale * weight_scale_group
381
+ # However we will first scale by weight_scale_group, that is accounting
382
+ # only for scale of weight, and then scale by x_scale at the end because
383
+ # x_scale applies to all groups
384
+ acc_fp32 = acc_fp32 + group_acc .to (torch .float32 ) * weight_scale_group .view (
385
+ 1 , - 1
386
+ )
387
+
388
+ # we must also subtract x_zero_point * weight_group_sum
389
+ # since (X - x_zero_point) * W = X * W - x_zero_point * W
390
+ weights_col_sum_adjusted = (
391
+ weight_group_col_sum .to (torch .float32 ).view (1 , - 1 )
392
+ * x_zero_point .view (- 1 , 1 )
393
+ * weight_scale_group .view (1 , - 1 )
394
+ )
395
+ acc_fp32 = acc_fp32 - weights_col_sum_adjusted
396
+ x_scale_multiplier = x_scale .view (- 1 , 1 )
397
+ out_fp32 = acc_fp32 * x_scale_multiplier
398
+ if bias_fp32 is not None :
399
+ out_fp32 = out_fp32 + bias_fp32
400
+
401
+ return out_fp32
402
+
403
+
404
+ def _reference_dynamic_quantized_linear_4bit_groupwise (
405
+ x_fp32 ,
406
+ x_eps ,
407
+ weight_i4 ,
408
+ weight_scale ,
409
+ weight_zero_point , # Not used because assuming weight is symmetric
410
+ bias_fp32 ,
411
+ group_size ,
412
+ ):
413
+ """
414
+ Reference implementation for dynamically quantized linear 4-bit groupwise operation.
415
+ This function now delegates to the custom op implementation.
416
+ """
417
+ return torch .ops .torchao .reference_dqlinear_int4 (
418
+ x_fp32 ,
419
+ x_eps ,
420
+ weight_i4 ,
421
+ weight_scale ,
422
+ weight_zero_point ,
423
+ bias_fp32 ,
424
+ (1 , group_size ),
425
+ )
426
+
427
+
206
428
def _qdq_quantized_conv2d (
207
429
x_i8 ,
208
430
x_scale ,
@@ -739,6 +961,18 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739
961
127 ,
740
962
)
741
963
964
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS = (
965
+ torch .randn ((2 , 32 ), dtype = torch .float ), # x_fp32
966
+ torch .finfo (torch .float32 ).eps , # x_eps
967
+ torch .randint (- 8 , 7 , (8 , 32 ), dtype = torch .int8 ), # weight_i4 (stored as int8)
968
+ torch .randn (8 , 4 , dtype = torch .float ), # weight_scale [out_features, num_groups]
969
+ torch .zeros (
970
+ 8 , 4 , dtype = torch .int
971
+ ), # weight_zero_point [out_features, num_groups]
972
+ torch .randn (8 , dtype = torch .float ), # bias_fp32
973
+ 8 , # group_size
974
+ )
975
+
742
976
_REWRITE_INFO_LIST = [
743
977
_RewriteInfo (
744
978
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
@@ -753,6 +987,26 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
753
987
literal_to_ph_idx = {- 128 : 1 , 127 : 2 , torch .finfo (torch .float32 ).eps : 3 },
754
988
),
755
989
),
990
+ _RewriteInfo (
991
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS ,
992
+ WrapperModule (_qdq_dynamic_quantized_linear_4bit_groupwise ),
993
+ WrapperModule (_reference_dynamic_quantized_linear_4bit_groupwise ),
994
+ partial (
995
+ _replace_literals_with_existing_placeholders ,
996
+ literal_to_ph_idx = {
997
+ torch .finfo (torch .float32 ).eps : 1 ,
998
+ (1 , 8 ): 6 ,
999
+ },
1000
+ ),
1001
+ partial (
1002
+ _replace_literals_with_existing_placeholders ,
1003
+ literal_to_ph_idx = {
1004
+ torch .finfo (torch .float32 ).eps : 1 ,
1005
+ (1 , 8 ): 6 ,
1006
+ },
1007
+ ),
1008
+ ignore_literals = True ,
1009
+ ),
756
1010
_RewriteInfo (
757
1011
_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
758
1012
WrapperModule (_qdq_quantized_linear ),
0 commit comments