Skip to content

Commit 4d7eeb7

Browse files
committed
[WIP] Make AWQ more general
Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: TODO Reviewers: Subscribers: Tasks: Tags:
1 parent 378e179 commit 4d7eeb7

File tree

10 files changed

+1069
-41
lines changed

10 files changed

+1069
-41
lines changed

torchao/_models/_eval.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ def _model_call(self, inps):
5757

5858
max_seq_length = min(max(inps.size()), self.max_length)
5959
with torch.device(self._device):
60-
self._model.setup_caches(self.batch_size, max_seq_length)
60+
if hasattr(self._model, "setup_caches"):
61+
self._model.setup_caches(self.batch_size, max_seq_length)
6162
logits = self._model(*input)
63+
from transformers.modeling_outputs import CausalLMOutputWithPast
64+
65+
if isinstance(logits, CausalLMOutputWithPast):
66+
logits = logits.logits
6267
return logits
6368

6469
def run_eval(self, tasks, limit):
@@ -84,7 +89,11 @@ def eot_token_id(self):
8489
try:
8590
return self.tokenizer.eos_id()
8691
except:
87-
return self.tokenizer.eos_id
92+
try:
93+
return self.tokenizer.eos_id
94+
except:
95+
idx = self.tokenizer.all_special_tokens.index("<|endoftext|>")
96+
return self.tokenizer.all_special_ids[idx]
8897

8998
@property
9099
def max_length(self):
@@ -102,8 +111,8 @@ def batch_size(self):
102111
def device(self):
103112
return self._device
104113

105-
def tok_decode(self, tokens):
106-
decoded = self.tokenizer.decode(tokens)
114+
def tok_decode(self, tokens, **kwargs):
115+
decoded = self.tokenizer.decode(tokens, **kwargs)
107116
return decoded
108117

109118
def tok_encode(self, string: str, **kwargs):
@@ -115,8 +124,8 @@ def tok_encode(self, string: str, **kwargs):
115124
tokens = [self.tokenizer.bos_id] + tokens
116125
return tokens
117126

118-
def _model_generate(self, context, max_length, eos_token_id):
119-
raise Exception("unimplemented")
127+
# def _model_generate(self, context, max_length, stop, **generation_kwargs):
128+
# super()._model_generate(context, max_length, stop, **generation_kwargs)
120129

121130

122131
class LMEvalInputRecorder(TransformerEvalWrapper):

torchao/_models/llama/eval.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,87 @@ def run_evaluation(
237237
quantize_(
238238
model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64)
239239
)
240+
elif quantization.startswith("awq-uintx"):
241+
from torchao._models._eval import TransformerEvalWrapper
242+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
243+
244+
if not TORCH_VERSION_AT_LEAST_2_3:
245+
print("Awq requires torch2.3+")
246+
exit()
247+
from torchao.prototype.awq import (
248+
AWQObservedLinear,
249+
awq_uintx,
250+
insert_awq_observer_,
251+
)
252+
253+
quant_dtype = quantization.split("-")[1]
254+
group_size = int(quantization.split("-")[2])
255+
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
256+
model = model.to(device)
257+
# get calibration data
258+
insert_awq_observer_(
259+
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
260+
)
261+
TransformerEvalWrapper(
262+
model=model.to(device),
263+
tokenizer=tokenizer,
264+
max_seq_length=256,
265+
input_prep_func=prepare_inputs_for_model,
266+
device=device,
267+
).run_eval(
268+
tasks=["wikitext"],
269+
limit=1,
270+
)
271+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
272+
use_hqq = "hqq" in quantization
273+
quantize_(
274+
model,
275+
awq_uintx(
276+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
277+
),
278+
is_observed_linear,
279+
)
280+
281+
elif quantization.startswith("awq-8da4w"):
282+
from torchao._models._eval import TransformerEvalWrapper
283+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
284+
285+
if not TORCH_VERSION_AT_LEAST_2_3:
286+
print("Awq requires torch2.3+")
287+
exit()
288+
from torchao.prototype.awq import (
289+
AWQObservedLinear,
290+
awq_uintx,
291+
insert_awq_observer_,
292+
)
293+
294+
quant_dtype = quantization.split("-")[1]
295+
group_size = int(quantization.split("-")[2])
296+
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
297+
model = model.to(device)
298+
# get calibration data
299+
insert_awq_observer_(
300+
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
301+
)
302+
TransformerEvalWrapper(
303+
model=model.to(device),
304+
tokenizer=tokenizer,
305+
max_seq_length=256,
306+
input_prep_func=prepare_inputs_for_model,
307+
device=device,
308+
).run_eval(
309+
tasks=["wikitext"],
310+
limit=1,
311+
)
312+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
313+
use_hqq = "hqq" in quantization
314+
quantize_(
315+
model,
316+
awq_uintx(
317+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
318+
),
319+
is_observed_linear,
320+
)
240321

241322
if compile:
242323
model = torch.compile(model, mode="max-autotune", fullgraph=True)

torchao/prototype/awq/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from .api import awq_uintx, insert_awq_observer_
1+
from .api import AWQConfig, awq_uintx, insert_awq_observer_
22
from .core import AWQObservedLinear
33

44
__all__ = [
55
"awq_uintx",
66
"insert_awq_observer_",
77
"AWQObservedLinear",
8+
"AWQConfig",
89
]

torchao/prototype/awq/api.py

Lines changed: 123 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import types
77
from dataclasses import dataclass
8-
from typing import Optional
8+
from typing import List, Optional
99

1010
import torch
1111

@@ -30,12 +30,15 @@
3030
ZeroPointDomain,
3131
)
3232
from torchao.quantization.transform_module import (
33+
_QUANTIZE_CONFIG_HANDLER,
3334
register_quantize_module_handler,
3435
)
36+
from torchao.utils import DummyModule
3537

