Skip to content

[WIP] [Research] Attention quantization and transformation #1612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions examples/attention_quantization/llama3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
model_id = "meta-llama/Llama-3.2-1B-instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = QuantizationModifier(
config_groups={
"attention_quant": QuantizationScheme(
targets=["re:.*self_attn$"],
input_activations=QuantizationArgs(num_bits=8, type="float"),
),
},
ignore=["lm_head"],
)

# Apply algorithms.
oneshot(
model=model,
dataset="ultrachat_200k",
splits={"calibration": "test_sft[:512]"},
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
1 change: 0 additions & 1 deletion src/llmcompressor/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# flake8: noqa

from .cache import *
from .gptq import *
from .quantization import *
58 changes: 58 additions & 0 deletions src/llmcompressor/modifiers/quantization/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Optional

import torch
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationStatus,
forward_quantize,
)
from compressed_tensors.transform import TransformBase, TransformLocation
from transformers.modeling_utils import AttentionInterface
from transformers.models.llama.modeling_llama import eager_attention_forward

from llmcompressor.modifiers.quantization.calibration import calibrate_activations


def calibrated_attention(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
for submodule in module.children():
if isinstance(submodule, TransformBase):
# 1. apply transforms
if TransformBase.args.location == TransformLocation.ATTN_Q:
query = submodule(query)

if TransformBase.args.location == TransformLocation.ATTN_K:
key = submodule(key)

# if TransformBase.args.location == TransformLocation.ATTN_V:
# value = submodule(value)

scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None)
status: Optional[QuantizationStatus] = getattr(module, "quantization_status", None)
if getattr(scheme, "input_activations", None) is not None:
# 2. calibrate quantization
if status == QuantizationStatus.CALIBRATION:
calibrate_activations(module, value=query, base_name="q")
calibrate_activations(module, value=key, base_name="k")
calibrate_activations(module, value=value, base_name="v")

# 3. apply quantization
if status in (QuantizationStatus.CALIBRATION, QuantizationStatus.FROZEN):
query = forward_quantize(module, query, "q", scheme.input_activations)
key = forward_quantize(module, key, "k", scheme.input_activations)
value = forward_quantize(module, value, "v", scheme.input_activations)

return eager_attention_forward(
module, query, key, value, attention_mask, scaling, dropout, **kwargs
)


AttentionInterface.register("calibrated_attention", calibrated_attention)
208 changes: 0 additions & 208 deletions src/llmcompressor/modifiers/quantization/cache.py

This file was deleted.

Loading
Loading