8
8
import contextlib
9
9
from dataclasses import dataclass
10
10
from functools import partial
11
- from typing import Any , Callable , Optional
11
+ from typing import Any , Callable , List , Optional
12
12
13
13
import torch
14
14
from torch ._higher_order_ops .out_dtype import out_dtype
15
15
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
16
16
from torch .fx import GraphModule
17
+ from torch .fx .passes .utils .matcher_with_name_node_map_utils import InternalMatch
17
18
from torch .fx .subgraph_rewriter import replace_pattern_with_filters
18
19
19
20
from torchao .quantization .pt2e .export_utils import WrapperModule
23
24
_replace_literals_with_new_placeholders ,
24
25
remove_tensor_overload_for_qdq_ops ,
25
26
)
27
+ from torchao .quantization .quant_primitives import MappingType
28
+ from torchao .quantization .utils import _get_per_token_block_size
29
+ from torchao .utils import _register_custom_op
26
30
27
31
try :
28
32
from torch ._export .utils import _disable_aten_to_metadata_assertions
29
33
except :
30
34
_disable_aten_to_metadata_assertions = contextlib .nullcontext
31
35
36
+ quant_lib = torch .library .Library ("torchao" , "FRAGMENT" )
37
+ register_custom_op = _register_custom_op (quant_lib )
32
38
33
39
__all__ = [
34
40
"reference_representation_rewrite" ,
@@ -203,6 +209,252 @@ def _reference_dynamic_quantized_linear(
203
209
return out_fp32
204
210
205
211
212
+ def _qdq_dynamic_quantized_linear_4bit_groupwise (
213
+ x_fp32 ,
214
+ x_eps ,
215
+ weight_i4 ,
216
+ weight_scale ,
217
+ weight_zero_point ,
218
+ bias_fp32 ,
219
+ group_size ,
220
+ ):
221
+ # Dynamic quantization of activation
222
+ x_mapping_type = MappingType .ASYMMETRIC
223
+ per_token_block_size = _get_per_token_block_size (x_fp32 )
224
+ x_quant_min = - 128
225
+ x_quant_max = 127
226
+ x_scale , x_zero_point = torch .ops .torchao .choose_qparams_affine (
227
+ x_fp32 ,
228
+ x_mapping_type .name ,
229
+ per_token_block_size ,
230
+ torch .int8 ,
231
+ x_quant_min ,
232
+ x_quant_max ,
233
+ x_eps ,
234
+ torch .float32 ,
235
+ torch .int32 ,
236
+ )
237
+ x_i8 = torch .ops .torchao .quantize_affine (
238
+ x_fp32 ,
239
+ per_token_block_size ,
240
+ x_scale ,
241
+ x_zero_point ,
242
+ torch .int8 ,
243
+ x_quant_min ,
244
+ x_quant_max ,
245
+ )
246
+ x_fp32 = torch .ops .torchao .dequantize_affine (
247
+ x_i8 ,
248
+ per_token_block_size ,
249
+ x_scale ,
250
+ x_zero_point ,
251
+ torch .int8 ,
252
+ x_quant_min ,
253
+ x_quant_max ,
254
+ torch .float32 ,
255
+ )
256
+
257
+ assert group_size > 0 , "Group size must be positive"
258
+ assert (
259
+ weight_i4 .shape [1 ] % group_size == 0
260
+ ), "Weight must be divisible by group_size"
261
+ assert weight_i4 .dim () == 2 , "Weight must be 2D tensor"
262
+ block_size = (1 , group_size )
263
+ weight_fp32 = torch .ops .torchao .dequantize_affine (
264
+ weight_i4 ,
265
+ block_size ,
266
+ weight_scale ,
267
+ weight_zero_point ,
268
+ torch .int8 ,
269
+ - 8 ,
270
+ 7 ,
271
+ )
272
+
273
+ out_fp32 = torch .ops .aten .linear .default (x_fp32 , weight_fp32 , bias_fp32 )
274
+ return out_fp32
275
+
276
+
277
+ @register_custom_op
278
+ def _reference_dqlinear_int4 (
279
+ x_fp32 : torch .Tensor ,
280
+ x_eps : float ,
281
+ weight_i4 : torch .Tensor ,
282
+ weight_scale : torch .Tensor ,
283
+ weight_zero_point : torch .Tensor , # Not used because assuming weight is symmetric
284
+ bias_fp32 : Optional [torch .Tensor ],
285
+ group_size : List [int ],
286
+ ) -> torch .Tensor :
287
+ """
288
+ Reference implementation for dynamically quantized linear 4-bit groupwise operation.
289
+ This implementation emulates actual numerics of on-device integer compute.
290
+
291
+ Args:
292
+ x_fp32: Input activation tensor in fp32
293
+ x_eps: Epsilon for quantization parameter computation
294
+ weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
295
+ weight_scale: Groupwise scales for weight dequantization
296
+ weight_zero_point: Groupwise zero points for weight (unused for symmetric)
297
+ bias_fp32: Optional bias tensor in fp32
298
+ group_size: Size of each group for groupwise quantization
299
+
300
+ Returns:
301
+ Output tensor in fp32
302
+ """
303
+ # Dynamic quantization of activation
304
+ group_size = group_size [1 ]
305
+ x_mapping_type = MappingType .ASYMMETRIC
306
+ per_token_block_size = _get_per_token_block_size (x_fp32 )
307
+ x_quant_min = - 128
308
+ x_quant_max = 127
309
+ x_scale , x_zero_point = torch .ops .torchao .choose_qparams_affine (
310
+ x_fp32 ,
311
+ x_mapping_type .name ,
312
+ per_token_block_size ,
313
+ torch .int8 ,
314
+ x_quant_min ,
315
+ x_quant_max ,
316
+ x_eps ,
317
+ torch .float32 ,
318
+ torch .int32 ,
319
+ )
320
+ x_i8 = torch .ops .torchao .quantize_affine (
321
+ x_fp32 ,
322
+ per_token_block_size ,
323
+ x_scale ,
324
+ x_zero_point ,
325
+ torch .int8 ,
326
+ x_quant_min ,
327
+ x_quant_max ,
328
+ )
329
+
330
+ # For groupwise quantization, we need to handle the computation differently
331
+ # weight_i4 shape: [out_features, in_features]
332
+ # weight_scale shape: [out_features, in_features // group_size]
333
+ # weight_zero_point shape: [out_features, in_features // group_size]
334
+ out_features , in_features = weight_i4 .shape
335
+ num_groups = in_features // group_size
336
+
337
+ # scales in xnnpack are stored as bf16 and converted to fp32 for computation
338
+ weight_scale = weight_scale .to (torch .bfloat16 ).to (torch .float32 )
339
+
340
+ # Reshape for group-wise processing
341
+ # x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
342
+ x_orig_shape = x_i8 .shape
343
+ k_dim = x_i8 .shape [- 1 ]
344
+ x_i8 = x_i8 .view (- 1 , k_dim )
345
+ batch_size = x_i8 .shape [0 ]
346
+ x_i8_grouped = x_i8 .view (batch_size , num_groups , group_size )
347
+
348
+ # weight: [out_features, in_features] -> [out_features, num_groups, group_size]
349
+ weight_i4_grouped = weight_i4 .view (out_features , num_groups , group_size )
350
+
351
+ # Convert to int16 for computation
352
+ x_i32_grouped = x_i8_grouped .to (torch .int32 )
353
+ weight_i32_grouped = weight_i4_grouped .to (torch .int32 )
354
+
355
+ # Perform groupwise integer linear operation
356
+ acc_fp32 = torch .zeros (
357
+ batch_size , out_features , dtype = torch .float32 , device = x_fp32 .device
358
+ )
359
+ out_shape = list (x_orig_shape )
360
+ out_shape [- 1 ] = out_features
361
+
362
+ if weight_scale .ndim == 1 :
363
+ weight_scale = weight_scale .unsqueeze (0 )
364
+
365
+ for group_idx in range (num_groups ):
366
+ # Extract current group
367
+ x_group = x_i32_grouped [:, group_idx , :] # [batch_size, group_size]
368
+ weight_group = weight_i32_grouped [:, group_idx , :] # [out_features, group_size]
369
+ weight_group_col_sum = weight_group .sum (dim = - 1 ) # [out_features]
370
+
371
+ # Get scale for this group
372
+ weight_scale_group = weight_scale [:, group_idx ] # [out_features]
373
+
374
+ # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
375
+ group_acc = out_dtype (
376
+ torch .ops .aten .linear .default ,
377
+ torch .int32 ,
378
+ x_group ,
379
+ weight_group ,
380
+ None ,
381
+ )
382
+
383
+ # Output has to be scaled by x_scale * weight_scale_group
384
+ # However we will first scale by weight_scale_group, that is accounting
385
+ # only for scale of weight, and then scale by x_scale at the end because
386
+ # x_scale applies to all groups
387
+ acc_fp32 = acc_fp32 + group_acc .to (torch .float32 ) * weight_scale_group .view (
388
+ 1 , - 1
389
+ )
390
+
391
+ # we must also subtract x_zero_point * weight_group_sum
392
+ # since (X - x_zero_point) * W = X * W - x_zero_point * W
393
+ weights_col_sum_adjusted = (
394
+ weight_group_col_sum .to (torch .float32 ).view (1 , - 1 )
395
+ * x_zero_point .view (- 1 , 1 )
396
+ * weight_scale_group .view (1 , - 1 )
397
+ )
398
+ acc_fp32 = acc_fp32 - weights_col_sum_adjusted
399
+ x_scale_multiplier = x_scale .view (- 1 , 1 )
400
+ out_fp32 = acc_fp32 * x_scale_multiplier
401
+ if bias_fp32 is not None :
402
+ out_fp32 = out_fp32 + bias_fp32
403
+
404
+ return out_fp32 .view (out_shape )
405
+
406
+
407
+ def _reference_dynamic_quantized_linear_4bit_groupwise (
408
+ x_fp32 ,
409
+ x_eps ,
410
+ weight_i4 ,
411
+ weight_scale ,
412
+ weight_zero_point , # Not used because assuming weight is symmetric
413
+ bias_fp32 ,
414
+ group_size ,
415
+ ):
416
+ """
417
+ Reference implementation for dynamically quantized linear 4-bit groupwise operation.
418
+ This function now delegates to the custom op implementation.
419
+ """
420
+ return torch .ops .torchao .reference_dqlinear_int4 (
421
+ x_fp32 ,
422
+ x_eps ,
423
+ weight_i4 ,
424
+ weight_scale ,
425
+ weight_zero_point ,
426
+ bias_fp32 ,
427
+ (1 , group_size ),
428
+ )
429
+
430
+
431
+ def _filter_fn_for_dynamic_quantized_linear_4bit_groupwise (
432
+ match ,
433
+ original_graph ,
434
+ pattern_graph ,
435
+ ) -> bool :
436
+ weight_is_int4 = False
437
+ act_quant_is_int8 = False
438
+ for node in match .nodes_map .values ():
439
+ if (
440
+ isinstance (node , torch .fx .Node )
441
+ and node .op == "call_function"
442
+ and node .target == torch .ops .torchao .dequantize_affine .default
443
+ ):
444
+ args = node .args
445
+ if len (args ) >= 7 :
446
+ weight_is_int4 = args [5 ] == - 8 and args [6 ] == 7
447
+ if (
448
+ isinstance (node , torch .fx .Node )
449
+ and node .op == "call_function"
450
+ and node .target == torch .ops .torchao .quantize_affine .default
451
+ ):
452
+ args = node .args
453
+ if len (args ) >= 5 :
454
+ act_quant_is_int8 = args [4 ] == torch .int8
455
+ return weight_is_int4 and act_quant_is_int8
456
+
457
+
206
458
def _qdq_quantized_conv2d (
207
459
x_i8 ,
208
460
x_scale ,
@@ -627,6 +879,9 @@ class _RewriteInfo:
627
879
# post transformation on the exported pattern and replacement GraphModule
628
880
pattern_post_trans : Optional [Callable [[GraphModule ], GraphModule ]] = None
629
881
replacement_post_trans : Optional [Callable [[GraphModule ], GraphModule ]] = None
882
+ filter_fn : Optional [
883
+ list [Callable [["InternalMatch" , torch .fx .Graph , torch .fx .Graph ], bool ]]
884
+ ] = None
630
885
ignore_literals : bool = False
631
886
632
887
@@ -739,6 +994,31 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739
994
127 ,
740
995
)
741
996
997
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1 = (
998
+ torch .randn ((1 , 32 ), dtype = torch .float ), # x_fp32
999
+ torch .finfo (torch .float32 ).eps , # x_eps
1000
+ torch .randint (- 8 , 7 , (8 , 32 ), dtype = torch .int8 ), # weight_i4 (stored as int8)
1001
+ torch .randn (8 , 4 , dtype = torch .float ), # weight_scale [out_features, num_groups]
1002
+ torch .zeros (
1003
+ 8 , 4 , dtype = torch .int
1004
+ ), # weight_zero_point [out_features, num_groups]
1005
+ torch .randn (8 , dtype = torch .float ), # bias_fp32
1006
+ 8 , # group_size
1007
+ )
1008
+
1009
+ # just saw that we can match again > 2 dim input. Hacky.
1010
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2 = (
1011
+ torch .randn ((1 , 1 , 32 ), dtype = torch .float ), # x_fp32
1012
+ torch .finfo (torch .float32 ).eps , # x_eps
1013
+ torch .randint (- 8 , 7 , (8 , 32 ), dtype = torch .int8 ), # weight_i4 (stored as int8)
1014
+ torch .randn (8 , 4 , dtype = torch .float ), # weight_scale [out_features, num_groups]
1015
+ torch .zeros (
1016
+ 8 , 4 , dtype = torch .int
1017
+ ), # weight_zero_point [out_features, num_groups]
1018
+ torch .randn (8 , dtype = torch .float ), # bias_fp32
1019
+ 8 , # group_size
1020
+ )
1021
+
742
1022
_REWRITE_INFO_LIST = [
743
1023
_RewriteInfo (
744
1024
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
@@ -753,6 +1033,48 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
753
1033
literal_to_ph_idx = {- 128 : 1 , 127 : 2 , torch .finfo (torch .float32 ).eps : 3 },
754
1034
),
755
1035
),
1036
+ _RewriteInfo (
1037
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1 ,
1038
+ WrapperModule (_qdq_dynamic_quantized_linear_4bit_groupwise ),
1039
+ WrapperModule (_reference_dynamic_quantized_linear_4bit_groupwise ),
1040
+ partial (
1041
+ _replace_literals_with_existing_placeholders ,
1042
+ literal_to_ph_idx = {
1043
+ torch .finfo (torch .float32 ).eps : 1 ,
1044
+ (1 , 8 ): 6 ,
1045
+ },
1046
+ ),
1047
+ partial (
1048
+ _replace_literals_with_existing_placeholders ,
1049
+ literal_to_ph_idx = {
1050
+ torch .finfo (torch .float32 ).eps : 1 ,
1051
+ (1 , 8 ): 6 ,
1052
+ },
1053
+ ),
1054
+ filter_fn = [_filter_fn_for_dynamic_quantized_linear_4bit_groupwise ],
1055
+ ignore_literals = True ,
1056
+ ),
1057
+ _RewriteInfo (
1058
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2 ,
1059
+ WrapperModule (_qdq_dynamic_quantized_linear_4bit_groupwise ),
1060
+ WrapperModule (_reference_dynamic_quantized_linear_4bit_groupwise ),
1061
+ partial (
1062
+ _replace_literals_with_existing_placeholders ,
1063
+ literal_to_ph_idx = {
1064
+ torch .finfo (torch .float32 ).eps : 1 ,
1065
+ (1 , 8 ): 6 ,
1066
+ },
1067
+ ),
1068
+ partial (
1069
+ _replace_literals_with_existing_placeholders ,
1070
+ literal_to_ph_idx = {
1071
+ torch .finfo (torch .float32 ).eps : 1 ,
1072
+ (1 , 8 ): 6 ,
1073
+ },
1074
+ ),
1075
+ filter_fn = [_filter_fn_for_dynamic_quantized_linear_4bit_groupwise ],
1076
+ ignore_literals = True ,
1077
+ ),
756
1078
_RewriteInfo (
757
1079
_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
758
1080
WrapperModule (_qdq_quantized_linear ),
@@ -835,7 +1157,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
835
1157
model ,
836
1158
pattern ,
837
1159
replacement ,
838
- match_filters = None ,
1160
+ match_filters = rewrite_info . filter_fn ,
839
1161
ignore_literals = rewrite_info .ignore_literals ,
840
1162
) # type: ignore[arg-type]
841
1163
0 commit comments