Skip to content

[Transforms] Enable transforms to be applied to weights during quantization #1243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions load_transform_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from pathlib import Path

from safetensors import safe_open
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils.quantization_config import CompressedTensorsConfig

MODEL_ID = "/home/dsikka/Llama-3.2-1B-Instruct-W4A16-uncompressed-hadamard-random-debug"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
breakpoint()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))

import lm_eval

results = lm_eval.simple_evaluate(
model="hf",
model_args={
"pretrained": MODEL_ID,
"add_bos_token": True,
"quantization_config": CompressedTensorsConfig(run_compressed=False),
},
tasks=["gsm8k"],
num_fewshot=8,
limit=1000,
device="cuda:0",
batch_size=100,
)
print(results["results"])
"""
For: Llama-3.2-1B-Instruct

Dense:
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.379,
'exact_match_stderr,strict-match': 0.015349091002225352,
'exact_match,flexible-extract': 0.381,
'exact_match_stderr,flexible-extract': 0.015364734787007436}}

----------------------------MINMAX ---------------------------:

QantModifier - NO TRANSFORMS
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.177,
'exact_match_stderr,strict-match': 0.011743632866916145,
'exact_match,flexible-extract': 0.179,
'exact_match_stderr,flexible-extract': 0.0117721103708122}}

QuantModifier - TRANSFORMS (random)
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.231,
'exact_match_stderr,strict-match': 0.012997843819031815,
'exact_match,flexible-extract': 0.236,
'exact_match_stderr,flexible-extract': 0.01301973553930782}}

GPTQ
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.243,
'exact_match_stderr,strict-match': 0.013569640199177434,
'exact_match,flexible-extract': 0.244,
'exact_match_stderr,flexible-extract': 0.013588548437881431}}


---------------------------MSE-----------------------------------:
QuantModifier - No Transforms
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.195,
'exact_match_stderr,strict-match': 0.012535235623319334,
'exact_match,flexible-extract': 0.195,
'exact_match_stderr,flexible-extract': 0.012535235623319334}}

QuantModifier - With Transforms (random)
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.243,
'exact_match_stderr,strict-match': 0.013569640199177457,
'exact_match,flexible-extract': 0.244,
'exact_match_stderr,flexible-extract': 0.013588548437881412}}

QuantModifier - With Transforms (not random, not normalized )
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.261,
'exact_match_stderr,strict-match': 0.013895037677965126,
'exact_match,flexible-extract': 0.262,
'exact_match_stderr,flexible-extract': 0.013912208651021352}}

QuantModifier - With Transforms (not random, normalized)
{'gsm8k': {'alias': 'gsm8k',
'exact_match,strict-match': 0.27,
'exact_match_stderr,strict-match': 0.014046255632633915,
'exact_match,flexible-extract': 0.27,
'exact_match_stderr,flexible-extract': 0.014046255632633915}}

GPTQ:
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.285,
'exact_match_stderr,strict-match': 0.014282120955200484,
'exact_match,flexible-extract': 0.286,
'exact_match_stderr,flexible-extract': 0.01429714686251791}}

---------------------8bit----------------------------------:
QuantModifier - with Transforms (not random, normalized)
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.371,
'exact_match_stderr,strict-match': 0.015283736211823187,
'exact_match,flexible-extract': 0.372,
'exact_match_stderr,flexible-extract': 0.015292149942040577}}

GPTQ
{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.364,
'exact_match_stderr,strict-match': 0.01522286884052202,
'exact_match,flexible-extract': 0.365,
'exact_match_stderr,flexible-extract': 0.015231776226264903}}
"""
33 changes: 25 additions & 8 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from compressed_tensors.transforms.apply import apply_transforms_to_activations_or_parameter
from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data
from loguru import logger
from torch.nn import Module
Expand Down Expand Up @@ -120,8 +120,21 @@ def update_weight_zp_scale(module: Module):

if module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it

transform_data = getattr(module, "transform_data", None)
if transform_data is not None:
untransformed_weight = module.weight.data.clone()
apply_transforms_to_activations_or_parameter(
module=module,
module_activation_or_parameter=module.weight,
transform_data=transform_data,
)

call_observer(module=module, base_name="weight")

if transform_data is not None:
module.weight.data.copy_(untransformed_weight)


