Skip to content

[WIP] Make AWQ more general #2400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,13 @@ def _model_call(self, inps):

max_seq_length = min(max(inps.size()), self.max_length)
with torch.device(self._device):
self._model.setup_caches(self.batch_size, max_seq_length)
if hasattr(self._model, "setup_caches"):
self._model.setup_caches(self.batch_size, max_seq_length)
logits = self._model(*input)
from transformers.modeling_outputs import CausalLMOutputWithPast

if isinstance(logits, CausalLMOutputWithPast):
logits = logits.logits
return logits

def run_eval(self, tasks, limit):
Expand All @@ -84,7 +89,11 @@ def eot_token_id(self):
try:
return self.tokenizer.eos_id()
except:
return self.tokenizer.eos_id
try:
return self.tokenizer.eos_id
except:
idx = self.tokenizer.all_special_tokens.index("<|endoftext|>")
return self.tokenizer.all_special_ids[idx]

@property
def max_length(self):
Expand All @@ -102,8 +111,8 @@ def batch_size(self):
def device(self):
return self._device

def tok_decode(self, tokens):
decoded = self.tokenizer.decode(tokens)
def tok_decode(self, tokens, **kwargs):
decoded = self.tokenizer.decode(tokens, **kwargs)
return decoded

def tok_encode(self, string: str, **kwargs):
Expand All @@ -115,8 +124,8 @@ def tok_encode(self, string: str, **kwargs):
tokens = [self.tokenizer.bos_id] + tokens
return tokens

def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")
# def _model_generate(self, context, max_length, stop, **generation_kwargs):
# super()._model_generate(context, max_length, stop, **generation_kwargs)


