diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index faf059c400..fca5b7f0af 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -57,8 +57,13 @@ def _model_call(self, inps): max_seq_length = min(max(inps.size()), self.max_length) with torch.device(self._device): - self._model.setup_caches(self.batch_size, max_seq_length) + if hasattr(self._model, "setup_caches"): + self._model.setup_caches(self.batch_size, max_seq_length) logits = self._model(*input) + from transformers.modeling_outputs import CausalLMOutputWithPast + + if isinstance(logits, CausalLMOutputWithPast): + logits = logits.logits return logits def run_eval(self, tasks, limit): @@ -84,7 +89,11 @@ def eot_token_id(self): try: return self.tokenizer.eos_id() except: - return self.tokenizer.eos_id + try: + return self.tokenizer.eos_id + except: + idx = self.tokenizer.all_special_tokens.index("<|endoftext|>") + return self.tokenizer.all_special_ids[idx] @property def max_length(self): @@ -102,8 +111,8 @@ def batch_size(self): def device(self): return self._device - def tok_decode(self, tokens): - decoded = self.tokenizer.decode(tokens) + def tok_decode(self, tokens, **kwargs): + decoded = self.tokenizer.decode(tokens, **kwargs) return decoded def tok_encode(self, string: str, **kwargs): @@ -115,8 +124,8 @@ def tok_encode(self, string: str, **kwargs): tokens = [self.tokenizer.bos_id] + tokens return tokens - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") + # def _model_generate(self, context, max_length, stop, **generation_kwargs): + # super()._model_generate(context, max_length, stop, **generation_kwargs) class LMEvalInputRecorder(TransformerEvalWrapper): diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 8ee15f1fd3..49e46c3d48 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -237,6 +237,87 @@ def run_evaluation( quantize_( model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64) ) + elif quantization.startswith("awq-uintx"): + from torchao._models._eval import TransformerEvalWrapper + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + + if not TORCH_VERSION_AT_LEAST_2_3: + print("Awq requires torch2.3+") + exit() + from torchao.prototype.awq import ( + AWQObservedLinear, + awq_uintx, + insert_awq_observer_, + ) + + quant_dtype = quantization.split("-")[1] + group_size = int(quantization.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + model = model.to(device) + # get calibration data + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=256, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=["wikitext"], + limit=1, + ) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quantization + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) + + elif quantization.startswith("awq-8da4w"): + from torchao._models._eval import TransformerEvalWrapper + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + + if not TORCH_VERSION_AT_LEAST_2_3: + print("Awq requires torch2.3+") + exit() + from torchao.prototype.awq import ( + AWQObservedLinear, + awq_uintx, + insert_awq_observer_, + ) + + quant_dtype = quantization.split("-")[1] + group_size = int(quantization.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + model = model.to(device) + # get calibration data + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=256, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=["wikitext"], + limit=1, + ) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quantization + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) if compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 570b0821d4..4f34d5375a 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,8 +1,9 @@ -from .api import awq_uintx, insert_awq_observer_ +from .api import AWQConfig, awq_uintx, insert_awq_observer_ from .core import AWQObservedLinear __all__ = [ "awq_uintx", "insert_awq_observer_", "AWQObservedLinear", + "AWQConfig", ] diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 5806c29ce6..51df030e58 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import types from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import torch @@ -30,12 +30,15 @@ ZeroPointDomain, ) from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) +from torchao.utils import DummyModule from .core import ( AWQObservedLinear, AWQObserver, + AWQObserver2, ) assert len(_DTYPE_TO_BIT_WIDTH) > 0, ( @@ -50,6 +53,7 @@ def insert_awq_observer_( quant_dtype: torch.dtype = torch.uint4, scale_search_space_size: int = 20, group_size: int = 128, + base_config: Optional[AOBaseConfig] = None, ): """ Inserts AWQObserver into Linear layers of a given model. @@ -80,22 +84,30 @@ def insert_awq_observer_( def replace_with_observer(layer): # creates observer and replaces linear layers with AWQObservedLinear layers - observer = AWQObserver( - layer.weight, - layer.bias, - quantization_granularity, - mapping_type, - quant_dtype, - n_validation_examples, - validation_sequence_len, - scale_search_space_size, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - zero_point_dtype=zero_point_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) + if base_config is None: + observer = AWQObserver( + layer.weight, + layer.bias, + quantization_granularity, + mapping_type, + quant_dtype, + n_validation_examples, + validation_sequence_len, + scale_search_space_size, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + zero_point_dtype=zero_point_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + ) + else: + observer = AWQObserver2( + layer.weight, + layer.bias, + base_config, + scale_search_space_size, + ) return AWQObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) @@ -194,3 +206,97 @@ def _awq_uintx_transform( linear.extra_repr = types.MethodType(_linear_extra_repr, module) linear.bias = observed_linear.bias return linear + + +@dataclass +class AWQConfig(AOBaseConfig): + """ + Configuration for quantizing linear layers when passed into quantize_() + + Args: + base_config (AOBaseConfig): The quantization config that we can apply awq on top of, e.g. 8da4w, int4 weight only + step (str): a string of "prepare", "convert" or "load" indicating the step of AWQ process + prepare: insert AWQ Observers to linear + convert: convert the observed linear modules to linear modules with awq quantized weights + load: convert the floating point model to a dummy awq quantized model + example_input_shape (Optional[List[int]])): This is used for load step to initialize a random example input + scale_search_space_size (int): the number of scales to search for + set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. + """ + + base_config: AOBaseConfig + step: str + example_input_shape: Optional[List[int]] = None + scale_search_space_size: int = 20 + set_inductor_config: bool = True + + def __post_init__(self): + OPTIONS = ["prepare", "convert", "load"] + assert self.step in OPTIONS, f"Only {OPTIONS} are supported, got {self.step}" + + +@register_quantize_module_handler(AWQConfig) +def _awq_transform( + module: torch.nn.Module, + config: AWQUIntXConfig, +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + step = config.step + scale_search_space_size = config.scale_search_space_size + observed_linear = None + base_config = config.base_config + + if step == "prepare": + observer = AWQObserver2( + module.weight, + module.bias, + base_config, + scale_search_space_size, + ) + return AWQObservedLinear.from_float(module, observer) + elif step == "load": + # loading from pre-quantized checkpoint + observer = AWQObserver2( + module.weight, + module.bias, + base_config, + scale_search_space_size, + ) + observed_linear = AWQObservedLinear.from_float(module, observer) + assert config.example_input_shape is not None, ( + "When step is load, we expect example_input_shape to be specified as well" + ) + example_input = torch.randn( + config.example_input_shape, + device=module.weight.device, + dtype=module.weight.dtype, + ) + observed_linear(example_input) + else: + if not isinstance(module, AWQObservedLinear): + print(f"convert: module is not AWQObservedLinear, skipping: {type(module)}") + return module + observed_linear = module + + assert observed_linear is not None + equalization_scale = observed_linear.act_obs.calculate_qparams() + + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + dummy_mod = DummyModule(observed_linear.weight * equalization_scale) + quant_mod = base_config_handler(dummy_mod, config.base_config) + qw = quant_mod.weight + qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) + + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias != None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = torch.nn.Parameter(qw, requires_grad=False) + linear.extra_repr = types.MethodType(_linear_extra_repr, linear) + linear.bias = observed_linear.bias + return linear diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index e5ee96fea2..61a21820e1 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -8,8 +8,10 @@ import torch import torch.nn.functional as F +from torchao.core.config import AOBaseConfig from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes.uintx.uintx_layout import UintxLayout +from torchao.quantization import Int8DynamicActivationIntxWeightConfig from torchao.quantization.granularity import Granularity from torchao.quantization.observer import ( AffineQuantizedObserverBase, @@ -18,6 +20,15 @@ MappingType, ZeroPointDomain, ) +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, +) +from torchao.utils import DummyModule + + +@torch.no_grad() +def get_act_scale(x): + return x.abs().view(-1, x.shape[-1]).mean(0) class AWQObserver(AffineQuantizedObserverBase): @@ -145,6 +156,216 @@ def calculate_qparams(self): return best_scales.detach() +class AWQObserver2(torch.nn.Module): + def __init__( + self, + weight: torch.Tensor, + bias: torch.Tensor, + config: AOBaseConfig, + scale_search_space_size: int = 20, + ): + """ + A custom observer for Activation aware Weight Quantization (AWQ) + + Args: + weight: The weight tensor to be observed. + bias: The bias tensor to be observed. + quantization_granularity: Granularity which specifies how many weights share the same scale/zero point + input_dtype: The data type of the input tensor. + mapping_type: Always set to asymmetric + target_dtype: The target data type of the quantized tensor + scale_search_space_size: The number of scales to search for. + quant_min: The minimum quantized value + quant_max: The maximum quantized value + eps: The minimum scale. + scale_dtype: The data type of the scale tensor. + zero_point_dtype: The data type of the zero point tensor. + preserve_zero: A flag to indicate whether we need zero to be exactly + representable or not. + zero_point_domain: The domain of the zero point. + """ + super().__init__() + self.config = config + self.weight = weight + self.bias = bias + # self.calibration_token_count = 0 + self.inputs = [] + # self.outputs = [] + self.scale_options = scale_search_space_size + self.device = self.weight.device + # self.acc = torch.zeros((1, weight.shape[1]), device=self.device) + if self.bias is not None: + self.bias.to(self.device) + + @torch.no_grad() + def forward(self, input: torch.Tensor, output: torch.Tensor): + # import pdb + # pdb.set_trace() + # print(input.shape, input.abs().sum(1).shape, self.average.shape) + self.inputs.append(input.to("cpu")) + # self.outputs.append(output.to("cpu")) + # i = input.view(-1, input.shape[-1]) + # self.calibration_token_count += i.shape[0] + # self.acc += i.abs().sum(0) + + def calculate_qparams(self): + # import pdb + # pdb.set_trace() + assert self.inputs != None, ( + "calibrate observer first by running model on exemplar data" + ) + print("len inputs:", len(self.inputs)) + print("shape:", self.inputs[0].shape) + print("weight shape:", self.weight.shape) + + for i in range(len(self.inputs)): + self.inputs[i] = self.inputs[i].to(self.device) + # self.outputs[i] = self.outputs[i].to(self.device) + if self.bias is not None: + self.bias = self.bias.to(self.device) + + acc = torch.cat(self.inputs, dim=-2) + x_max = get_act_scale(acc) + + best_loss = float("inf") + best_scales = None + for i in range(self.scale_options): + ratio = i * 1 / self.scale_options + scales = x_max.pow(ratio).to(self.weight.dtype).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.config)] + dummy_mod = DummyModule(self.weight * scales) + quant_mod = config_handler(dummy_mod, self.config) + w = quant_mod.weight + + # for j in range(len(self.inputs)): + # q_out = F.linear(self.inputs[j] / scales, w, self.bias) + # loss += (self.outputs[j] - q_out).pow(2).mean().item() + orig_out = F.linear(acc, self.weight, self.bias) + q_out = F.linear(acc / scales, w, self.bias) + loss = (orig_out - q_out).pow(2).mean().item() + # print("scale:", scales, " loss:", loss, "scale shape:", scales.shape) + if loss < best_loss: + best_scales = scales + best_loss = loss + + # print("best_scale:", best_scales, " best loss:", best_loss) + + # for i in range(len(self.inputs)): + # self.inputs[i] = self.inputs[i].to("cpu") + # self.outputs[i] = self.outputs[i].to("cpu") + # if self.bias is not None: + # self.bias = self.bias.to("cpu") + + print("best scale shape:", best_scales.shape) + return best_scales.detach() + + +class AWQObserver3(AffineQuantizedObserverBase): + def __init__( + self, + weight: torch.Tensor, + bias: torch.Tensor, + config: AOBaseConfig, + scale_search_space_size: int = 20, + ): + """ + A custom observer for Activation aware Weight Quantization (AWQ) + + Args: + weight: The weight tensor to be observed. + bias: The bias tensor to be observed. + quantization_granularity: Granularity which specifies how many weights share the same scale/zero point + input_dtype: The data type of the input tensor. + mapping_type: Always set to asymmetric + target_dtype: The target data type of the quantized tensor + scale_search_space_size: The number of scales to search for. + quant_min: The minimum quantized value + quant_max: The maximum quantized value + eps: The minimum scale. + scale_dtype: The data type of the scale tensor. + zero_point_dtype: The data type of the zero point tensor. + preserve_zero: A flag to indicate whether we need zero to be exactly + representable or not. + zero_point_domain: The domain of the zero point. + """ + self.config = config + quant_min = getattr(config, "quant_min", None) + quant_max = getattr(config, "quant_max", None) + + assert isinstance(config, Int8DynamicActivationIntxWeightConfig), ( + f"Got: {config}" + ) + # TODO: + quantization_granularity = config.weight_granularity + target_dtype = config.weight_dtype + mapping_type = config.weight_mapping_type + + # TODO: + super().__init__( + mapping_type, + target_dtype, + quantization_granularity, + quant_min=quant_min, + quant_max=quant_max, + ) + self.quantization_granularity = quantization_granularity + self.weight = weight + self.bias = bias + self.calibration_token_count = 0 + self.inputs = [] + self.outputs = [] + self.scale_options = scale_search_space_size + self.device = self.weight.device + self.acc = torch.zeros((1, weight.shape[1]), device=self.device) + if self.bias is not None: + self.bias.to(self.device) + + self.best_loss = float("inf") + self.best_scales = None + + @torch.no_grad() + def forward(self, input: torch.Tensor, output: torch.Tensor): + # import pdb + # pdb.set_trace() + # print(input.shape, input.abs().sum(1).shape, self.average.shape) + # self.inputs.append(input.to("cpu")) + # self.outputs.append(output.to("cpu")) + # self.calibration_token_count += input.shape[-2] + # self.acc += input.abs().sum(-2) + # average = self.acc / self.calibration_token_count + + x_max = get_act_scale(input) + + # for i in range(len(self.inputs)): + # self.inputs[i] = self.inputs[i].to(self.device) + # self.outputs[i] = self.outputs[i].to(self.device) + + for i in range(self.scale_options): + ratio = i * 1 / self.scale_options + scales = x_max.pow(ratio).to(self.weight.dtype) + scales = scales / (scales.max() * scales.min()).sqrt() + config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.config)] + dummy_mod = DummyModule(self.weight * scales) + quant_mod = config_handler(dummy_mod, self.config) + w = quant_mod.weight + + loss = 0 + q_out = F.linear(input / scales, w, self.bias) + loss += (output - q_out).pow(2).mean().item() + if loss < self.best_loss: + self.best_scales = scales + self.best_loss = loss + + return output + + def calculate_qparams(self): + assert self.best_loss != None, ( + "calibrate observer first by running model on exemplar data" + ) + return self.best_scales.detach() + + class AWQObservedLinear(torch.nn.Linear): def __init__( self, diff --git a/torchao/prototype/awq/example2.py b/torchao/prototype/awq/example2.py new file mode 100644 index 0000000000..a0a8bc750c --- /dev/null +++ b/torchao/prototype/awq/example2.py @@ -0,0 +1,605 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import time + +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +from torchao.dtypes import Int4XPULayout, QDQLayout +from torchao.prototype.awq import ( + AWQConfig, + AWQObservedLinear, + insert_awq_observer_, +) +from torchao.quantization import ( + GemliteUIntXWeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + ModuleFqnToConfig, + PerAxis, + int4_weight_only, + quantize_, +) +from torchao.quantization.granularity import PerGroup + + +# adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 +def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): + dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") + samples = [] + n_tokens = n_samples * block_size + n_run = n_tokens + for data in dataset: + line = data["text"] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run -= len(line_encoded) + if n_run <= n_samples: + break + + cat_samples = torch.cat(samples, dim=1) + return [ + cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples) + ] + + +# from https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py +def wiki2_eval( + model, tokenizer, sequence_length, stride=512, verbose=True, device="cuda" +): + model.eval() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + tokenizer.add_eos_token = False + + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt") + + encodings["input_ids"] = encodings["input_ids"].to(device) + + lls, t = [], [] + for i in tqdm( + range(0, encodings["input_ids"].size(1), stride), disable=not verbose + ): + begin_loc = max(i + stride - sequence_length, 0) + end_loc = min(i + stride, encodings["input_ids"].size(1)) + trg_len = end_loc - i + input_ids = encodings["input_ids"][:, begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 # ignore context + + t1 = time.time() + with torch.no_grad(): + log_likelihood = model(input_ids, labels=target_ids).loss * trg_len + if device.startswith("cuda"): + torch.cuda.synchronize() + if device.startswith("xpu"): + torch.xpu.synchronize() + t2 = time.time() + t.append((t2 - t1)) + lls.append(log_likelihood) + + del input_ids, target_ids + + ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) + pred_time = sum(t) / len(t) + if verbose: + print("perplexity", ppl) + print("time", str(pred_time) + " sec") + + return {"perplexity": ppl, "prediction_time": pred_time} + + +# adapted from Hicham Badri (@mobicham) +def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): + import lm_eval + import numpy as np + + model.eval() + model.config.use_cache = False + try: + lm_eval.tasks.initialize_tasks() + except: + pass + model_eval = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer) + eval_batch_size = 1 # 8 + if tasks is None: + tasks = [ + "PPL", + "truthfulqa_mc2", + "winogrande", + "arc_challenge", + "hellaswag", + "gsm8k", + "mmlu", + "bbh", + ] + results = {} + if "PPL" in tasks: + results["perplexity"] = wiki2_eval( + model, tokenizer, 512, verbose=True, device=device + ) + ############################################ + if "truthfulqa_mc2" in tasks: + for task in [("truthfulqa_mc2", 0)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + if "winogrande" in tasks: + for task in [("winogrande", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + if "arc_challenge" in tasks: + for task in [("arc_challenge", 25)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + + # ############################################ + if "hellaswag" in tasks: + for task in [("hellaswag", 10)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + if "gsm8k" in tasks: + for task in [("gsm8k", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + # ############################################ + + if "mmlu" in tasks: + # MMLU + results_mmlu = {} + for task in [("mmlu", 5)]: + tag, fewshot = task + results_mmlu[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results_mmlu[tag]) + + mmlu_list = "hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions" + mmlu_list = [l.replace("hendrycksTest-", "") for l in mmlu_list.split(",")] + results_mmlu = results_mmlu["mmlu"] + + k = [] + for r in results_mmlu: + if np.any([(l in r) for l in mmlu_list]): + k.append(results_mmlu[r]["acc,none"]) + + assert len(k) == 57 + print("MMLU avg acc", np.mean(k)) + + results["mmlu"] = np.mean(k) + if "bbh" in tasks: + for task in [("leaderboard_bbh", 3)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + results["bbh"] = results[tag] + + return results + + +def wikitext2_ppl( + repo_id: str, + quant: str, + tasks: list[str], + calibration_limit: int, + validation_size: int, + device: str, + precision: torch.dtype, + sequence_length: int, + compile: bool, + model_save_path: str, + model_save_hf_hub_path: str, +): + print(f"Loading model on {device}...") + torch.manual_seed(34) + t0 = time.time() + # load any model with torch.nn.linear layers + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = ( + AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision) + .eval() + .to(device) + ) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + if quant.startswith("awq-uint"): + quant_dtype = quant.split("-")[1] + group_size = int(quant.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) + base_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=quant_dtype, + weight_granularity=PerGroup(group_size), + weight_scale_dtype=torch.bfloat16, + layout=QDQLayout(), + ) + print(f"running {quant_dtype} calibration") + t0 = time.time() + # insert observers to find average magnitude and calculate scales + insert_awq_observer_( + model, validation_size, sequence_length, base_config=base_config + ) + calibration_data = get_calib_dataset( + tokenizer=tokenizer, n_samples=calibration_limit, block_size=sequence_length + ) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") + print(f"time for calibration: {time.time() - t0:.02f} seconds") + + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quant + print(f"running {quant_dtype} quantization") + t0 = time.time() + # awq_uintx_config = awq_uintx( + # quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + # ) + if "xpu" in device: + base_config.layout = Int4XPULayout() + awq_config = AWQConfig(base_config) + quantize_( + model, + awq_config, + is_observed_linear, + ) + print(f"time for quantization: {time.time() - t0:.02f} seconds") + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + elif quant.startswith("awq-8da4w"): + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + + group_size = int(quant.split("-")[2]) + quant_dtype = torch.int4 + base_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=quant_dtype, + weight_granularity=PerGroup(group_size), + weight_scale_dtype=torch.bfloat16, + layout=QDQLayout(), + ) + print(f"running {quant_dtype} prepare and calibrate") + t0 = time.time() + awq_config = AWQConfig(base_config, step="prepare") + + quant_config = ModuleFqnToConfig( + {"_default": awq_config, "model.embed_tokens": embedding_config} + ) + quantize_( + model, + quant_config, + ) + from torchao._models._eval import TransformerEvalWrapper + + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=2048, + device=device, + ).run_eval( + tasks=tasks, + limit=calibration_limit, + ) + + # calibration_seq_length = 2048 + # calibration_tasks = ["wikitext"] + # calibration_limit = 10 + # inputs = ( + # LMEvalInputRecorder( + # tokenizer, + # calibration_seq_length, + # vocab_size=model.config.vocab_size, + # device="cpu", + # ) + # .record_inputs( + # calibration_tasks, + # calibration_limit, + # ) + # .get_recorded_inputs() + # ) + + # for batch in inputs: + # model(batch.to(device)) + # batch.to("cpu") + + # calibration + # calibration_data = get_calib_dataset( + # tokenizer=tokenizer, n_samples=calibration_limit, block_size=sequence_length + # ) + # for batch in calibration_data: + # model(batch.to(device)) + # batch.to("cpu") + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + + print(f"running {quant_dtype} convert") + t0 = time.time() + awq_config = AWQConfig(base_config, step="convert") + quant_config = ModuleFqnToConfig( + {"_default": awq_config, "model.embed_tokens": None} + ) + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + print("model after awq:", model) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + awq_load_config = AWQConfig(base_config, step="load") + quant_config = ModuleFqnToConfig( + {"_default": awq_load_config, "model.embed_tokens": None} + ) + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + + elif quant.startswith("8da4w"): + group_size = int(quant.split("-")[1]) + quant_dtype = torch.int4 + base_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=quant_dtype, + weight_granularity=PerGroup(group_size), + weight_scale_dtype=torch.bfloat16, + layout=QDQLayout(), + ) + + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + + quant_config = ModuleFqnToConfig( + {"_default": base_config, "model.embed_tokens": embedding_config} + ) + quantize_(model, quant_config) + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + + elif quant.startswith("awq-int4wo"): + group_size = int(quant.split("-")[2]) + use_hqq = True + print(f"running {quant} quantization with group size {group_size}") + base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) + + print(f"running {quant} prepare and calibrate") + t0 = time.time() + awq_config = AWQConfig(base_config, step="prepare") + + quant_config = awq_config + quantize_( + model, + quant_config, + ) + from torchao._models._eval import TransformerEvalWrapper + + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=2048, + device=device, + ).run_eval( + tasks=tasks, + limit=calibration_limit, + ) + + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + print(f"running {quant} convert") + t0 = time.time() + awq_config = AWQConfig(base_config, step="convert") + quant_config = awq_config + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + print("model after awq:", model) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + quant_config = AWQConfig(base_config, step="load") + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + elif quant.startswith("awq-gemlite"): + group_size = int(quant.split("-")[2]) + use_hqq = True + print(f"running {quant} quantization with group size {group_size}") + base_config = GemliteUIntXWeightOnlyConfig(group_size=group_size) + + print(f"running {quant} prepare and calibrate") + t0 = time.time() + awq_config = AWQConfig(base_config, step="prepare") + + quant_config = awq_config + quantize_( + model, + quant_config, + ) + from torchao._models._eval import TransformerEvalWrapper + + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=2048, + device=device, + ).run_eval( + tasks=tasks, + limit=calibration_limit, + ) + + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + print(f"running {quant} convert") + t0 = time.time() + awq_config = AWQConfig(base_config, step="convert") + quant_config = awq_config + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + print("model after awq:", model) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + quant_config = AWQConfig(base_config, step="load") + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + elif quant.startswith("gemlite"): + group_size = int(quant.split("-")[1]) + use_hqq = True + print(f"running {quant} quantization with group size {group_size}") + config = GemliteUIntXWeightOnlyConfig(group_size=group_size) + + print(f"running {quant} prepare and calibrate") + t0 = time.time() + quantize_(model, config) + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + quant_config = AWQConfig(config, step="load") + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + + elif quant.startswith("int4wo"): + group_size = int(quant.split("-")[1]) + use_hqq = "hqq" in quant + print(f"running {quant} quantization with group size {group_size}") + int4_weight_only_config = int4_weight_only( + group_size=group_size, use_hqq=use_hqq + ) + if "xpu" in device: + int4_weight_only_config.layout = Int4XPULayout() + quantize_(model, int4_weight_only_config) + + if compile: + model = torch.compile(model) + + return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Evaluate a model with the specified parameters." + ) + + # Optional arguments with default values + parser.add_argument("--repo", type=str, help="Repository ID of the model.") + parser.add_argument( + "--quant", + type=str, + help="Quantization method. Options are either awq-uint- for x =[1..8], awq-8da4w-, int4wo-, or int4wo--hqq.", + ) + parser.add_argument( + "--tasks", + nargs="+", + type=str, + help="Task to benchmark model on. Either PPL or QA", + default=["PPL"], + ) + parser.add_argument( + "--calibration_limit", + type=int, + default=10, + help="Number of samples to use for calibration. Default is 10.", + ) + parser.add_argument( + "--validation_size", type=int, default=1, help="Validation size. Default is 1." + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run the evaluation on. Default is 'cuda'.", + ) + parser.add_argument( + "--precision", + type=str, + default="bfloat16", + help="Precision type. Default is 'bfloat16'.", + ) + parser.add_argument( + "--seq_len", + type=int, + default=512, + help="Length of examples to calibrate and evaluate model on. Default is 512", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Flag to indicate if compilation is required.", + ) + parser.add_argument( + "--model_save_path", + type=str, + default=None, + help="Path to store the scale values.", + ) + parser.add_argument( + "--model_save_hf_hub_path", + type=str, + default=None, + help="Huggingface hub path to store the quantized model and tokenizer.", + ) + + args = parser.parse_args() + + # Convert precision argument to torch dtype + precision_dtype = getattr(torch, args.precision, torch.bfloat16) + ppl = wikitext2_ppl( + args.repo, + args.quant, + args.tasks, + args.calibration_limit, + args.validation_size, + args.device, + args.precision, + args.seq_len, + args.compile, + args.model_save_path, + args.model_save_hf_hub_path, + ) + + print(f"{args.quant} Results: {ppl}") diff --git a/torchao/prototype/moe_quant/utils.py b/torchao/prototype/moe_quant/utils.py index 0e75de2ee4..28291afdf4 100644 --- a/torchao/prototype/moe_quant/utils.py +++ b/torchao/prototype/moe_quant/utils.py @@ -20,18 +20,7 @@ dataclass, register_quantize_module_handler, ) -from torchao.utils import fill_defaults - - -class DummyModule(torch.nn.Module): - """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a - DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. - """ - - def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - super().__init__() - self.weight = weight - self.bias = bias +from torchao.utils import DummyModule, fill_defaults class FakeExtraDimTensor(torch.Tensor): diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index 6c433844a6..a794ae8289 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -73,6 +73,11 @@ def __tensor_unflatten__( tensor_data_dict["scale"], ) + def _quantization_type(self): + return ( + f"original_weight_tensor={self.original_weight_tensor}, scale={self.scale}" + ) + @staticmethod def _quantized_linear_op( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor @@ -126,7 +131,7 @@ def _(func, types, args, kwargs): ) -@implements(aten.detach.default) +@implements([aten.detach.default, aten.alias.default]) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index d692b52bdc..e7ed3a5030 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -545,10 +545,10 @@ def _quantization_type(weight: torch.Tensor): if hasattr(weight, "_quantization_type"): return f"{weight.__class__.__name__}({weight._quantization_type()})" - if type(weight) is torch.Tensor: - return "not quantized" + if type(weight) is torch.Tensor or isinstance(weight, torch.nn.Parameter): + return f"Tensor: {type(weight)}" - return "not recognized" + return f"not recognized: {type(weight)}" def _linear_extra_repr(self): diff --git a/torchao/utils.py b/torchao/utils.py index c56b607b7b..0a06a8ef4f 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -11,7 +11,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable +from typing import Any, Callable, Optional import torch import torch.nn.utils.parametrize as parametrize @@ -42,6 +42,7 @@ "is_sm_at_least_89", "is_sm_at_least_90", "is_package_at_least", + "DummyModule", ] @@ -732,3 +733,13 @@ def _is_fbgemm_genai_gpu_available(): return False return True + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + super().__init__() + self.weight = weight + self.bias = bias