8
8
import contextlib
9
9
from dataclasses import dataclass
10
10
from functools import partial
11
- from typing import Any , Callable , Optional
11
+ from typing import Any , Callable , List , Optional
12
12
13
13
import torch
14
14
from torch ._higher_order_ops .out_dtype import out_dtype
23
23
_replace_literals_with_new_placeholders ,
24
24
remove_tensor_overload_for_qdq_ops ,
25
25
)
26
+ from torchao .quantization .quant_primitives import MappingType
27
+ from torchao .quantization .utils import _get_per_token_block_size
28
+ from torchao .utils import _register_custom_op
26
29
27
30
try :
28
31
from torch ._export .utils import _disable_aten_to_metadata_assertions
29
32
except :
30
33
_disable_aten_to_metadata_assertions = contextlib .nullcontext
31
34
35
+ quant_lib = torch .library .Library ("torchao" , "FRAGMENT" )
36
+ register_custom_op = _register_custom_op (quant_lib )
32
37
33
38
__all__ = [
34
39
"reference_representation_rewrite" ,
@@ -203,6 +208,221 @@ def _reference_dynamic_quantized_linear(
203
208
return out_fp32
204
209
205
210
211
+ def _qdq_dynamic_quantized_linear_4bit_groupwise (
212
+ x_fp32 ,
213
+ x_eps ,
214
+ weight_i4 ,
215
+ weight_scale ,
216
+ weight_zero_point ,
217
+ bias_fp32 ,
218
+ group_size ,
219
+ ):
220
+ # Dynamic quantization of activation
221
+ x_mapping_type = MappingType .ASYMMETRIC
222
+ per_token_block_size = _get_per_token_block_size (x_fp32 )
223
+ x_quant_min = - 128
224
+ x_quant_max = 127
225
+ x_scale , x_zero_point = torch .ops .torchao .choose_qparams_affine (
226
+ x_fp32 ,
227
+ x_mapping_type .name ,
228
+ per_token_block_size ,
229
+ torch .int8 ,
230
+ x_quant_min ,
231
+ x_quant_max ,
232
+ x_eps ,
233
+ torch .float32 ,
234
+ torch .int32 ,
235
+ )
236
+ x_i8 = torch .ops .torchao .quantize_affine (
237
+ x_fp32 ,
238
+ per_token_block_size ,
239
+ x_scale ,
240
+ x_zero_point ,
241
+ torch .int8 ,
242
+ x_quant_min ,
243
+ x_quant_max ,
244
+ )
245
+ x_fp32 = torch .ops .torchao .dequantize_affine (
246
+ x_i8 ,
247
+ per_token_block_size ,
248
+ x_scale ,
249
+ x_zero_point ,
250
+ torch .int8 ,
251
+ x_quant_min ,
252
+ x_quant_max ,
253
+ torch .float32 ,
254
+ )
255
+
256
+ assert group_size > 0 , "Group size must be positive"
257
+ assert (
258
+ weight_i4 .shape [1 ] % group_size == 0
259
+ ), "Weight must be divisible by group_size"
260
+ assert weight_i4 .dim () == 2 , "Weight must be 2D tensor"
261
+ block_size = (1 , group_size )
262
+ weight_fp32 = torch .ops .torchao .dequantize_affine (
263
+ weight_i4 ,
264
+ block_size ,
265
+ weight_scale ,
266
+ weight_zero_point ,
267
+ torch .int8 ,
268
+ - 8 ,
269
+ 7 ,
270
+ )
271
+
272
+ out_fp32 = torch .ops .aten .linear .default (x_fp32 , weight_fp32 , bias_fp32 )
273
+ return out_fp32
274
+
275
+
276
+ @register_custom_op
277
+ def _reference_dqlinear_int4 (
278
+ x_fp32 : torch .Tensor ,
279
+ x_eps : float ,
280
+ weight_i4 : torch .Tensor ,
281
+ weight_scale : torch .Tensor ,
282
+ weight_zero_point : torch .Tensor , # Not used because assuming weight is symmetric
283
+ bias_fp32 : Optional [torch .Tensor ],
284
+ group_size : List [int ],
285
+ ) -> torch .Tensor :
286
+ """
287
+ Reference implementation for dynamically quantized linear 4-bit groupwise operation.
288
+ This implementation emulates actual numerics of on-device integer compute.
289
+
290
+ Args:
291
+ x_fp32: Input activation tensor in fp32
292
+ x_eps: Epsilon for quantization parameter computation
293
+ weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
294
+ weight_scale: Groupwise scales for weight dequantization
295
+ weight_zero_point: Groupwise zero points for weight (unused for symmetric)
296
+ bias_fp32: Optional bias tensor in fp32
297
+ group_size: Size of each group for groupwise quantization
298
+
299
+ Returns:
300
+ Output tensor in fp32
301
+ """
302
+ # Dynamic quantization of activation
303
+ group_size = group_size [1 ]
304
+ x_mapping_type = MappingType .ASYMMETRIC
305
+ per_token_block_size = _get_per_token_block_size (x_fp32 )
306
+ x_quant_min = - 128
307
+ x_quant_max = 127
308
+ x_scale , x_zero_point = torch .ops .torchao .choose_qparams_affine (
309
+ x_fp32 ,
310
+ x_mapping_type .name ,
311
+ per_token_block_size ,
312
+ torch .int8 ,
313
+ x_quant_min ,
314
+ x_quant_max ,
315
+ x_eps ,
316
+ torch .float32 ,
317
+ torch .int32 ,
318
+ )
319
+ x_i8 = torch .ops .torchao .quantize_affine (
320
+ x_fp32 ,
321
+ per_token_block_size ,
322
+ x_scale ,
323
+ x_zero_point ,
324
+ torch .int8 ,
325
+ x_quant_min ,
326
+ x_quant_max ,
327
+ )
328
+
329
+ # For groupwise quantization, we need to handle the computation differently
330
+ # weight_i4 shape: [out_features, in_features]
331
+ # weight_scale shape: [out_features, in_features // group_size]
332
+ # weight_zero_point shape: [out_features, in_features // group_size]
333
+ out_features , in_features = weight_i4 .shape
334
+ num_groups = in_features // group_size
335
+
336
+ # scales in xnnpack are stored as bf16 and converted to fp32 for computation
337
+ weight_scale = weight_scale .to (torch .bfloat16 ).to (torch .float32 )
338
+
339
+ assert x_i8 .dim () == 2 , "x_i8 must be 2D tensor"
340
+ # Reshape for group-wise processing
341
+ # x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
342
+ batch_size = x_i8 .shape [0 ]
343
+ x_i8_grouped = x_i8 .view (batch_size , num_groups , group_size )
344
+
345
+ # weight: [out_features, in_features] -> [out_features, num_groups, group_size]
346
+ weight_i4_grouped = weight_i4 .view (out_features , num_groups , group_size )
347
+
348
+ # Convert to int16 for computation
349
+ x_i32_grouped = x_i8_grouped .to (torch .int32 )
350
+ weight_i32_grouped = weight_i4_grouped .to (torch .int32 )
351
+
352
+ # Perform groupwise integer linear operation
353
+ acc_fp32 = torch .zeros (
354
+ batch_size , out_features , dtype = torch .float32 , device = x_fp32 .device
355
+ )
356
+
357
+ if weight_scale .ndim == 1 :
358
+ weight_scale = weight_scale .unsqueeze (0 )
359
+
360
+ for group_idx in range (num_groups ):
361
+ # Extract current group
362
+ x_group = x_i32_grouped [:, group_idx , :] # [batch_size, group_size]
363
+ weight_group = weight_i32_grouped [:, group_idx , :] # [out_features, group_size]
364
+ weight_group_col_sum = weight_group .sum (dim = - 1 ) # [out_features]
365
+
366
+ # Get scale for this group
367
+ weight_scale_group = weight_scale [:, group_idx ] # [out_features]
368
+
369
+ # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
370
+ group_acc = out_dtype (
371
+ torch .ops .aten .linear .default ,
372
+ torch .int32 ,
373
+ x_group ,
374
+ weight_group ,
375
+ None ,
376
+ )
377
+
378
+ # Output has to be scaled by x_scale * weight_scale_group
379
+ # However we will first scale by weight_scale_group, that is accounting
380
+ # only for scale of weight, and then scale by x_scale at the end because
381
+ # x_scale applies to all groups
382
+ acc_fp32 = acc_fp32 + group_acc .to (torch .float32 ) * weight_scale_group .view (
383
+ 1 , - 1
384
+ )
385
+
386
+ # we must also subtract x_zero_point * weight_group_sum
387
+ # since (X - x_zero_point) * W = X * W - x_zero_point * W
388
+ weights_col_sum_adjusted = (
389
+ weight_group_col_sum .to (torch .float32 ).view (1 , - 1 )
390
+ * x_zero_point .view (- 1 , 1 )
391
+ * weight_scale_group .view (1 , - 1 )
392
+ )
393
+ acc_fp32 = acc_fp32 - weights_col_sum_adjusted
394
+ x_scale_multiplier = x_scale .view (- 1 , 1 )
395
+ out_fp32 = acc_fp32 * x_scale_multiplier
396
+ if bias_fp32 is not None :
397
+ out_fp32 = out_fp32 + bias_fp32
398
+
399
+ return out_fp32
400
+
401
+
402
+ def _reference_dynamic_quantized_linear_4bit_groupwise (
403
+ x_fp32 ,
404
+ x_eps ,
405
+ weight_i4 ,
406
+ weight_scale ,
407
+ weight_zero_point , # Not used because assuming weight is symmetric
408
+ bias_fp32 ,
409
+ group_size ,
410
+ ):
411
+ """
412
+ Reference implementation for dynamically quantized linear 4-bit groupwise operation.
413
+ This function now delegates to the custom op implementation.
414
+ """
415
+ return torch .ops .torchao .reference_dqlinear_int4 (
416
+ x_fp32 ,
417
+ x_eps ,
418
+ weight_i4 ,
419
+ weight_scale ,
420
+ weight_zero_point ,
421
+ bias_fp32 ,
422
+ (1 , group_size ),
423
+ )
424
+
425
+
206
426
def _qdq_quantized_conv2d (
207
427
x_i8 ,
208
428
x_scale ,
@@ -739,6 +959,18 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739
959
127 ,
740
960
)
741
961
962
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS = (
963
+ torch .randn ((2 , 32 ), dtype = torch .float ), # x_fp32
964
+ torch .finfo (torch .float32 ).eps , # x_eps
965
+ torch .randint (- 8 , 7 , (8 , 32 ), dtype = torch .int8 ), # weight_i4 (stored as int8)
966
+ torch .randn (8 , 4 , dtype = torch .float ), # weight_scale [out_features, num_groups]
967
+ torch .zeros (
968
+ 8 , 4 , dtype = torch .int
969
+ ), # weight_zero_point [out_features, num_groups]
970
+ torch .randn (8 , dtype = torch .float ), # bias_fp32
971
+ 8 , # group_size
972
+ )
973
+
742
974
_REWRITE_INFO_LIST = [
743
975
_RewriteInfo (
744
976
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
@@ -753,6 +985,26 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
753
985
literal_to_ph_idx = {- 128 : 1 , 127 : 2 , torch .finfo (torch .float32 ).eps : 3 },
754
986
),
755
987
),
988
+ _RewriteInfo (
989
+ _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS ,
990
+ WrapperModule (_qdq_dynamic_quantized_linear_4bit_groupwise ),
991
+ WrapperModule (_reference_dynamic_quantized_linear_4bit_groupwise ),
992
+ partial (
993
+ _replace_literals_with_existing_placeholders ,
994
+ literal_to_ph_idx = {
995
+ torch .finfo (torch .float32 ).eps : 1 ,
996
+ (1 , 8 ): 6 ,
997
+ },
998
+ ),
999
+ partial (
1000
+ _replace_literals_with_existing_placeholders ,
1001
+ literal_to_ph_idx = {
1002
+ torch .finfo (torch .float32 ).eps : 1 ,
1003
+ (1 , 8 ): 6 ,
1004
+ },
1005
+ ),
1006
+ ignore_literals = True ,
1007
+ ),
756
1008
_RewriteInfo (
757
1009
_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
758
1010
WrapperModule (_qdq_quantized_linear ),
0 commit comments