class LMEvalInputRecorder(TransformerEvalWrapper):
Expand Down
81 changes: 81 additions & 0 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,87 @@ def run_evaluation(
quantize_(
model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64)
)
elif quantization.startswith("awq-uintx"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3

if not TORCH_VERSION_AT_LEAST_2_3:
print("Awq requires torch2.3+")
exit()
from torchao.prototype.awq import (
AWQObservedLinear,
awq_uintx,
insert_awq_observer_,
)

quant_dtype = quantization.split("-")[1]
group_size = int(quantization.split("-")[2])
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
model = model.to(device)
# get calibration data
insert_awq_observer_(
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
)
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=256,
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=["wikitext"],
limit=1,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = "hqq" in quantization
quantize_(
model,
awq_uintx(
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
),
is_observed_linear,
)

elif quantization.startswith("awq-8da4w"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3

if not TORCH_VERSION_AT_LEAST_2_3:
print("Awq requires torch2.3+")
exit()
from torchao.prototype.awq import (
AWQObservedLinear,
awq_uintx,
insert_awq_observer_,
)

quant_dtype = quantization.split("-")[1]
group_size = int(quantization.split("-")[2])
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
model = model.to(device)
# get calibration data
insert_awq_observer_(
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
)
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=256,
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=["wikitext"],
limit=1,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = "hqq" in quantization
quantize_(
model,
awq_uintx(
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
),
is_observed_linear,
)

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .api import awq_uintx, insert_awq_observer_
from .api import AWQConfig, awq_uintx, insert_awq_observer_
from .core import AWQObservedLinear

__all__ = [
"awq_uintx",
"insert_awq_observer_",
"AWQObservedLinear",
"AWQConfig",
]
140 changes: 123 additions & 17 deletions torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import types
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

import torch

Expand All @@ -30,12 +30,15 @@
ZeroPointDomain,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
)
from torchao.utils import DummyModule

from .core import (
AWQObservedLinear,
AWQObserver,
AWQObserver2,
)

assert len(_DTYPE_TO_BIT_WIDTH) > 0, (
Expand All @@ -50,6 +53,7 @@ def insert_awq_observer_(
quant_dtype: torch.dtype = torch.uint4,
scale_search_space_size: int = 20,
group_size: int = 128,
base_config: Optional[AOBaseConfig] = None,
):
"""
Inserts AWQObserver into Linear layers of a given model.
Expand Down Expand Up @@ -80,22 +84,30 @@ def insert_awq_observer_(

def replace_with_observer(layer):
# creates observer and replaces linear layers with AWQObservedLinear layers
observer = AWQObserver(
layer.weight,
layer.bias,
quantization_granularity,
mapping_type,
quant_dtype,
n_validation_examples,
validation_sequence_len,
scale_search_space_size,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
zero_point_dtype=zero_point_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
)
if base_config is None:
observer = AWQObserver(
layer.weight,
layer.bias,
quantization_granularity,
mapping_type,
quant_dtype,
n_validation_examples,
validation_sequence_len,
scale_search_space_size,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
zero_point_dtype=zero_point_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
)
else:
observer = AWQObserver2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you not add kwargs to the AWQObserver and just check 'base_config' in kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is temporary, I think we can deprecate the old one in the end

layer.weight,
layer.bias,
base_config,
scale_search_space_size,
)
return AWQObservedLinear.from_float(layer, observer)

_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
Expand Down Expand Up @@ -194,3 +206,97 @@ def _awq_uintx_transform(
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
linear.bias = observed_linear.bias
return linear


@dataclass
class AWQConfig(AOBaseConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this is consolidating with quantize_ api's config based design?

"""
Configuration for quantizing linear layers when passed into quantize_()

Args:
base_config (AOBaseConfig): The quantization config that we can apply awq on top of, e.g. 8da4w, int4 weight only
step (str): a string of "prepare", "convert" or "load" indicating the step of AWQ process
prepare: insert AWQ Observers to linear
convert: convert the observed linear modules to linear modules with awq quantized weights
load: convert the floating point model to a dummy awq quantized model
example_input_shape (Optional[List[int]])): This is used for load step to initialize a random example input
scale_search_space_size (int): the number of scales to search for
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
"""

base_config: AOBaseConfig
step: str
example_input_shape: Optional[List[int]] = None
scale_search_space_size: int = 20
set_inductor_config: bool = True

def __post_init__(self):
OPTIONS = ["prepare", "convert", "load"]
assert self.step in OPTIONS, f"Only {OPTIONS} are supported, got {self.step}"


@register_quantize_module_handler(AWQConfig)
def _awq_transform(
module: torch.nn.Module,
config: AWQUIntXConfig,
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

step = config.step
scale_search_space_size = config.scale_search_space_size
observed_linear = None
base_config = config.base_config

if step == "prepare":
observer = AWQObserver2(
module.weight,
module.bias,
base_config,
scale_search_space_size,
)
return AWQObservedLinear.from_float(module, observer)
elif step == "load":
# loading from pre-quantized checkpoint
observer = AWQObserver2(
module.weight,
module.bias,
base_config,
scale_search_space_size,
)
observed_linear = AWQObservedLinear.from_float(module, observer)
assert config.example_input_shape is not None, (
"When step is load, we expect example_input_shape to be specified as well"
)
example_input = torch.randn(
config.example_input_shape,
device=module.weight.device,
dtype=module.weight.dtype,
)
observed_linear(example_input)
else:
if not isinstance(module, AWQObservedLinear):
print(f"convert: module is not AWQObservedLinear, skipping: {type(module)}")
return module
observed_linear = module

assert observed_linear is not None
equalization_scale = observed_linear.act_obs.calculate_qparams()

base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)]
dummy_mod = DummyModule(observed_linear.weight * equalization_scale)
quant_mod = base_config_handler(dummy_mod, config.base_config)
Comment on lines +287 to +288
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whats happening here?. Isnt module already nn.Module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just trying to quantize the weight with the quantization type specified by config.base_config

qw = quant_mod.weight
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias != None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)
linear.bias = observed_linear.bias
return linear
Loading
Loading