Skip to content

Commit d682cb5

Browse files
committed
[WIP] Make AWQ more general
Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: TODO Reviewers: Subscribers: Tasks: Tags:
1 parent e4f2715 commit d682cb5

File tree

6 files changed

+578
-30
lines changed

6 files changed

+578
-30
lines changed

torchao/prototype/awq/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from .api import awq_uintx, insert_awq_observer_
1+
from .api import AWQConfig, awq_uintx, insert_awq_observer_
22
from .core import AWQObservedLinear
33

44
__all__ = [
55
"awq_uintx",
66
"insert_awq_observer_",
77
"AWQObservedLinear",
8+
"AWQConfig",
89
]

torchao/prototype/awq/api.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
ZeroPointDomain,
3131
)
3232
from torchao.quantization.transform_module import (
33+
_QUANTIZE_CONFIG_HANDLER,
3334
register_quantize_module_handler,
3435
)
36+
from torchao.utils import DummyModule
3537

3638
from .core import (
3739
AWQObservedLinear,
3840
AWQObserver,
41+
AWQObserver2,
3942
)
4043

4144
assert len(_DTYPE_TO_BIT_WIDTH) > 0, (
@@ -50,6 +53,7 @@ def insert_awq_observer_(
5053
quant_dtype: torch.dtype = torch.uint4,
5154
scale_search_space_size: int = 20,
5255
group_size: int = 128,
56+
base_config: Optional[AOBaseConfig] = None,
5357
):
5458
"""
5559
Inserts AWQObserver into Linear layers of a given model.
@@ -80,22 +84,32 @@ def insert_awq_observer_(
8084

8185
def replace_with_observer(layer):
8286
# creates observer and replaces linear layers with AWQObservedLinear layers
83-
observer = AWQObserver(
84-
layer.weight,
85-
layer.bias,
86-
quantization_granularity,
87-
mapping_type,
88-
quant_dtype,
89-
n_validation_examples,
90-
validation_sequence_len,
91-
scale_search_space_size,
92-
preserve_zero=preserve_zero,
93-
zero_point_domain=zero_point_domain,
94-
zero_point_dtype=zero_point_dtype,
95-
quant_min=quant_min,
96-
quant_max=quant_max,
97-
eps=eps,
98-
)
87+
if base_config is None:
88+
observer = AWQObserver(
89+
layer.weight,
90+
layer.bias,
91+
quantization_granularity,
92+
mapping_type,
93+
quant_dtype,
94+
n_validation_examples,
95+
validation_sequence_len,
96+
scale_search_space_size,
97+
preserve_zero=preserve_zero,
98+
zero_point_domain=zero_point_domain,
99+
zero_point_dtype=zero_point_dtype,
100+
quant_min=quant_min,
101+
quant_max=quant_max,
102+
eps=eps,
103+
)
104+
else:
105+
observer = AWQObserver2(
106+
layer.weight,
107+
layer.bias,
108+
base_config,
109+
n_validation_examples,
110+
validation_sequence_len,
111+
scale_search_space_size,
112+
)
99113
return AWQObservedLinear.from_float(layer, observer)
100114

101115
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
@@ -194,3 +208,50 @@ def _awq_uintx_transform(
194208
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
195209
linear.bias = observed_linear.bias
196210
return linear
211+
212+
213+
@dataclass
214+
class AWQConfig(AOBaseConfig):
215+
"""
216+
Configuration for quantizing linear layers when passed into quantize_()
217+
218+
Args:
219+
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
220+
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
221+
group_size: Quantization granularity. Use -1 for channel wise quantization
222+
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
223+
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
224+
"""
225+
226+
base_config: AOBaseConfig
227+
set_inductor_config: bool = True
228+
229+
230+
@register_quantize_module_handler(AWQConfig)
231+
def _awq_transform(
232+
module: torch.nn.Module,
233+
config: AWQUIntXConfig,
234+
) -> torch.nn.Module:
235+
if config.set_inductor_config:
236+
torchao.quantization.utils.recommended_inductor_config_setter()
237+
238+
observed_linear = module
239+
equalization_scale = observed_linear.act_obs.calculate_qparams()
240+
241+
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)]
242+
dummy_mod = DummyModule(observed_linear.weight * equalization_scale)
243+
quant_mod = base_config_handler(dummy_mod, config.base_config)
244+
qw = quant_mod.weight
245+
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)
246+
247+
linear = torch.nn.Linear(
248+
observed_linear.in_features,
249+
observed_linear.out_features,
250+
observed_linear.bias != None,
251+
device=observed_linear.weight.device,
252+
dtype=observed_linear.weight.dtype,
253+
)
254+
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
255+
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
256+
linear.bias = observed_linear.bias
257+
return linear

torchao/prototype/awq/core.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import torch
99
import torch.nn.functional as F
1010

11+
from torchao.core.config import AOBaseConfig
1112
from torchao.dtypes import to_affine_quantized_intx
1213
from torchao.dtypes.uintx.uintx_layout import UintxLayout
14+
from torchao.quantization import Int8DynamicActivationIntxWeightConfig
1315
from torchao.quantization.granularity import Granularity
1416
from torchao.quantization.observer import (
1517
AffineQuantizedObserverBase,
@@ -18,6 +20,10 @@
1820
MappingType,
1921
ZeroPointDomain,
2022
)
23+
from torchao.quantization.transform_module import (
24+
_QUANTIZE_CONFIG_HANDLER,
25+
)
26+
from torchao.utils import DummyModule
2127

2228

2329
class AWQObserver(AffineQuantizedObserverBase):
@@ -145,6 +151,134 @@ def calculate_qparams(self):
145151
return best_scales.detach()
146152

147153

154+
class AWQObserver2(AffineQuantizedObserverBase):
155+
def __init__(
156+
self,
157+
weight: torch.Tensor,
158+
bias: torch.Tensor,
159+
config: AOBaseConfig,
160+
n_validation_examples: int,
161+
validation_sequence_len: int,
162+
scale_search_space_size: int = 20,
163+
base_config: Optional[AOBaseConfig] = None,
164+
):
165+
"""
166+
A custom observer for Activation aware Weight Quantization (AWQ)
167+
168+
Args:
169+
weight: The weight tensor to be observed.
170+
bias: The bias tensor to be observed.
171+
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
172+
input_dtype: The data type of the input tensor.
173+
mapping_type: Always set to asymmetric
174+
target_dtype: The target data type of the quantized tensor
175+
n_validation_examples: Number of examples used to calibrate observer
176+
validation_sequence_len: Number of tokens in each example
177+
scale_search_space_size: The number of scales to search for.
178+
quant_min: The minimum quantized value
179+
quant_max: The maximum quantized value
180+
eps: The minimum scale.
181+
scale_dtype: The data type of the scale tensor.
182+
zero_point_dtype: The data type of the zero point tensor.
183+
preserve_zero: A flag to indicate whether we need zero to be exactly
184+
representable or not.
185+
zero_point_domain: The domain of the zero point.
186+
"""
187+
self.base_config = base_config
188+
quant_min = getattr(config, "quant_min", None)
189+
quant_max = getattr(config, "quant_max", None)
190+
191+
assert isinstance(base_config, Int8DynamicActivationIntxWeightConfig)
192+
# TODO:
193+
quantization_granularity = base_config.weight_granularity
194+
target_dtype = base_config.weight_dtype
195+
mapping_type = base_config.weight_mapping_type
196+
197+
# TODO:
198+
super().__init__(
199+
mapping_type,
200+
target_dtype,
201+
quantization_granularity,
202+
quant_min=quant_min,
203+
quant_max=quant_max,
204+
)
205+
self.quantization_granularity = quantization_granularity
206+
self.weight = weight
207+
self.bias = bias
208+
self.n_validation_examples = n_validation_examples
209+
self.validation_sequence_len = validation_sequence_len
210+
self.calibration_token_count = 0
211+
self.inputs = []
212+
self.outputs = []
213+
self.scale_options = scale_search_space_size
214+
self.device = self.weight.device
215+
self.average = torch.zeros((1, weight.shape[1]), device=self.device)
216+
if self.bias is not None:
217+
self.bias.to(self.device)
218+
219+
@torch.no_grad()
220+
def forward(self, input: torch.Tensor, output: torch.Tensor):
221+
# import pdb
222+
# pdb.set_trace()
223+
# print(input.shape, input.abs().sum(1).shape, self.average.shape)
224+
if len(self.inputs) < self.n_validation_examples:
225+
self.inputs.append(input.to("cpu"))
226+
self.outputs.append(output.to("cpu"))
227+
self.calibration_token_count += input.shape[-2]
228+
self.average += input.abs().sum(-2)
229+
230+
def calculate_qparams(self):
231+
# import pdb
232+
# pdb.set_trace()
233+
assert self.outputs != None, (
234+
"calibrate observer first by running model on exemplar data"
235+
)
236+
self.average /= self.calibration_token_count
237+
for i in range(self.n_validation_examples):
238+
self.inputs[i] = self.inputs[i].to(self.device)
239+
self.outputs[i] = self.outputs[i].to(self.device)
240+
241+
best_loss = float("inf")
242+
best_scales = None
243+
for i in range(self.scale_options):
244+
ratio = i * 1 / self.scale_options
245+
scales = self.average.pow(ratio).to(self.weight.dtype)
246+
scales = scales / (scales.max() * scales.min()).sqrt()
247+
# layout = UintxLayout(self.target_dtype)
248+
# # regardless of weight dtype, we have to store as packed uint8 tensors
249+
# tensor_dtype = torch.uint8
250+
# w = to_affine_quantized_intx(
251+
# self.weight * scales,
252+
# self.mapping_type,
253+
# (1, self.quantization_granularity.group_size),
254+
# tensor_dtype,
255+
# quant_min=self.quant_min,
256+
# quant_max=self.quant_max,
257+
# eps=self.eps,
258+
# scale_dtype=self.scale_dtype,
259+
# zero_point_dtype=self.zero_point_dtype,
260+
# preserve_zero=self.preserve_zero,
261+
# zero_point_domain=self.zero_point_domain,
262+
# _layout=layout,
263+
# )
264+
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.base_config)]
265+
dummy_mod = DummyModule(self.weight * scales)
266+
quant_mod = base_config_handler(dummy_mod, self.base_config)
267+
w = quant_mod.weight
268+
269+
loss = 0
270+
for i in range(self.n_validation_examples):
271+
q_out = F.linear(self.inputs[i] / scales, w, self.bias)
272+
loss += (self.outputs[i] - q_out).pow(2).mean().item()
273+
if loss < best_loss:
274+
best_scales = scales
275+
best_loss = loss
276+
for i in range(self.n_validation_examples):
277+
self.inputs[i].to("cpu")
278+
self.outputs[i].to("cpu")
279+
return best_scales.detach()
280+
281+
148282
class AWQObservedLinear(torch.nn.Linear):
149283
def __init__(
150284
self,

0 commit comments

Comments
 (0)