3638
from .core import (
3739
AWQObservedLinear,
3840
AWQObserver,
41+
AWQObserver2,
3942
)
4043

4144
assert len(_DTYPE_TO_BIT_WIDTH) > 0, (
@@ -50,6 +53,7 @@ def insert_awq_observer_(
5053
quant_dtype: torch.dtype = torch.uint4,
5154
scale_search_space_size: int = 20,
5255
group_size: int = 128,
56+
base_config: Optional[AOBaseConfig] = None,
5357
):
5458
"""
5559
Inserts AWQObserver into Linear layers of a given model.
@@ -80,22 +84,30 @@ def insert_awq_observer_(
8084

8185
def replace_with_observer(layer):
8286
# creates observer and replaces linear layers with AWQObservedLinear layers
83-
observer = AWQObserver(
84-
layer.weight,
85-
layer.bias,
86-
quantization_granularity,
87-
mapping_type,
88-
quant_dtype,
89-
n_validation_examples,
90-
validation_sequence_len,
91-
scale_search_space_size,
92-
preserve_zero=preserve_zero,
93-
zero_point_domain=zero_point_domain,
94-
zero_point_dtype=zero_point_dtype,
95-
quant_min=quant_min,
96-
quant_max=quant_max,
97-
eps=eps,
98-
)
87+
if base_config is None:
88+
observer = AWQObserver(
89+
layer.weight,
90+
layer.bias,
91+
quantization_granularity,
92+
mapping_type,
93+
quant_dtype,
94+
n_validation_examples,
95+
validation_sequence_len,
96+
scale_search_space_size,
97+
preserve_zero=preserve_zero,
98+
zero_point_domain=zero_point_domain,
99+
zero_point_dtype=zero_point_dtype,
100+
quant_min=quant_min,
101+
quant_max=quant_max,
102+
eps=eps,
103+
)
104+
else:
105+
observer = AWQObserver2(
106+
layer.weight,
107+
layer.bias,
108+
base_config,
109+
scale_search_space_size,
110+
)
99111
return AWQObservedLinear.from_float(layer, observer)
100112

101113
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
@@ -194,3 +206,97 @@ def _awq_uintx_transform(
194206
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
195207
linear.bias = observed_linear.bias
196208
return linear
209+
210+
211+
@dataclass
212+
class AWQConfig(AOBaseConfig):
213+
"""
214+
Configuration for quantizing linear layers when passed into quantize_()
215+
216+
Args:
217+
base_config (AOBaseConfig): The quantization config that we can apply awq on top of, e.g. 8da4w, int4 weight only
218+
step (str): a string of "prepare", "convert" or "load" indicating the step of AWQ process
219+
prepare: insert AWQ Observers to linear
220+
convert: convert the observed linear modules to linear modules with awq quantized weights
221+
load: convert the floating point model to a dummy awq quantized model
222+
example_input_shape (Optional[List[int]])): This is used for load step to initialize a random example input
223+
scale_search_space_size (int): the number of scales to search for
224+
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
225+
"""
226+
227+
base_config: AOBaseConfig
228+
step: str
229+
example_input_shape: Optional[List[int]] = None
230+
scale_search_space_size: int = 20
231+
set_inductor_config: bool = True
232+
233+
def __post_init__(self):
234+
OPTIONS = ["prepare", "convert", "load"]
235+
assert self.step in OPTIONS, f"Only {OPTIONS} are supported, got {self.step}"
236+
237+
238+
@register_quantize_module_handler(AWQConfig)
239+
def _awq_transform(
240+
module: torch.nn.Module,
241+
config: AWQUIntXConfig,
242+
) -> torch.nn.Module:
243+
if config.set_inductor_config:
244+
torchao.quantization.utils.recommended_inductor_config_setter()
245+
246+
step = config.step
247+
scale_search_space_size = config.scale_search_space_size
248+
observed_linear = None
249+
base_config = config.base_config
250+
251+
if step == "prepare":
252+
observer = AWQObserver2(
253+
module.weight,
254+
module.bias,
255+
base_config,
256+
scale_search_space_size,
257+
)
258+
return AWQObservedLinear.from_float(module, observer)
259+
elif step == "load":
260+
# loading from pre-quantized checkpoint
261+
observer = AWQObserver2(
262+
module.weight,
263+
module.bias,
264+
base_config,
265+
scale_search_space_size,
266+
)
267+
observed_linear = AWQObservedLinear.from_float(module, observer)
268+
assert config.example_input_shape is not None, (
269+
"When step is load, we expect example_input_shape to be specified as well"
270+
)
271+
example_input = torch.randn(
272+
config.example_input_shape,
273+
device=module.weight.device,
274+
dtype=module.weight.dtype,
275+
)
276+
observed_linear(example_input)
277+
else:
278+
if not isinstance(module, AWQObservedLinear):
279+
print(f"convert: module is not AWQObservedLinear, skipping: {type(module)}")
280+
return module
281+
observed_linear = module
282+
283+
assert observed_linear is not None
284+
equalization_scale = observed_linear.act_obs.calculate_qparams()
285+
286+
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)]
287+
dummy_mod = DummyModule(observed_linear.weight * equalization_scale)
288+
quant_mod = base_config_handler(dummy_mod, config.base_config)
289+
qw = quant_mod.weight
290+
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)
291+
292+
linear = torch.nn.Linear(
293+
observed_linear.in_features,
294+
observed_linear.out_features,
295+
observed_linear.bias != None,
296+
device=observed_linear.weight.device,
297+
dtype=observed_linear.weight.dtype,
298+
)
299+
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
300+
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)
301+
linear.bias = observed_linear.bias
302+
return linear

0 commit comments

Comments
 (0)