Skip to content

Commit 8b1fca1

Browse files
committed
[WIP] Make AWQ more general
Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: TODO Reviewers: Subscribers: Tasks: Tags:
1 parent c561d26 commit 8b1fca1

File tree

10 files changed

+809
-41
lines changed

10 files changed

+809
-41
lines changed

torchao/_models/_eval.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ def _model_call(self, inps):
5757

5858
max_seq_length = min(max(inps.size()), self.max_length)
5959
with torch.device(self._device):
60-
self._model.setup_caches(self.batch_size, max_seq_length)
60+
if hasattr(self._model, "setup_caches"):
61+
self._model.setup_caches(self.batch_size, max_seq_length)
6162
logits = self._model(*input)
63+
from transformers.modeling_outputs import CausalLMOutputWithPast
64+
65+
if isinstance(logits, CausalLMOutputWithPast):
66+
logits = logits.logits
6267
return logits
6368

6469
def run_eval(self, tasks, limit):
@@ -84,7 +89,11 @@ def eot_token_id(self):
8489
try:
8590
return self.tokenizer.eos_id()
8691
except:
87-
return self.tokenizer.eos_id
92+
try:
93+
return self.tokenizer.eos_id
94+
except:
95+
idx = self.tokenizer.all_special_tokens.index("<|endoftext|>")
96+
return self.tokenizer.all_special_ids[idx]
8897

8998
@property
9099
def max_length(self):
@@ -102,8 +111,8 @@ def batch_size(self):
102111
def device(self):
103112
return self._device
104113

105-
def tok_decode(self, tokens):
106-
decoded = self.tokenizer.decode(tokens)
114+
def tok_decode(self, tokens, **kwargs):
115+
decoded = self.tokenizer.decode(tokens, **kwargs)
107116
return decoded
108117

109118
def tok_encode(self, string: str, **kwargs):
@@ -115,8 +124,8 @@ def tok_encode(self, string: str, **kwargs):
115124
tokens = [self.tokenizer.bos_id] + tokens
116125
return tokens
117126

118-
def _model_generate(self, context, max_length, eos_token_id):
119-
raise Exception("unimplemented")
127+
# def _model_generate(self, context, max_length, stop, **generation_kwargs):
128+
# super()._model_generate(context, max_length, stop, **generation_kwargs)
120129

121130

122131
class LMEvalInputRecorder(TransformerEvalWrapper):

torchao/_models/llama/eval.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,87 @@ def run_evaluation(
237237
quantize_(
238238
model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64)
239239
)
240+
elif quantization.startswith("awq-uintx"):
241+
from torchao._models._eval import TransformerEvalWrapper
242+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
243+
244+
if not TORCH_VERSION_AT_LEAST_2_3:
245+
print("Awq requires torch2.3+")
246+
exit()
247+
from torchao.prototype.awq import (
248+
AWQObservedLinear,
249+
awq_uintx,
250+
insert_awq_observer_,
251+
)
252+
253+
quant_dtype = quantization.split("-")[1]
254+
group_size = int(quantization.split("-")[2])
255+
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
256+
model = model.to(device)
257+
# get calibration data
258+
insert_awq_observer_(
259+
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
260+
)
261+
TransformerEvalWrapper(
262+
model=model.to(device),
263+
tokenizer=tokenizer,
264+
max_seq_length=256,
265+
input_prep_func=prepare_inputs_for_model,
266+
device=device,
267+
).run_eval(
268+
tasks=["wikitext"],
269+
limit=1,
270+
)
271+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
272+
use_hqq = "hqq" in quantization
273+
quantize_(
274+
model,
275+
awq_uintx(
276+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
277+
),
278+
is_observed_linear,
279+
)
280+
281+
elif quantization.startswith("awq-8da4w"):
282+
from torchao._models._eval import TransformerEvalWrapper
283+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
284+
285+
if not TORCH_VERSION_AT_LEAST_2_3:
286+
print("Awq requires torch2.3+")
287+
exit()
288+
from torchao.prototype.awq import (
289+
AWQObservedLinear,
290+
awq_uintx,
291+
insert_awq_observer_,
292+
)
293+
294+
quant_dtype = quantization.split("-")[1]
295+
group_size = int(quantization.split("-")[2])
296+
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
297+
model = model.to(device)
298+
# get calibration data
299+
insert_awq_observer_(
300+
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
301+
)
302+
TransformerEvalWrapper(
303+
model=model.to(device),
304+
tokenizer=tokenizer,
305+
max_seq_length=256,
306+
input_prep_func=prepare_inputs_for_model,
307+
device=device,
308+
).run_eval(
309+
tasks=["wikitext"],
310+
limit=1,
311+
)
312+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
313+
use_hqq = "hqq" in quantization
314+
quantize_(
315+
model,
316+
awq_uintx(
317+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
318+
),
319+
is_observed_linear,
320+
)
240321

241322
if compile:
242323
model = torch.compile(model, mode="max-autotune", fullgraph=True)

torchao/prototype/awq/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from .api import awq_uintx, insert_awq_observer_
1+
from .api import AWQConfig, awq_uintx, insert_awq_observer_
22
from .core import AWQObservedLinear
33

44
__all__ = [
55
"awq_uintx",
66
"insert_awq_observer_",
77
"AWQObservedLinear",
8+
"AWQConfig",
89
]

torchao/prototype/awq/api.py

Lines changed: 117 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import types
77
from dataclasses import dataclass
8-
from typing import Optional
8+
from typing import List, Optional
99

