Skip to content

[Transforms][WIP] Update wrapped_forward to use hooks; apply transforms to activations #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: matrix_registry
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
:param model: model to load pretrained quantization parameters to
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
model, which is used to load quantization parameters

Note: this currently does not process/support shared transforms i.e transforms with
identical permutation
"""
model_path = get_safetensors_folder(model_name_or_path)
state_dict = get_quantization_state_dict(model_path)
Expand Down Expand Up @@ -192,6 +195,7 @@ def process_transforms_config(
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
]:
# empty tensor to load the parameter from disk
transform = Transforms.load_from_registry(
transform_type,
dtype=dtype,
Expand All @@ -201,17 +205,18 @@ def process_transforms_config(
**transform_creation_args,
)
else:
# should mean we have identical permuation matrices for all shared submodules
transform = Transforms.load_from_registry(
transform_type,
dtype=dtype,
transform_name=transform_name,
permutation_name=permutation_name,
device=next(submodule.parameters()).device,
**transform_creation_args,
)

)

transform.transform_name = transform_name
transform.permutation_name = permutation_name
transform.register_to_module(module=submodule)

# add relevant transform data to the submodule as well
data = {
transform_name: {
Expand All @@ -226,15 +231,16 @@ def process_transforms_config(
else:
transform_data = TransformData(data=OrderedDict(data))
submodule.transform_data = transform_data
# 10358 for now mib; 1/3 of memory if caching/sharing parameter data

return model


def apply_quantization_config(
model: Module,
config: Union[QuantizationConfig, None],
run_compressed: bool = False,
transforms_config=None,
transforms_config: Optional[TransformationConfig] = None,
delay_forward_quantize: Optional[bool] = False,
) -> OrderedDict:
"""
Initializes the model for quantization in-place based on the given config.
Expand Down Expand Up @@ -320,7 +326,9 @@ def apply_quantization_config(
)

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
apply_quantization_status(
model, config.quantization_status, delay_forward_quantize=delay_forward_quantize
)
return names_to_scheme


