Skip to content

[WIP] [Research] Attention quantization and transformation #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 1 addition & 44 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def apply_quantization_config(
# build mapping of targets to schemes for easier matching
# use ordered dict to preserve target ordering in config
target_to_scheme = OrderedDict()
config = process_quantization_config(config)
names_to_scheme = dict()
for scheme in config.config_groups.values():
for target in scheme.targets:
Expand All @@ -152,13 +151,7 @@ def apply_quantization_config(
# list of submodules to ignore
ignored_submodules = defaultdict(list)
# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in iter_named_quantizable_modules(
model,
include_children=True,
include_attn=True,
): # child modules and attention modules
# potentially fix module name to remove FSDP wrapper prefix
name = fix_fsdp_module_name(name)
for name, submodule in model.named_modules():
if matches := find_name_or_class_matches(name, submodule, config.ignore):
for match in matches:
ignored_submodules[match].append(name)
Expand Down Expand Up @@ -200,42 +193,6 @@ def apply_quantization_config(
return names_to_scheme


def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
"""
Preprocess the raw QuantizationConfig

:param config: the raw QuantizationConfig
:return: the processed QuantizationConfig
"""
if config.kv_cache_scheme is not None:
config = process_kv_cache_config(config)

return config


def process_kv_cache_config(
config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
) -> QuantizationConfig:
"""
Reformulate the `config.kv_cache` as a `config_group`
and add it to the set of existing `config.groups`

:param config: the QuantizationConfig
:return: the QuantizationConfig with additional "kv_cache" group
"""
if targets == KV_CACHE_TARGETS:
_LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")

kv_cache_dict = config.kv_cache_scheme.model_dump()
kv_cache_scheme = QuantizationScheme(
output_activations=QuantizationArgs(**kv_cache_dict),
targets=targets,
)
kv_cache_group = dict(kv_cache=kv_cache_scheme)
config.config_groups.update(kv_cache_group)
return config


def apply_quantization_status(model: Module, status: QuantizationStatus):
"""
Applies in place the quantization lifecycle up to the given status
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from compressed_tensors.utils import safe_permute
from torch.nn import Module
from transformers import AttentionInterface


__all__ = [
Expand Down
133 changes: 58 additions & 75 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# limitations under the License.


import logging
import math
from enum import Enum
from typing import List, Optional
from typing import Optional, Tuple

import torch
from compressed_tensors.quantization.lifecycle.forward import (
Expand All @@ -30,7 +28,7 @@
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
from compressed_tensors.quantization.utils import is_fp4
from compressed_tensors.utils import (
disable_hf_hook,
get_execution_device,
Expand All @@ -42,18 +40,10 @@
__all__ = [
"initialize_module_for_quantization",
"is_attention_module",
"KVCacheScaleType",
"get_calibrated_locations",
]


_LOGGER = logging.getLogger(__name__)


class KVCacheScaleType(Enum):
KEY = "k_scale"
VALUE = "v_scale"


def initialize_module_for_quantization(
module: Module,
scheme: Optional[QuantizationScheme] = None,
Expand All @@ -78,78 +68,78 @@ def initialize_module_for_quantization(
# TODO: don't initialize parameters when running decompression
scheme = scheme or getattr(module, "quantization_scheme", None)
if scheme is None:
# no scheme passed and layer not targeted for quantization - skip
return

if is_attention_module(module):
# quantized actions based on calltime status
_initialize_attn_scales(module)
# initialize scheme and status
module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED

else:
input, weight, output = get_calibrated_locations(scheme)

if scheme.input_activations is not None:
_initialize_scale_zero_point(
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
if input:
_initialize_quantization_parameters(
module,
"input",
scheme.input_activations,
force_zero_point=force_zero_point,
scale_dtype=scale_dtype,
)

if scheme.weights is not None:
if hasattr(module, "weight"):
weight_shape = None
if isinstance(module, torch.nn.Linear):
weight_shape = module.weight.shape
_initialize_scale_zero_point(
module,
"weight",
scheme.weights,
weight_shape=weight_shape,
force_zero_point=force_zero_point,
scale_dtype=scale_dtype,
)
else:
_LOGGER.warning(
f"module type {type(module)} targeted for weight quantization but "
"has no attribute weight, skipping weight quantization "
f"for {type(module)}"
)

if scheme.output_activations is not None:
if not is_kv_cache_quant_scheme(scheme):
_initialize_scale_zero_point(
module, "output", scheme.output_activations, scale_dtype=scale_dtype
)

module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED
if weight:
_initialize_quantization_parameters(
module,
"weight",
scheme.weights,
force_zero_point=force_zero_point,
scale_dtype=scale_dtype,
)

if output:
_initialize_quantization_parameters(
module,
"output",
scheme.output_activations,
force_zero_point=force_zero_point,
scale_dtype=scale_dtype,
)

with disable_hf_hook(module):
# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)

elif is_attention_module(module):
assert input and scheme.input_activations is not None
for base_name in ("q", "k", "v"):
_initialize_quantization_parameters(
module,
base_name,
scheme.input_activations,
force_zero_point=force_zero_point,
scale_dtype=scale_dtype,
)

else:
raise ValueError(f"Unsupported quantization target {type(module)}")


def is_attention_module(module: Module):
# can redefine to inspect source code for references to ALL_ATTENTION_FUNCTIONS
return "attention" in module.__class__.__name__.lower() and (
hasattr(module, "k_proj")
or hasattr(module, "v_proj")
or hasattr(module, "qkv_proj")
)


def _initialize_scale_zero_point(
def _initialize_quantization_parameters(
module: Module,
base_name: str,
quantization_args: QuantizationArgs,
weight_shape: Optional[torch.Size] = None,
force_zero_point: bool = True,
scale_dtype: Optional[torch.dtype] = None,
):
if quantization_args.dynamic is True:
return

# initialize on execution device to avoid performing quantized ops on cpu
device = get_execution_device(module)

Expand All @@ -170,7 +160,8 @@ def _initialize_scale_zero_point(
else:
expected_shape = 1

if base_name == "weight" and weight_shape is not None:
if base_name == "weight":
weight_shape = getattr(module, "weight").shape
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
# (output_channels, 1)
expected_shape = (weight_shape[0], 1)
Expand All @@ -182,7 +173,13 @@ def _initialize_scale_zero_point(
expected_shape = (weight_shape[0], max(num_groups, 1))

# 3. Identify quantization scale and zp dtype
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
if scale_dtype is None:
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
scale_dtype = module.weight.dtype
elif is_attention_module(module):
scale_dtype = next(module.parameters()).dtype
else:
raise ValueError()

if is_fp4(quantization_args=quantization_args):
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
Expand All @@ -195,7 +192,7 @@ def _initialize_scale_zero_point(

# 4. Initializes empty scale, zero point, and g_idx parameters for the module
# do not init scales for quantzation_args.dynamic == DynamicType.local
if not quantization_args.dynamic:
if quantization_args.dynamic is False:
init_scale = Parameter(
torch.empty(expected_shape, dtype=scale_dtype, device=device),
requires_grad=False,
Expand All @@ -220,23 +217,9 @@ def _initialize_scale_zero_point(
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)


def _initialize_attn_scales(module: Module) -> None:
"""Initlaize k_scale, v_scale for self_attn"""

expected_shape = 1 # per tensor

param = next(module.parameters())
scale_dtype = param.dtype
device = param.device
def get_calibrated_locations(scheme: QuantizationScheme) -> Tuple[bool, bool, bool]:
input = scheme.input_activations and scheme.input_activations.dynamic is not True
weight = scheme.weights is not None
output = scheme.output_activations and scheme.output_activations.dynamic is not True

init_scale = Parameter(
torch.empty(expected_shape, dtype=scale_dtype, device=device),
requires_grad=False,
)
register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale)

init_scale = Parameter(
torch.empty(expected_shape, dtype=scale_dtype, device=device),
requires_grad=False,
)
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
return input, weight, output
30 changes: 30 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,33 @@ def is_preset_scheme(name: str) -> bool:
),
)

# FP8 attention quantization
FP8_ATTN = dict(
targets=["re:.*self_attn$"],
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TOKEN,
symmetric=True,
dynamic=False,
observer=None,
),
)

# FP4 attention quantization
NVFP4_ATTN = dict(
targets=["re:.*self_attn$"],
input_activations=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR_GROUP,
symmetric=True,
dynamic=DynamicType.LOCAL,
group_size=16,
),
)


PRESET_SCHEMES = {
# Unquantized (no-op)
"UNQUANTIZED": UNQUANTIZED,
Expand All @@ -259,4 +286,7 @@ def is_preset_scheme(name: str) -> bool:
"FP8_DYNAMIC": FP8_DYNAMIC,
"NVFP4A16": NVFP4A16,
"NVFP4": NVFP4,
# Attention activation schemes
"FP8_ATTN": FP8_ATTN,
"NVFP4_ATTN": NVFP4_ATTN,
}
9 changes: 5 additions & 4 deletions src/compressed_tensors/transform/transform_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ class TransformLocation(str, Enum):
| `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501
| `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
| `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
| `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501
| `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501
| `ATTN_Q` | online | query_states | `this.ATTN_K` | # noqa: E501
| `ATTN_K` | online | key_states | `this.Q_ATTN` | # noqa: E501
| -------------------------------------------------------------------------------------------------------- | # noqa: E501
"""

INPUT = "input"
WEIGHT_INPUT = "weight_input"
WEIGHT_OUTPUT = "weight_output"
OUTPUT = "output"
K_CACHE = "k_cache"
Q_ATTN = "q_attn"
ATTN_Q = "attn_q"
ATTN_K = "attn_k"
# ATTN_V = "attn_v"


class TransformArgs(BaseModel):
Expand Down