Skip to content

[WIP] [Research] Attention quantization and transformation #1612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/llmcompressor/modifiers/quantization/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Optional

import torch
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationStatus,
calibrate_activations,
forward_quantize,
)
from compressed_tensors.transform import TransformBase, TransformLocation
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface


def calibrated_attention(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# 1. apply transforms
for submodule in module.children():
if isinstance(submodule, TransformBase):
if TransformBase.args.location == TransformLocation.ATTN_Q:
query = submodule(query)

if TransformBase.args.location == TransformLocation.ATTN_K:
key = submodule(key)

# if TransformBase.args.location == TransformLocation.ATTN_V:
# key = submodule(key)

scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None)
status = Optional[QuantizationStatus] = getattr(module, "quantization_status", None)
if scheme is not None:
if scheme.input_activations is not None:
# 2. calibrate quantization
if status == QuantizationStatus.CALIBRATION:
calibrate_activations(module, value=query, base_name="q")
calibrate_activations(module, value=query, base_name="k")
calibrate_activations(module, value=query, base_name="v")

# 3. apply quantization
if status in (QuantizationStatus.CALIBRATION, QuantizationStatus.FROZEN):
query = forward_quantize(module, query, "q", scheme.input_activations)
key = forward_quantize(module, key, "k", scheme.input_activations)
value = forward_quantize(module, value, "v", scheme.input_activations)

if scheme.weights is not None:
raise ValueError("")

if scheme.output_activations is not None:
raise NotImplementedError("")

return ALL_ATTENTION_FUNCTIONS["eager"](
module, query, key, value, attention_mask, scaling, dropout, **kwargs
)


AttentionInterface.register("calibrated_attention", calibrated_attention)
55 changes: 1 addition & 54 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import Any, Dict, Optional, Tuple
from typing import Any, Optional

import torch
from compressed_tensors.quantization import (
DynamicType,
KVCacheScaleType,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from compressed_tensors.utils import align_module_device, update_parameter_data
from loguru import logger
from torch.nn import Module
Expand All @@ -29,9 +26,6 @@
"update_weight_zp_scale",
"calibrate_input_hook",
"calibrate_output_hook",
"calibrate_kv_cache_input_hook",
"calibrate_kv_cache_output_hook",
"initialize_quantized_kv_cache",
"freeze_module_quantization",
"apply_calibration_status",
"reset_quantization_status",
Expand Down Expand Up @@ -235,53 +229,6 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
return output


def calibrate_kv_cache_input_hook(
module: Module, args: Any, kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Hook to update inputs to attention layers when running
kv_cache quantization. Will update the passed in
kv_cache to singleton QuantizedKVParameterCache.
"""
kv_cache = getattr(module, "kv_cache")
kwargs["past_key_value"] = kv_cache
kwargs["use_cache"] = False
return args, kwargs


def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor):
"""
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""
kv_cache = getattr(module, "kv_cache")
k_scale = kv_cache.k_scales[module.layer_idx]
v_scale = kv_cache.v_scales[module.layer_idx]
update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)


def initialize_quantized_kv_cache(module: Module):
"""
Initialize a quantized kv_cache on a module (analogous to initializing an observer)
When a config specifying kv_cache quantization is applied to a model, the kv_cache
args are redefined as the output_activations targeting attention modules.

This function should be called on attention modules with output_activations
"""
scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None)
existing_kv_cache = getattr(module, "kv_cache", None)

if (
scheme is None
or not is_kv_cache_quant_scheme(scheme)
or isinstance(existing_kv_cache, QuantizedKVParameterCache)
):
return

quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations)
setattr(module, "kv_cache", quantized_kv_cache)