Expand Down Expand Up @@ -360,7 +368,9 @@ def process_kv_cache_config(
return config


def apply_quantization_status(model: Module, status: QuantizationStatus):
def apply_quantization_status(
model: Module, status: QuantizationStatus, delay_forward_quantize: bool
):
"""
Applies in place the quantization lifecycle up to the given status

Expand All @@ -374,7 +384,9 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
force_zero_point_init = status != QuantizationStatus.COMPRESSED
model.apply(
lambda module: initialize_module_for_quantization(
module, force_zero_point=force_zero_point_init
module,
force_zero_point=force_zero_point_init,
delay_forward_quantize=delay_forward_quantize,
)
)

Expand Down
131 changes: 66 additions & 65 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from functools import wraps
from math import ceil
from typing import Optional
from typing import Any, Optional

import torch
from compressed_tensors.quantization.quant_args import (
Expand All @@ -29,8 +29,8 @@
compute_dynamic_scales_and_zp,
)
from compressed_tensors.transforms.apply import (
apply_inverse_transforms_to_parameter,
apply_transforms_to_parameter,
apply_inverse_transforms_to_activations_or_parameter,
apply_transforms_to_activations_or_parameter,
)
from compressed_tensors.utils import safe_permute
from torch.nn import Module
Expand All @@ -40,8 +40,9 @@
"quantize",
"dequantize",
"fake_quantize",
"wrap_module_forward_quantized",
"forward_quantize",
"pre_forward_quantize",
"post_forward_quantize",
]


Expand Down Expand Up @@ -258,77 +259,77 @@ def _process_quantization(
return output


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
# expects a module already initialized and injected with the parameters in
# initialize_module_for_quantization
if hasattr(module.forward, "__func__"):
forward_func_orig = module.forward.__func__
else:
forward_func_orig = module.forward.func

@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
def wrapped_forward(self, *args, **kwargs):
if not getattr(module, "quantization_enabled", True):
# quantization is disabled on forward passes, return baseline
# forward call
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)

input_ = args[0]

compressed = module.quantization_status == QuantizationStatus.COMPRESSED

if scheme.input_activations is not None:
# prehook should calibrate activations before forward call
input_ = forward_quantize(module, input_, "input", scheme.input_activations)

if scheme.weights is not None and not compressed:
# calibrate and (fake) quantize weights when applicable
unquantized_weight = self.weight.data.clone()
transform_data = getattr(module, "transform_data", None)
if transform_data is not None:
apply_transforms_to_parameter(
module=module,
module_parameter=self.weight,
transform_data=transform_data,
)
def pre_forward_quantize(module: Module, input: Any):
if not getattr(module, "quantization_enabled", True):
return input

input_ = input[0]
scheme = module.quantization_scheme
compressed = module.quantization_status == QuantizationStatus.COMPRESSED

self.weight.data = forward_quantize(
module, self.weight, "weight", scheme.weights
transform_data = getattr(module, "transform_data", None)

# Input Activations
# TODO: break into their own func/hook
if scheme.input_activations is not None:
if transform_data is not None:
input_ = apply_transforms_to_activations_or_parameter(
module=module,
module_activation_or_parameter=input_,
transform_data=transform_data,
update_in_place=False,
)

if transform_data is not None:
apply_inverse_transforms_to_parameter(
module=module,
module_parameter=self.weight,
transform_data=transform_data,
)
input_ = forward_quantize(module, input_, "input", scheme.input_activations)

# perform wrapped forward call
output = forward_func_orig.__get__(module, module.__class__)(
input_, *args[1:], **kwargs
)
if transform_data is not None:
input_ = apply_inverse_transforms_to_activations_or_parameter(
module=module,
module_activation_or_parameter=input_,
transform_data=transform_data,
update_in_place=False,
)

# restore back to unquantized_value
if scheme.weights is not None and not compressed:
self.weight.data = unquantized_weight
# Weights
# TODO: break into their own func/hook
if scheme.weights is not None and not compressed:
setattr(module, "unquantized_weight", module.weight.data.clone())

if scheme.output_activations is not None:
# forward-hook should calibrate/forward_quantize
if (
module.quantization_status == QuantizationStatus.CALIBRATION
and not scheme.output_activations.dynamic
):
return output
if transform_data is not None:
apply_transforms_to_activations_or_parameter(
module=module,
module_activation_or_parameter=module.weight,
transform_data=transform_data,
)

output = forward_quantize(
module, output, "output", scheme.output_activations
module.weight.data = forward_quantize(
module, module.weight, "weight", scheme.weights
)

if transform_data is not None:
apply_inverse_transforms_to_activations_or_parameter(
module=module,
module_activation_or_parameter=module.weight,
transform_data=transform_data,
)

return (input_,)


def post_forward_quantize(module: Module, input: Any, output: torch.Tensor):
if not getattr(module, "quantization_enabled", True):
return output

# bind wrapped forward to module class so reference to `self` is correct
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
# set forward to wrapped forward
setattr(module, "forward", bound_wrapped_forward)
scheme = module.quantization_scheme
compressed = module.quantization_status == QuantizationStatus.COMPRESSED

if scheme.weights is not None and not compressed:
module.weight.data = getattr(module, "unquantized_weight")

if scheme.output_activations is not None:
output = forward_quantize(module, output, "output", scheme.output_activations)

return output


def forward_quantize(
Expand Down
24 changes: 19 additions & 5 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import torch
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
post_forward_quantize,
pre_forward_quantize,
)
from compressed_tensors.quantization.quant_args import (
ActivationOrdering,
Expand All @@ -41,6 +42,7 @@
"initialize_module_for_quantization",
"is_attention_module",
"KVCacheScaleType",
"register_quantization_hooks",
]


Expand All @@ -52,10 +54,22 @@ class KVCacheScaleType(Enum):
VALUE = "v_scale"


def register_quantization_hooks(module: Module):
# TODO: some of these checks may be redundant
quantization_scheme = getattr(module, "quantization_scheme", None)
if not quantization_scheme:
return

if not is_attention_module(module):
module.register_forward_pre_hook(pre_forward_quantize)
module.register_forward_hook(post_forward_quantize)


def initialize_module_for_quantization(
module: Module,
scheme: Optional[QuantizationScheme] = None,
force_zero_point: bool = True,
delay_forward_quantize: bool = False,
):
"""
attaches appropriate scales, zero points, and observers to a layer
Expand Down Expand Up @@ -116,10 +130,10 @@ def initialize_module_for_quantization(
module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED

with disable_hf_hook(module):
# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)
# TODO: shouldn't need this anymore as we're no longer wrapping?
if not delay_forward_quantize:
with disable_hf_hook(module):
register_quantization_hooks(module)


def is_attention_module(module: Module):
Expand Down
Loading