Skip to content

[Transform] Online Rotations #1651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: bdellabe/transform-modifier
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions examples/transform/spinquant_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MODEL_ID,
torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, attn_implementation="eager")

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
Expand Down Expand Up @@ -58,8 +58,10 @@ def tokenize(sample):
# * apply spinquant transforms to model in order to make quantization easier
# * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = [
SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
SpinQuantModifier(
rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard"
),
# QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
]

# Apply algorithms.
Expand All @@ -75,9 +77,12 @@ def tokenize(sample):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
from llmcompressor.utils import calibration_forward_context

with calibration_forward_context(model):
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
Expand Down
39 changes: 36 additions & 3 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Tuple
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple

import torch
from compressed_tensors.quantization import (
Expand All @@ -13,18 +14,26 @@
from compressed_tensors.utils import align_module_device, update_parameter_data
from loguru import logger
from torch.nn import Module
from torch.utils.hooks import RemovableHandle

from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache
from llmcompressor.observers import Observer
from llmcompressor.utils.helpers import getattr_chain

if TYPE_CHECKING:
from compressed_tensors.modeling.attention import CompressedAttentionImpl

from llmcompressor.modifiers.utils.hooks import HooksMixin


DEFAULT_MAXSHRINK = 0.20
DEFAULT_PATIENCE = 5
DEFAULT_AVERAGING_CONSTANT = 0.01
DEFAULT_GRID = 100.0
DEFAULT_NORM = 2.4

__all__ = [
"register_calibrate_attn_hooks",
"initialize_observer",
"update_weight_zp_scale",
"calibrate_input_hook",
Expand Down Expand Up @@ -205,14 +214,30 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
)


def calibrate_input_hook(module: Module, args: Any):
def register_calibrate_attn_hooks(
modifier: "HooksMixin", attention_impl: "CompressedAttentionImpl"
) -> Set[RemovableHandle]:
return {
modifier.register_hook(
attention_impl, partial(calibrate_input_hook, basename="q"), "query"
),
modifier.register_hook(
attention_impl, partial(calibrate_input_hook, basename="k"), "key"
),
modifier.register_hook(
attention_impl, partial(calibrate_input_hook, basename="v"), "value"
),
}


def calibrate_input_hook(module: Module, args: Any, base_name: str = "input"):
"""
Hook to calibrate input activations.
Will call the observers to update the scales/zp before applying
input QDQ in the module's forward pass.
"""
args = args[0] if isinstance(args, tuple) else args
calibrate_activations(module, value=args, base_name="input")
calibrate_activations(module, value=args, base_name=base_name)


def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
Expand Down Expand Up @@ -282,6 +307,14 @@ def initialize_quantized_kv_cache(module: Module):
setattr(module, "kv_cache", quantized_kv_cache)


def initialize_attention_observers(module: Module):
input_args = getattr_chain(module, "quantization_scheme.input_activations", None)
if input_args is not None:
initialize_observer(module, "q", input_args)
initialize_observer(module, "k", input_args)
initialize_observer(module, "v", input_args)


def apply_calibration_status(module: Module):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def _initialize_observers(self, module: torch.nn.Module):
# kv_cache activations. Within `apply_quantization_config`, the config is
# modified to use attention output quantization if a kv_cache_scheme exists
if is_attention and output:
# initialize_attention_observers(module) # TODO: attnq
initialize_quantized_kv_cache(module)

# output activations
Expand All @@ -240,6 +241,7 @@ def _initialize_observers(self, module: torch.nn.Module):

def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
hooks = set()

for module in model.modules():
if not hasattr(module, "quantization_scheme"):
continue
Expand All @@ -258,6 +260,11 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
self.register_hook(module, calibrate_input_hook, "forward_pre")
)

# TODO: attnq
# if is_attention:
# attention_impl = CompressedAttentionImpl.from_module(module)
# hooks |= register_calibrate_attn_hooks(self, attention_impl)

# kv_cache activations. Within `apply_quantization_config`, the config is
# modified to use attention output quantization if a kv_cache_scheme exists
if is_attention and output:
Expand Down
47 changes: 43 additions & 4 deletions src/llmcompressor/modifiers/transform/spinquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
config_groups["R2"] = self._create_r2_scheme(state.model)

if SpinquantRotation.R3 in self.rotations:
config_groups["R3"] = self._create_r3_scheme()
config_groups["R3"] = self._create_r3_scheme(state.model)

if SpinquantRotation.R4 in self.rotations:
config_groups["R4"] = self._create_r4_scheme()
Expand Down Expand Up @@ -214,8 +214,47 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
],
)

def _create_r3_scheme(self) -> TransformScheme:
raise NotImplementedError()
def _create_r3_scheme(self, model: PreTrainedModel) -> TransformScheme:
config = model.config

if hasattr(config, "head_dim"):
head_dim = config.head_dim
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
head_dim = config.hidden_size // config.num_attention_heads
else:
raise NotImplementedError()

return TransformScheme(
type=self.transform_type,
randomize=self.randomize,
requires_grad=self.learnable,
head_dim=head_dim,
apply=[
TransformArgs(
targets=[self.mappings.attn],
location="attn_q",
),
TransformArgs(
targets=[self.mappings.attn],
location="attn_k",
),
],
)

def _create_r4_scheme(self) -> TransformScheme:
raise NotImplementedError()
return TransformScheme(
type=self.transform_type,
randomize=self.randomize,
requires_grad=self.learnable,
apply=[
TransformArgs(
targets=[*self.mappings.mlp_out],
location="input",
),
TransformArgs(
targets=[*self.mappings.mlp_out],
location="weight_input",
inverse=True,
),
],
)
2 changes: 2 additions & 0 deletions src/llmcompressor/modifiers/transform/spinquant/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class SpinQuantMapping(BaseModel):
embedding: str

attn: str
attn_q: str
attn_k: str
attn_v: str
Expand All @@ -31,6 +32,7 @@ def cast_to_list(cls, value):

_default_mappings = SpinQuantMapping(
embedding="re:.*embed_tokens$",
attn="re:.*self_attn$",
attn_q="re:.*q_proj$",
attn_k="re:.*k_proj$",
attn_v="re:.*v_proj$",
Expand Down
Loading