6
6
* then we apply equalization scale to linear activation with to_weight_tensor_with_linear_activation_scale_metadata (input activation will be divided by equalization_scale), and then call F.linear with
7
7
scaled input activation and quantized weight (so we can reuse the efficient quantized linear kernels used by quantized weight)
8
8
"""
9
- import torch
9
+
10
10
import copy
11
11
12
+ import torch
12
13
import torch .nn .functional as F
13
14
from torch import Tensor
15
+
14
16
from torchao .dtypes import (
15
- to_affine_quantized_intx_static ,
16
- to_affine_quantized_floatx_static ,
17
17
Float8Layout ,
18
+ to_affine_quantized_floatx_static ,
19
+ to_affine_quantized_intx_static ,
18
20
)
19
- from torchao .quantization .utils import compute_error
20
- from torchao .quantization import quantize_
21
- from torchao .quantization import to_weight_tensor_with_linear_activation_scale_metadata
22
- from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
23
- from torchao .quantization .observer import (
24
- AffineQuantizedMinMaxObserver ,
21
+ from torchao .quantization import (
22
+ quantize_ ,
23
+ to_weight_tensor_with_linear_activation_scale_metadata ,
25
24
)
26
25
from torchao .quantization .granularity import (
27
26
PerAxis ,
28
27
PerTensor ,
29
28
)
29
+ from torchao .quantization .observer import (
30
+ AffineQuantizedMinMaxObserver ,
31
+ )
32
+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
30
33
from torchao .quantization .quant_primitives import (
31
34
MappingType ,
32
- FP8_TYPES ,
33
35
)
36
+ from torchao .quantization .utils import compute_error
34
37
35
38
36
39
class ObservedLinear (torch .nn .Linear ):
37
- def __init__ (self , in_features : int , out_features : int , act_obs : torch .nn .Module , weight_obs : torch .nn .Module , bias : bool = True , device = None , dtype = None ):
40
+ def __init__ (
41
+ self ,
42
+ in_features : int ,
43
+ out_features : int ,
44
+ act_obs : torch .nn .Module ,
45
+ weight_obs : torch .nn .Module ,
46
+ bias : bool = True ,
47
+ device = None ,
48
+ dtype = None ,
49
+ ):
38
50
super ().__init__ (in_features , out_features , bias , device , dtype )
39
51
self .act_obs = act_obs
40
52
self .weight_obs = weight_obs
@@ -46,11 +58,20 @@ def forward(self, input: Tensor):
46
58
47
59
@classmethod
48
60
def from_float (cls , float_linear , act_obs , weight_obs ):
49
- observed_linear = cls (float_linear .in_features , float_linear .out_features , act_obs , weight_obs , False , device = float_linear .weight .device , dtype = float_linear .weight .dtype )
61
+ observed_linear = cls (
62
+ float_linear .in_features ,
63
+ float_linear .out_features ,
64
+ act_obs ,
65
+ weight_obs ,
66
+ False ,
67
+ device = float_linear .weight .device ,
68
+ dtype = float_linear .weight .dtype ,
69
+ )
50
70
observed_linear .weight = float_linear .weight
51
71
observed_linear .bias = float_linear .bias
52
72
return observed_linear
53
73
74
+
54
75
def insert_observers_ (model , act_obs , weight_obs ):
55
76
_is_linear = lambda m , fqn : isinstance (m , torch .nn .Linear )
56
77
@@ -61,22 +82,39 @@ def replacement_fn(m):
61
82
62
83
_replace_with_custom_fn_if_matches_filter (model , replacement_fn , _is_linear )
63
84
85
+
64
86
# converting observed linear module to linear module with quantzied weights (and quantized activations)
65
87
# with tensor subclasses
66
88
def apply_awq (target_dtype : torch .dtype ):
67
89
# target_dtype = torch.uint8
68
90
def _apply_awq_to_linear (observed_linear ):
69
91
# weight quantization
70
92
weight_scale , weight_zero_point = observed_linear .weight_obs .calculate_qparams ()
93
+
71
94
def weight_quant_func (weight ):
72
95
block_size = (1 , weight .shape [1 ])
73
96
if target_dtype == torch .uint8 :
74
- return to_affine_quantized_intx_static (weight , weight_scale , weight_zero_point , block_size , target_dtype )
97
+ return to_affine_quantized_intx_static (
98
+ weight , weight_scale , weight_zero_point , block_size , target_dtype
99
+ )
75
100
elif target_dtype == torch .float8_e4m3fn :
76
- return to_affine_quantized_floatx_static (weight , weight_scale , block_size , target_dtype , Float8Layout (mm_config = None ))
101
+ return to_affine_quantized_floatx_static (
102
+ weight ,
103
+ weight_scale ,
104
+ block_size ,
105
+ target_dtype ,
106
+ Float8Layout (mm_config = None ),
107
+ )
77
108
else :
78
109
raise ValueError (f"Unsupported target dtype { target_dtype } " )
79
- linear = torch .nn .Linear (observed_linear .in_features , observed_linear .out_features , False , device = observed_linear .weight .device , dtype = observed_linear .weight .dtype )
110
+
111
+ linear = torch .nn .Linear (
112
+ observed_linear .in_features ,
113
+ observed_linear .out_features ,
114
+ False ,
115
+ device = observed_linear .weight .device ,
116
+ dtype = observed_linear .weight .dtype ,
117
+ )
80
118
linear .weight = observed_linear .weight
81
119
linear .bias = observed_linear .bias
82
120
@@ -86,16 +124,22 @@ def weight_quant_func(weight):
86
124
equalization_scale , _ = observed_linear .act_obs .calculate_qparams ()
87
125
equalization_scale = torch .ones_like (equalization_scale )
88
126
89
- linear .weight = torch .nn .Parameter (weight_quant_func (linear .weight * equalization_scale ), requires_grad = False )
127
+ linear .weight = torch .nn .Parameter (
128
+ weight_quant_func (linear .weight * equalization_scale ), requires_grad = False
129
+ )
90
130
91
- linear .weight = torch .nn .Parameter (to_weight_tensor_with_linear_activation_scale_metadata (linear .weight , equalization_scale ), requires_grad = False )
131
+ linear .weight = torch .nn .Parameter (
132
+ to_weight_tensor_with_linear_activation_scale_metadata (
133
+ linear .weight , equalization_scale
134
+ ),
135
+ requires_grad = False ,
136
+ )
92
137
93
138
return linear
94
139
95
140
return _apply_awq_to_linear
96
141
97
142
98
-
99
143
######## Test ##########
100
144
class ToyLinearModel (torch .nn .Module ):
101
145
def __init__ (self , m = 64 , n = 32 , k = 64 ):
@@ -104,7 +148,11 @@ def __init__(self, m=64, n=32, k=64):
104
148
self .linear2 = torch .nn .Linear (k , n , bias = False )
105
149
106
150
def example_inputs (self , batch_size = 1 , dtype = torch .float32 , device = "cpu" ):
107
- return (torch .randn (batch_size , self .linear1 .in_features , dtype = dtype , device = device ),)
151
+ return (
152
+ torch .randn (
153
+ batch_size , self .linear1 .in_features , dtype = dtype , device = device
154
+ ),
155
+ )
108
156
109
157
def forward (self , x ):
110
158
x = self .linear1 (x )
@@ -119,16 +167,24 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
119
167
dtype = torch .bfloat16
120
168
m = ToyLinearModel ().eval ().to (dtype ).to ("cuda" )
121
169
122
- m_for_test = copy .deepcopy (m )
123
-
124
170
m_bf16 = copy .deepcopy (m )
125
171
example_inputs = m .example_inputs (dtype = dtype , device = "cuda" )
126
172
print ("example inputs shape:" , example_inputs [0 ].shape )
127
173
128
- m_bf16 = torch .compile (m_bf16 , mode = 'max-autotune' )
129
-
130
- act_obs = AffineQuantizedMinMaxObserver (mapping_type , target_dtype , granularity_type = PerTensor (), eps = torch .finfo (torch .float32 ).eps )
131
- weight_obs = AffineQuantizedMinMaxObserver (mapping_type , target_dtype , granularity_type = PerAxis (axis = 0 ), eps = torch .finfo (torch .float32 ).eps )
174
+ m_bf16 = torch .compile (m_bf16 , mode = "max-autotune" )
175
+
176
+ act_obs = AffineQuantizedMinMaxObserver (
177
+ mapping_type ,
178
+ target_dtype ,
179
+ granularity_type = PerTensor (),
180
+ eps = torch .finfo (torch .float32 ).eps ,
181
+ )
182
+ weight_obs = AffineQuantizedMinMaxObserver (
183
+ mapping_type ,
184
+ target_dtype ,
185
+ granularity_type = PerAxis (axis = 0 ),
186
+ eps = torch .finfo (torch .float32 ).eps ,
187
+ )
132
188
133
189
before_quant = m (* example_inputs )
134
190
@@ -137,9 +193,9 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
137
193
for _ in range (10 ):
138
194
m (* example_inputs )
139
195
140
- after_obs = m (* example_inputs )
196
+ m (* example_inputs )
141
197
142
- m2 = copy .deepcopy (m )
198
+ copy .deepcopy (m )
143
199
144
200
is_observed_linear = lambda m , fqn : isinstance (m , ObservedLinear )
145
201
0 commit comments