def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
"""
Expand All @@ -138,11 +151,22 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
if value.numel() == 0:
return

transform_data = getattr(module, "transform_data", None)
if transform_data is not None:
value = apply_transforms_to_activations_or_parameter(
module=module,
module_activation_or_parameter=value,
transform_data=transform_data,
update_in_place=False
)

call_observer(
module=module,
base_name=base_name,
value=value,
)
breakpoint()
# validate value is correct


def calibrate_input_hook(module: Module, args: Any):
Expand All @@ -166,12 +190,6 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
value=output,
base_name="output",
)
output = forward_quantize(
module=module,
value=output,
base_name="output",
args=module.quantization_scheme.output_activations,
)
return output


Expand All @@ -197,7 +215,6 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")


def set_unset_kv_cache(module: Module):
"""
Set or unset singleton QuantizedKVParameterCache for each
Expand Down
54 changes: 43 additions & 11 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Set

from compressed_tensors.quantization import (
QuantizationArgs,
Expand All @@ -10,6 +10,12 @@
is_preset_scheme,
preset_name_to_scheme,
)
from compressed_tensors.quantization.lifecycle import (
post_forward_quantize,
pre_forward_quantize,
register_quantization_hooks,
)
from compressed_tensors.transforms.transform_config import TransformationConfig
from loguru import logger
from pydantic import Field, field_validator
from torch.nn import Module
Expand Down Expand Up @@ -74,6 +80,7 @@ class QuantizationModifier(Modifier):
"""

config_groups: Optional[Dict[str, QuantizationScheme]] = None
transforms_config: Optional[TransformationConfig] = None
ignore: List[str] = Field(default_factory=list)
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
scheme: Optional[Union[str, Dict[str, Any]]] = None
Expand All @@ -83,6 +90,7 @@ class QuantizationModifier(Modifier):

calibration_dataloader_: Any = None
calibration_function_: Any = None
_handles: Set = set()

@field_validator("targets", mode="before")
def validate_targets(cls, value: Union[str, List[str]]) -> List[str]:
Expand Down Expand Up @@ -210,7 +218,12 @@ def _check_calibration_data(self, config: QuantizationConfig):
def _apply_modifier_to_model(self, model: Module):
modifier_as_config = self.create_init_config()
# Add step to attach kv_cache to the model, if present within the config
apply_quantization_config(model, modifier_as_config)
apply_quantization_config(
model,
modifier_as_config,
transforms_config=self.transforms_config,
delay_forward_quantize=True,
)
model.apply(set_unset_kv_cache)
return modifier_as_config

Expand Down Expand Up @@ -258,14 +271,17 @@ def _calibrate_if_possible(self, module: Module):
)

elif not self.calibration_dataloader_:
# TODO: should just use HooksMixin
# hooks should have been delayed
module.apply(lambda model: register_quantization_hooks(model))
return

module.apply(lambda model: initialize_observer(model, base_name="input"))
module.apply(lambda model: initialize_observer(model, base_name="output"))
module.apply(self.register_calibration_hooks)
self._calibrate(module)
module.apply(set_unset_kv_cache)
self.remove_hooks()
self.remove_hooks(self._handles)

def register_calibration_hooks(self, module: Module):
"""
Expand All @@ -285,23 +301,39 @@ def register_calibration_hooks(self, module: Module):

# Calibrate inputs if an input_quant is provided and not running dynamic quant
if calibrate_inputs:
self.register_hook(module, calibrate_input_hook, "forward_pre")
self._handles.add(
self.register_hook(module, calibrate_input_hook, "forward_pre")
)

if not is_attention_module_:
self.register_hook(module, pre_forward_quantize, "forward_pre")

if output_quant:
# hooks for attn modules if running kv_cache quant
if is_attention_module_:
self.register_hook(
module,
calibrate_kv_cache_input_hook,
"forward_pre",
with_kwargs=True,
self._handles.add(
self.register_hook(
module,
calibrate_kv_cache_input_hook,
"forward_pre",
with_kwargs=True,
)
)

self.register_hook(module, calibrate_kv_cache_output_hook, "forward")
self._handles.add(
self.register_hook(
module, calibrate_kv_cache_output_hook, "forward"
)
)

# hooks for output quant if not running dynamic quant
elif not output_quant.dynamic:
self.register_hook(module, calibrate_output_hook, "forward")
self._handles.add(
self.register_hook(module, calibrate_output_hook, "forward")
)

if not is_attention_module_:
self.register_hook(module, post_forward_quantize, "forward")

def _calibrate(self, module: Module):
class_name = self.__class__.__name__.replace("PyTorch", "")
Expand Down
1 change: 0 additions & 1 deletion src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,6 @@ def calibration_forward_context(model: PreTrainedModel):
with (
torch.no_grad(),
DisableKVCache(model),
DisableQuantization(model),
eval_context(model),
):
yield
Expand Down
Loading
Loading