1010
import torch
1111

@@ -30,12 +30,15 @@
3030
ZeroPointDomain,
3131
)
3232
from torchao.quantization.transform_module import (
33+
_QUANTIZE_CONFIG_HANDLER,
3334
register_quantize_module_handler,
3435
)
36+
from torchao.utils import DummyModule
3537

3638
from .core import (
3739
AWQObservedLinear,
3840
AWQObserver,
41+
AWQObserver2,
3942
)
4043

4144
assert len(_DTYPE_TO_BIT_WIDTH) > 0, (
@@ -50,6 +53,7 @@ def insert_awq_observer_(
5053
quant_dtype: torch.dtype = torch.uint4,
5154
scale_search_space_size: int = 20,
5255
group_size: int = 128,
56+
base_config: Optional[AOBaseConfig] = None,
5357
):
5458
"""
5559
Inserts AWQObserver into Linear layers of a given model.
@@ -80,22 +84,32 @@ def insert_awq_observer_(
8084

8185
def replace_with_observer(layer):
8286
# creates observer and replaces linear layers with AWQObservedLinear layers
83-
observer = AWQObserver(
84-
layer.weight,
85-
layer.bias,
86-
quantization_granularity,
87-
mapping_type,
88-
quant_dtype,
89-
n_validation_examples,
90-
validation_sequence_len,
91-
scale_search_space_size,
92-
preserve_zero=preserve_zero,
93-
zero_point_domain=zero_point_domain,
94-
zero_point_dtype=zero_point_dtype,
95-
quant_min=quant_min,
96-
quant_max=quant_max,
97-
eps=eps,
98-
)
87+
if base_config is None:
88+
observer = AWQObserver(
89+
layer.weight,
90+
layer.bias,
91+
quantization_granularity,
92+
mapping_type,
93+
quant_dtype,
94+
n_validation_examples,
95+
validation_sequence_len,
96+
scale_search_space_size,
97+
preserve_zero=preserve_zero,
98+
zero_point_domain=zero_point_domain,
99+
zero_point_dtype=zero_point_dtype,
100+
quant_min=quant_min,
101+
quant_max=quant_max,
102+
eps=eps,
103+
)
104+
else:
105+
observer = AWQObserver2(
106+
layer.weight,
107+
layer.bias,
108+
base_config,
109+
n_validation_examples,
110+
validation_sequence_len,
111+
scale_search_space_size,
112+
)
99113
return AWQObservedLinear.from_float(layer, observer)
100114

101115
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
@@ -194,3 +208,89 @@ def _awq_uintx_transform(
194208
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
195209
linear.bias = observed_linear.bias
196210
return linear
211+
212+
213+
@dataclass
214+
class AWQConfig(AOBaseConfig):
215+
"""
216+
Configuration for quantizing linear layers when passed into quantize_()
217+
218+
Args:
219+
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
220+
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
221+
group_size: Quantization granularity. Use -1 for channel wise quantization
222+
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
223+
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
224+
"""
225+
226+
base_config: AOBaseConfig
227+
step: str = "convert"
228+
example_input_shape: Optional[List[int]] = None
229+
scale_search_space_size: int = 20
230+
set_inductor_config: bool = True
231+
232+
def __post_init__(self):
233+
OPTIONS = ["calibrate", "convert", "load"]
234+
assert self.step in OPTIONS, f"Only {OPTIONS} are supported, got {self.step}"
235+
236+
237+
@register_quantize_module_handler(AWQConfig)
238+
def _awq_transform(
239+
module: torch.nn.Module,
240+
config: AWQUIntXConfig,
241+
) -> torch.nn.Module:
242+
if config.set_inductor_config:
243+
torchao.quantization.utils.recommended_inductor_config_setter()
244+
245+
step = config.step
246+
scale_search_space_size = config.scale_search_space_size
247+
observed_linear = None
248+
base_config = config.base_config
249+
250+
if step == "calibrate":
251+
observer = AWQObserver2(
252+
module.weight,
253+
module.bias,
254+
base_config,
255+
scale_search_space_size,
256+
)
257+
return AWQObservedLinear.from_float(module, observer)
258+
elif step == "load":
259+
# loading from pre-quantized checkpoint
260+
observer = AWQObserver2(
261+
module.weight,
262+
module.bias,
263+
base_config,
264+
scale_search_space_size,
265+
)
266+
observed_linear = AWQObservedLinear.from_float(module, observer)
267+
for _ in range(10):
268+
example_input = torch.randn(
269+
config.example_input_shape,
270+
device=module.weight.device,
271+
dtype=module.weight.dtype,
272+
)
273+
observed_linear(example_input)
274+
else:
275+
observed_linear = module
276+
277+
assert observed_linear is not None
278+
equalization_scale = observed_linear.act_obs.calculate_qparams()
279+
280+
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)]
281+
dummy_mod = DummyModule(observed_linear.weight * equalization_scale)
282+
quant_mod = base_config_handler(dummy_mod, config.base_config)
283+
qw = quant_mod.weight
284+
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)
285+
286+
linear = torch.nn.Linear(
287+
observed_linear.in_features,
288+
observed_linear.out_features,
289+
observed_linear.bias != None,
290+
device=observed_linear.weight.device,
291+
dtype=observed_linear.weight.dtype,
292+
)
293+
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
294+
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)
295+
linear.bias = observed_linear.bias
296+
return linear

0 commit comments

Comments
 (0)