8
8
import contextlib
9
9
from dataclasses import dataclass
10
10
from functools import partial
11
- from typing import Any , Callable , Optional
11
+ from typing import Any , Callable , List , Optional
12
12
13
13
import torch
14
14
from torch ._higher_order_ops .out_dtype import out_dtype
24
24
remove_tensor_overload_for_qdq_ops ,
25
25
)
26
26
27
+ from torchao .quantization .quant_primitives import MappingType
28
+ from torchao .quantization .utils import _get_per_token_block_size
29
+ from torchao .utils import _register_custom_op
30
+
27
31
try :
28
32
from torch ._export .utils import _disable_aten_to_metadata_assertions
29
33
except :
30
34
_disable_aten_to_metadata_assertions = contextlib .nullcontext
31
35
36
+ quant_lib = torch .library .Library ("torchao" , "FRAGMENT" )
37
+ register_custom_op = _register_custom_op (quant_lib )
32
38
33
39
__all__ = [
34
40
"reference_representation_rewrite" ,
41
+ "_qdq_dynamic_quantized_linear_4bit_groupwise" ,
42
+ "_reference_dynamic_quantized_linear_4bit_groupwise" ,
35
43
]
36
44
37
45
@@ -203,6 +211,221 @@ def _reference_dynamic_quantized_linear(
203
211
return out_fp32
204
212
205
213
214
+ def _qdq_dynamic_quantized_linear_4bit_groupwise (
215
+ x_fp32 ,
216
+ x_eps ,
217
+ weight_i4 ,
218
+ weight_scale ,
219
+ weight_zero_point ,
220
+ bias_fp32 ,
221
+ group_size ,
222
+ ):
223
+ # Dynamic quantization of activation
224
+ x_mapping_type = MappingType .ASYMMETRIC
225
+ per_token_block_size = _get_per_token_block_size (x_fp32 )
226
+ x_quant_min = - 128
227
+ x_quant_max = 127
228
+ x_scale , x_zero_point = torch .ops .torchao .choose_qparams_affine (
229
+ x_fp32 ,
230
+ x_mapping_type .name ,
231
+ per_token_block_size ,
232
+ torch .int8 ,
233
+ x_quant_min ,
234
+ x_quant_max ,
235
+ x_eps ,
236
+ torch .float32 ,
237
+ torch .int32 ,
238
+ )
239
+ x_i8 = torch .ops .torchao .quantize_affine (
240
+ x_fp32 ,
241
+ per_token_block_size ,
242
+ x_scale ,
243
+ x_zero_point ,
244
+ torch .int8 ,
245
+ x_quant_min ,
246
+ x_quant_max ,
247
+ )
248
+ x_fp32 = torch .ops .torchao .dequantize_affine (
249
+ x_i8 ,
250
+ per_token_block_size ,
251
+ x_scale ,
252
+ x_zero_point ,
253
+ torch .int8 ,
254
+ x_quant_min ,
255
+ x_quant_max ,
256
+ torch .float32 ,
257
+ )
258
+
259
+ assert group_size > 0 , "Group size must be positive"
260
+ assert (
261
+ weight_i4 .shape [1 ] % group_size == 0
262
+ ), "Weight must be divisible by group_size"
263
+ assert weight_i4 .dim () == 2 , "Weight must be 2D tensor"
264
+ block_size = (1 , group_size )
265
+ weight_fp32 = torch .ops .torchao .dequantize_affine (
266
+ weight_i4 ,
267
+ block_size ,
268
+ weight_scale ,
269
+ weight_zero_point ,
270
+ torch .int8 ,
271
+ - 8 ,
272
+ 7 ,
273
+ )
274
+
275
+ out_fp32 = torch .ops .aten .linear .default (x_fp32 , weight_fp32 , bias_fp32 )
276
+ return out_fp32
277
+
278
+
279
+ @register_custom_op
280
+ def _reference_dqlinear_int4 (
281
+ x_fp32 : torch .Tensor ,
282
+ x_eps : float ,
283
+ weight_i4 : torch .Tensor ,
284
+ weight_scale : torch .Tensor ,
285
+ weight_zero_point : torch .Tensor , # Not used because assuming weight is symmetric
286
+ bias_fp32 : Optional [torch .Tensor ],
287
+ group_size : List [int ],
288
+ ) -> torch .Tensor :
289
+ """
290
+ Reference implementation for dynamically quantized linear 4-bit groupwise operation.
291
+ This implementation emulates actual numerics of on-device integer compute.
292
+
293
+ Args:
294
+ x_fp32: Input activation tensor in fp32
295
+ x_eps: Epsilon for quantization parameter computation
296
+ weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
297
+ weight_scale: Groupwise scales for weight dequantization
298
+ weight_zero_point: Groupwise zero points for weight (unused for symmetric)
299
+ bias_fp32: Optional bias tensor in fp32
300
+ group_size: Size of each group for groupwise quantization
301
+
302
+ Returns:
303
+ Output tensor in fp32
304
+ """
305
+ # Dynamic quantization of activation
306
+ group_size = group_size [1 ]
307
+ x_mapping_type = MappingType .ASYMMETRIC
308
+ per_token_block_size = _get_per_token_block_size (x_fp32 )
309
+ x_quant_min = - 128
310
+ x_quant_max = 127
311
+ x_scale , x_zero_point = torch .ops .torchao .choose_qparams_affine (
312
+ x_fp32 ,
313
+ x_mapping_type .name ,
314
+ per_token_block_size ,
315
+ torch .int8 ,
316
+ x_quant_min ,
317
+ x_quant_max ,
318
+ x_eps ,
319
+ torch .float32 ,
320
+ torch .int32 ,
321
+ )
322
+ x_i8 = torch .ops .torchao .quantize_affine (
323
+ x_fp32 ,
324
+ per_token_block_size ,
325
+ x_scale ,
326
+ x_zero_point ,
327
+ torch .int8 ,
328
+ x_quant_min ,
329
+ x_quant_max ,
330
+ )
331
+
332
+ # For groupwise quantization, we need to handle the computation differently
333
+ # weight_i4 shape: [out_features, in_features]
334
+ # weight_scale shape: [out_features, in_features // group_size]
335
+ # weight_zero_point shape: [out_features, in_features // group_size]
336
+ out_features , in_features = weight_i4 .shape
337
+ num_groups = in_features // group_size
338
+
339
+ # scales in xnnpack are stored as bf16 and converted to fp32 for computation
340
+ weight_scale = weight_scale .to (torch .bfloat16 ).to (torch .float32 )
341
+
342
+ assert x_i8 .dim () == 2 , "x_i8 must be 2D tensor"
343
+ # Reshape for group-wise processing
344
+ # x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
345
+ batch_size = x_i8 .shape [0 ]
346
+ x_i8_grouped = x_i8 .view (batch_size , num_groups , group_size )
347
+
348
+ # weight: [out_features, in_features] -> [out_features, num_groups, group_size]
349
+ weight_i4_grouped = weight_i4 .view (out_features , num_groups , group_size )
350
+
351
+ # Convert to int16 for computation
352
+ x_i32_grouped = x_i8_grouped .to (torch .int32 )
353
+ weight_i32_grouped = weight_i4_grouped .to (torch .int32 )
354
+
355
+ # Perform groupwise integer linear operation
356
+ acc_fp32 = torch .zeros (
357
+ batch_size , out_features , dtype = torch .float32 , device = x_fp32 .device
358
+ )
359
+
360
+ if weight_scale .ndim == 1 :
361
+ weight_scale = weight_scale .unsqueeze (0 )
362
+
363
+ for group_idx in range (num_groups ):
364
+ # Extract current group
365
+ x_group = x_i32_grouped [:, group_idx , :] # [batch_size, group_size]
366
+ weight_group = weight_i32_grouped [:, group_idx , :] # [out_features, group_size]
367
+ weight_group_col_sum = weight_group .sum (dim = - 1 ) # [out_features]
368
+
369
+ # Get scale for this group
370
+ weight_scale_group = weight_scale [:, group_idx ] # [out_features]
371
+
372
+ # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
373
+ group_acc = out_dtype (
374
+ torch .ops .aten .linear .default ,
375
+ torch .int32 ,
376
+ x_group ,
377
+ weight_group ,
378
+ None ,
379
+ )
380
+
381
+ # Output has to be scaled by x_scale * weight_scale_group
382
+ # However we will first scale by weight_scale_group, that is accounting
383
+ # only for scale of weight, and then scale by x_scale at the end because
384
+ # x_scale applies to all groups
385
+ acc_fp32 = acc_fp32 + group_acc .to (torch .float32 ) * weight_scale_group .view (
386
+ 1 , - 1
387
+ )
388
+
389
+ # we must also subtract x_zero_point * weight_group_sum
390
+ # since (X - x_zero_point) * W = X * W - x_zero_point * W
391
+ weights_col_sum_adjusted = (
392
+ weight_group_col_sum .to (torch .float32 ).view (1 , - 1 )
393
+ * x_zero_point .view (- 1 , 1 )
394
+ * weight_scale_group .view (1 , - 1 )
395
+ )
396
+ acc_fp32 = acc_fp32 - weights_col_sum_adjusted
397
+ x_scale_multiplier = x_scale .view (- 1 , 1 )
398
+ out_fp32 = acc_fp32 * x_scale_multiplier
399
+ if bias_fp32 is not None :
400
+ out_fp32 = out_fp32 + bias_fp32
401
+
402
+ return out_fp32
403
+
404
+
405
+ def _reference_dynamic_quantized_linear_4bit_groupwise (
406
+ x_fp32 ,
407
+ x_eps ,
408
+ weight_i4 ,
409
+ weight_scale ,
410
+ weight_zero_point , # Not used because assuming weight is symmetric
411
+ bias_fp32 ,
412
+ group_size ,
413
+ ):
414
+ """
415
+ Reference implementation for dynamically quantized linear 4-bit groupwise operation.
416
+ This function now delegates to the custom op implementation.
417
+ """
418
+ return torch .ops .torchao .reference_dqlinear_int4 (
419
+ x_fp32 ,
420
+ x_eps ,
421
+ weight_i4 ,
422
+ weight_scale ,
423
+ weight_zero_point ,
424
+ bias_fp32 ,
425
+ (1 , group_size ),
426
+ )
427
+
428
+
206
429
def _qdq_quantized_conv2d (
207
430
x_i8 ,
208
431
x_scale ,
@@ -739,6 +962,18 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739
962
127 ,
740
963
)
741
964
965
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS = (
966
+ torch .randn ((2 , 32 ), dtype = torch .float ), # x_fp32
967
+ torch .finfo (torch .float32 ).eps , # x_eps
968
+ torch .randint (- 8 , 7 , (8 , 32 ), dtype = torch .int8 ), # weight_i4 (stored as int8)
969
+ torch .randn (8 , 4 , dtype = torch .float ), # weight_scale [out_features, num_groups]
970
+ torch .zeros (
971
+ 8 , 4 , dtype = torch .int
972
+ ), # weight_zero_point [out_features, num_groups]
973
+ torch .randn (8 , dtype = torch .float ), # bias_fp32
974
+ 8 , # group_size
975
+ )
976
+
742
977
_REWRITE_INFO_LIST = [
743
978
_RewriteInfo (
744
979
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
@@ -753,6 +988,26 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
753
988
literal_to_ph_idx = {- 128 : 1 , 127 : 2 , torch .finfo (torch .float32 ).eps : 3 },
754
989
),
755
990
),
991
+ _RewriteInfo (
992
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS ,
993
+ WrapperModule (_qdq_dynamic_quantized_linear_4bit_groupwise ),
994
+ WrapperModule (_reference_dynamic_quantized_linear_4bit_groupwise ),
995
+ partial (
996
+ _replace_literals_with_existing_placeholders ,
997
+ literal_to_ph_idx = {
998
+ torch .finfo (torch .float32 ).eps : 1 ,
999
+ (1 , 8 ): 6 ,
1000
+ },
1001
+ ),
1002
+ partial (
1003
+ _replace_literals_with_existing_placeholders ,
1004
+ literal_to_ph_idx = {
1005
+ torch .finfo (torch .float32 ).eps : 1 ,
1006
+ (1 , 8 ): 6 ,
1007
+ },
1008
+ ),
1009
+ ignore_literals = True ,
1010
+ ),
756
1011
_RewriteInfo (
757
1012
_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
758
1013
WrapperModule (_qdq_quantized_linear ),
0 commit comments