def apply_calibration_status(module: Module):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
Expand Down
109 changes: 54 additions & 55 deletions src/llmcompressor/modifiers/quantization/quantization/mixin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch
from compressed_tensors.quantization import (
DynamicType,
QuantizationArgs,
QuantizationConfig,
QuantizationScheme,
Expand All @@ -20,12 +19,9 @@
from llmcompressor.modifiers.quantization.calibration import (
apply_calibration_status,
calibrate_input_hook,
calibrate_kv_cache_input_hook,
calibrate_kv_cache_output_hook,
calibrate_output_hook,
freeze_module_quantization,
initialize_observer,
initialize_quantized_kv_cache,
reset_quantization_status,
)
from llmcompressor.modifiers.utils.hooks import HooksMixin
Expand Down Expand Up @@ -138,7 +134,9 @@ def start_calibration(self, model: torch.nn.Module):

:param model: model to prepare for calibration
"""
self._calibration_hooks = self._initialize_hooks(model)
hooks = set()
for module in model.modules():
hooks += self._initialize_hooks(module)
model.apply(apply_calibration_status)
model.apply(enable_quantization) # quantize at the same time as calibrate

Expand Down Expand Up @@ -212,71 +210,72 @@ def _initialize_observers(self, module: torch.nn.Module):
if not hasattr(module, "quantization_scheme"):
return

scheme: QuantizationScheme = module.quantization_scheme
input = scheme.input_activations and scheme.input_activations.dynamic in (
False,
DynamicType.LOCAL,
)
weight = scheme.weights is not None
output = scheme.output_activations and not scheme.output_activations.dynamic
is_attention = is_attention_module(module)
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
input, weight, output = _get_observer_targets(module.quantization_scheme)

# input activations
if input:
initialize_observer(module, base_name="input")
# input activations
if input:
initialize_observer(module, base_name="input")

# weight observers (used by `update_weight_zp_scale` or child modifier)
if weight:
initialize_observer(module, base_name="weight")

# weight observers (used by `update_weight_zp_scale` or child modifier)
if weight:
initialize_observer(module, base_name="weight")
# output activations
if output:
initialize_observer(module, base_name="output")

# kv_cache activations. Within `apply_quantization_config`, the config is
# modified to use attention output quantization if a kv_cache_scheme exists
if is_attention and output:
initialize_quantized_kv_cache(module)
elif is_attention_module(module):
# attention observers
initialize_observer(module, base_name="q")
initialize_observer(module, base_name="k")
initialize_observer(module, base_name="v")

# output activations
elif output:
initialize_observer(module, base_name="output")
else:
raise ValueError(f"Unsupported quantization target {type(module)}")

def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
def _initialize_hooks(self, module: torch.nn.Module) -> Set[RemovableHandle]:
hooks = set()
for module in model.modules():
if not hasattr(module, "quantization_scheme"):
continue
if not hasattr(module, "quantization_scheme"):
return hooks

scheme: QuantizationScheme = module.quantization_scheme
input = scheme.input_activations and scheme.input_activations.dynamic in (
False,
DynamicType.LOCAL,
)
output = scheme.output_activations and not scheme.output_activations.dynamic
is_attention = is_attention_module(module)
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
input, _, output = _get_observer_targets(module.quantization_scheme)

# input activations
if input:
hooks.add(
self.register_hook(module, calibrate_input_hook, "forward_pre")
)

# kv_cache activations. Within `apply_quantization_config`, the config is
# modified to use attention output quantization if a kv_cache_scheme exists
if is_attention and output:
hooks.add(
self.register_hook(
module,
calibrate_kv_cache_input_hook,
"forward_pre",
with_kwargs=True,
)
)
hooks.add(
self.register_hook(
module, calibrate_kv_cache_output_hook, "forward"
)
)

# output activations
elif output:
hooks.add(self.register_hook(module, calibrate_output_hook, "forward"))

elif is_attention_module(module):
# wrap attention interface
tmp = None

def forward_pre(self, *args, **kwargs):
nonlocal tmp
tmp = self.config["_attn_implementation"]
self.config["_attn_implementation"] = "calibrated_attention"

def forward(self, *args, **kwargs):
self.config["_attn_implementation"] = tmp

hooks.add(self.register_hook(module, forward_pre, "forward_pre"))
hooks.add(self.register_hook(module, forward, "forward"))

else:
raise ValueError(f"Unsupported quantization target {type(module)}")

return hooks


def _get_observer_targets(scheme: QuantizationScheme) -> Tuple[bool, bool, bool]:
input = scheme.input_activations and scheme.input_activations.dynamic is not True
weight = scheme.weights is not None
output = scheme.output_activations and scheme.output_activations.dynamic is not True

return input, weight, output
94 changes: 0 additions & 94 deletions tests/llmcompressor/modifiers/calibration/test_kv_cache.py

This file was deleted.

Loading