1
+ from dataclasses import dataclass
2
+
1
3
import torch
2
4
3
- from torchao .quantization import int4_weight_only , int8_weight_only
4
- from torchao .quantization .quant_api import _get_linear_subclass_inserter
5
+ from torchao .core .config import AOBaseConfig
5
6
from torchao .quantization .quant_primitives import (
6
7
MappingType ,
7
8
)
9
+ from torchao .quantization .transform_module import (
10
+ register_quantize_module_handler ,
11
+ )
8
12
9
13
10
- def intN_weight_only (group_size = 32 , n = 8 , symmetric = False ):
14
+ @dataclass
15
+ class IntNWeightOnlyConfig (AOBaseConfig ):
11
16
"""
12
- Apply int N-bit weight only quantization to a linear layer.
17
+ Configuration for applying int N-bit weight only quantization to a linear layer.
13
18
Args:
14
19
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
15
20
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
@@ -18,6 +23,25 @@ def intN_weight_only(group_size=32, n=8, symmetric=False):
18
23
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
19
24
"""
20
25
26
+ group_size : int = 32
27
+ n : int = 8
28
+ symmetric : bool = False
29
+
30
+
31
+ # for bc
32
+ intN_weight_only = IntNWeightOnlyConfig
33
+
34
+
35
+ @register_quantize_module_handler (IntNWeightOnlyConfig )
36
+ def _intN_weight_only_transform (
37
+ module : torch .nn .Module ,
38
+ config : IntNWeightOnlyConfig ,
39
+ ) -> torch .nn .Module :
40
+ group_size = config .group_size
41
+ n = config .n
42
+ symmetric = config .symmetric
43
+ weight = module .weight
44
+
21
45
# for asymmetric quantization
22
46
def apply_intN_weight_only_quant_asym (weight ):
23
47
# avoid circular dependency
@@ -64,16 +88,19 @@ def apply_intN_weight_only_quant_sym(weight):
64
88
zero_point_dtype = zero_point_dtype ,
65
89
)
66
90
67
- try :
68
- assert n in [8 , 6 , 5 , 4 , 3 , 2 ], "n must be one of [8, 6, 5, 4, 3, 2]"
69
- if n == 8 :
70
- return int8_weight_only ()
71
- elif n == 4 :
72
- return int4_weight_only (group_size = group_size )
91
+ assert n in [8 , 6 , 5 , 4 , 3 , 2 ], "n must be one of [8, 6, 5, 4, 3, 2]"
92
+ if n == 8 :
93
+ raise AssertionError (
94
+ "Someone needs to refactor this code to handle int8_weight_only again"
95
+ )
96
+ elif n == 4 :
97
+ raise AssertionError (
98
+ "Someone needs to refactor this code to handle int4_weight_only again"
99
+ )
100
+ else :
101
+ if symmetric :
102
+ new_weight = apply_intN_weight_only_quant_sym (weight )
73
103
else :
74
- if symmetric :
75
- return _get_linear_subclass_inserter (apply_intN_weight_only_quant_sym )
76
- else :
77
- return _get_linear_subclass_inserter (apply_intN_weight_only_quant_asym )
78
- except Exception :
79
- raise
104
+ new_weight = apply_intN_weight_only_quant_asym (weight )
105
+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
106
+ return module
0 commit comments