From fcc89b96c7dbea8b7ded1880ec7a84c65281c6ab Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 27 Mar 2025 21:31:09 +0000 Subject: [PATCH 1/6] update forward --- .../quantization/lifecycle/forward.py | 32 ++++++++++++++++++- .../quantization/lifecycle/initialize.py | 12 +++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 510f7e29..b5bd2397 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -14,7 +14,7 @@ from functools import wraps from math import ceil -from typing import Optional +from typing import Any, Optional import torch from compressed_tensors.quantization.quant_args import ( @@ -42,6 +42,8 @@ "fake_quantize", "wrap_module_forward_quantized", "forward_quantize", + "pre_forward_quantize", + "post_forward_quantize", ] @@ -258,6 +260,34 @@ def _process_quantization( return output +def pre_forward_quantize(module: Module, args: Any): + scheme = module.quantization_scheme + compressed = module.quantization_status == QuantizationStatus.COMPRESSED + + input_ = args[0] + if scheme.input_activations is not None: + # prehook should calibrate activations before forward call + input_ = forward_quantize(module, input_, "input", scheme.input_activations) + + if scheme.weights is not None and not compressed: + setattr(module, "unquantized_weight", module.weight.data.clone()) + module.weight.data = forward_quantize( + module, module.weight, "weight", scheme.weights + ) + + return input_ + + +def post_forward_quantize(module: Module, _args: Any, output: torch.Tensor): + scheme = module.quantization_scheme + compressed = module.quantization_status == QuantizationStatus.COMPRESSED + + if scheme.weights is not None and not compressed: + module.weight.data = module.getattr("unquantized_weight") + + return output + + def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): # expects a module already initialized and injected with the parameters in # initialize_module_for_quantization diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 6886423a..f8775a23 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -18,8 +18,9 @@ from typing import Optional import torch -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, +from compressed_tensors.quantization.lifecycle.forward import ( # wrap_module_forward_quantized, + post_forward_quantize, + pre_forward_quantize, ) from compressed_tensors.quantization.quant_args import ( ActivationOrdering, @@ -119,7 +120,12 @@ def initialize_module_for_quantization( with disable_hf_hook(module): # wrap forward call of module to perform # quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) + post_forward_wrapper = lambda module, input, output: post_forward_quantize( + input, output, scheme=scheme + ) + module.register_forward_pre_hook(pre_forward_quantize) + module.register_forward_hook(post_forward_wrapper) + # wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): From 3879da8f878d07c9e92fc34078a0ead6bc6d4576 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 28 Mar 2025 22:07:45 +0000 Subject: [PATCH 2/6] update forward pass --- .../quantization/lifecycle/forward.py | 21 +++++++++++++++++-- .../quantization/lifecycle/initialize.py | 9 ++++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b5bd2397..5c71405d 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -261,10 +261,15 @@ def _process_quantization( def pre_forward_quantize(module: Module, args: Any): + print("pre forward") + if not getattr(module, "quantization_enabled", True): + return args[0] + scheme = module.quantization_scheme compressed = module.quantization_status == QuantizationStatus.COMPRESSED input_ = args[0] + breakpoint() if scheme.input_activations is not None: # prehook should calibrate activations before forward call input_ = forward_quantize(module, input_, "input", scheme.input_activations) @@ -274,17 +279,29 @@ def pre_forward_quantize(module: Module, args: Any): module.weight.data = forward_quantize( module, module.weight, "weight", scheme.weights ) - return input_ def post_forward_quantize(module: Module, _args: Any, output: torch.Tensor): + print("post forward") + if not getattr(module, "quantization_enabled", True): + return output + scheme = module.quantization_scheme compressed = module.quantization_status == QuantizationStatus.COMPRESSED if scheme.weights is not None and not compressed: - module.weight.data = module.getattr("unquantized_weight") + module.weight.data = getattr(module, "unquantized_weight") + + if scheme.output_activations is not None: + # forward-hook should calibrate/forward_quantize right afer this + if ( + module.quantization_status == QuantizationStatus.CALIBRATION + and not scheme.output_activations.dynamic + ): + return output + output = forward_quantize(module, output, "output", scheme.output_activations) return output diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index f8775a23..73bb52aa 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -18,9 +18,10 @@ from typing import Optional import torch -from compressed_tensors.quantization.lifecycle.forward import ( # wrap_module_forward_quantized, +from compressed_tensors.quantization.lifecycle.forward import ( post_forward_quantize, pre_forward_quantize, + wrap_module_forward_quantized, ) from compressed_tensors.quantization.quant_args import ( ActivationOrdering, @@ -117,14 +118,12 @@ def initialize_module_for_quantization( module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED + # TODO: shouldn't need this anymore as we're no longer wrapping? with disable_hf_hook(module): # wrap forward call of module to perform # quantized actions based on calltime status - post_forward_wrapper = lambda module, input, output: post_forward_quantize( - input, output, scheme=scheme - ) module.register_forward_pre_hook(pre_forward_quantize) - module.register_forward_hook(post_forward_wrapper) + module.register_forward_hook(post_forward_quantize) # wrap_module_forward_quantized(module, scheme) From 75e96f20db3b1df50c9141377eac9263291c0003 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 31 Mar 2025 18:26:20 +0000 Subject: [PATCH 3/6] apply transforms to activations --- .../quantization/lifecycle/forward.py | 47 +++++++++++------ .../quantization/lifecycle/initialize.py | 6 +-- src/compressed_tensors/transforms/apply.py | 50 +++++++++++-------- 3 files changed, 63 insertions(+), 40 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 5c71405d..df459a41 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -29,8 +29,8 @@ compute_dynamic_scales_and_zp, ) from compressed_tensors.transforms.apply import ( - apply_inverse_transforms_to_parameter, - apply_transforms_to_parameter, + apply_inverse_transforms_to_activations_or_parameter, + apply_transforms_to_activations_or_parameter, ) from compressed_tensors.utils import safe_permute from torch.nn import Module @@ -260,30 +260,30 @@ def _process_quantization( return output -def pre_forward_quantize(module: Module, args: Any): - print("pre forward") +def pre_forward_quantize(module: Module, input: Any): if not getattr(module, "quantization_enabled", True): - return args[0] + return input + input = input[0] scheme = module.quantization_scheme compressed = module.quantization_status == QuantizationStatus.COMPRESSED - input_ = args[0] - breakpoint() if scheme.input_activations is not None: # prehook should calibrate activations before forward call - input_ = forward_quantize(module, input_, "input", scheme.input_activations) + breakpoint() + input = forward_quantize(module, input, "input", scheme.input_activations) + breakpoint() if scheme.weights is not None and not compressed: setattr(module, "unquantized_weight", module.weight.data.clone()) module.weight.data = forward_quantize( module, module.weight, "weight", scheme.weights ) - return input_ + breakpoint() + return (input,) -def post_forward_quantize(module: Module, _args: Any, output: torch.Tensor): - print("post forward") +def post_forward_quantize(module: Module, input: Any, output: torch.Tensor): if not getattr(module, "quantization_enabled", True): return output @@ -323,19 +323,33 @@ def wrapped_forward(self, *args, **kwargs): input_ = args[0] compressed = module.quantization_status == QuantizationStatus.COMPRESSED + transform_data = getattr(module, "transform_data", None) if scheme.input_activations is not None: # prehook should calibrate activations before forward call + if transform_data is not None: + input_ = apply_transforms_to_activations_or_parameter( + module=module, + module_activation_or_parameter=input_, + transform_data=transform_data, + update_in_place=False, + ) input_ = forward_quantize(module, input_, "input", scheme.input_activations) + if transform_data is not None: + input_ = apply_inverse_transforms_to_activations_or_parameter( + module=module, + module_activation_or_parameter=input_, + transform_data=transform_data, + update_in_place=False, + ) if scheme.weights is not None and not compressed: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() - transform_data = getattr(module, "transform_data", None) if transform_data is not None: - apply_transforms_to_parameter( + apply_transforms_to_activations_or_parameter( module=module, - module_parameter=self.weight, + module_activation_or_parameter=self.weight, transform_data=transform_data, ) @@ -344,9 +358,9 @@ def wrapped_forward(self, *args, **kwargs): ) if transform_data is not None: - apply_inverse_transforms_to_parameter( + apply_inverse_transforms_to_activations_or_parameter( module=module, - module_parameter=self.weight, + module_activation_or_parameter=self.weight, transform_data=transform_data, ) @@ -405,6 +419,7 @@ def forward_quantize( scale = getattr(module, f"{base_name}_scale") zero_point = getattr(module, f"{base_name}_zero_point", None) + breakpoint() return fake_quantize( x=value, scale=scale, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 73bb52aa..e853a870 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -122,9 +122,9 @@ def initialize_module_for_quantization( with disable_hf_hook(module): # wrap forward call of module to perform # quantized actions based on calltime status - module.register_forward_pre_hook(pre_forward_quantize) - module.register_forward_hook(post_forward_quantize) - # wrap_module_forward_quantized(module, scheme) + # module.register_forward_pre_hook(pre_forward_quantize) + # module.register_forward_hook(post_forward_quantize) + wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): diff --git a/src/compressed_tensors/transforms/apply.py b/src/compressed_tensors/transforms/apply.py index b3a5d18b..d39d8bd0 100644 --- a/src/compressed_tensors/transforms/apply.py +++ b/src/compressed_tensors/transforms/apply.py @@ -17,58 +17,66 @@ import torch from compressed_tensors.transforms import Transforms from compressed_tensors.transforms.transform_data import TransformData +from compressed_tensors.utils import update_parameter_data -__all__ = ["apply_transforms_to_parameter", "apply_inverse_transforms_to_parameter"] +__all__ = [ + "apply_transforms_to_activations_or_parameter", + "apply_inverse_transforms_to_activations_or_parameter", +] -def apply_transforms_to_parameter( +def apply_transforms_to_activations_or_parameter( module: torch.nn.Module, - module_parameter: torch.nn.Parameter, + module_activation_or_parameter: torch.Tensor, transform_data: TransformData, -): + update_in_place: Optional[bool] = True, + base_name: Optional[str] = "weight", +) -> Optional[torch.Tensor]: """ Apply all transforms relevant to a parameter using a module's transform data. The parameter data is updated in-place. :param module: torch.nn.Moudle - :param module_parameter: the torch.nn.Parameter to transform + :param module_activation_or_parameter: module Parameter or activations to transform :param transform_data: a module's TransformData - - Only implemented for weight parameters thus far. - """ for transform_name, transform_values in transform_data.data.items(): transform = transform_values.get("transform") call_args = transform_values.get("call_args") - transformed_param_data = transform.apply( - input_tensor=module_parameter, **call_args + transformed_output_data = transform.apply( + input_tensor=module_activation_or_parameter, **call_args ) - module_parameter.data.copy_(transformed_param_data) + if not update_in_place: + return transformed_output_data + + update_parameter_data(module, transformed_output_data, base_name) -def apply_inverse_transforms_to_parameter( +def apply_inverse_transforms_to_activations_or_parameter( module: torch.nn.Module, - module_parameter: torch.nn.Parameter, + module_activation_or_parameter: torch.nn.Parameter, transform_data: TransformData, -): + update_in_place: Optional[bool] = True, + base_name: Optional[str] = "weight", +) -> Optional[torch.Tensor]: """ Apply all inverse transform operations relevant to a parameter using a module's TransformData. The parameter data is updated in-place. :param module: torch.nn.Moudle - :param module_parameter: the torch.nn.Parameter to transform + :param module_activation_or_parameter: module Parameter or activations to transform :param transform_data: a module's TransformData - - Only implemented for weight parameters thus far. - """ for transform_name, transform_values in reversed(transform_data.data.items()): transform = transform_values.get("transform") call_args = transform_values.get("call_args") - transformed_param_data = transform.inverse_apply( - input_tensor=module_parameter, **call_args + transformed_output_data = transform.inverse_apply( + input_tensor=module_activation_or_parameter, **call_args ) - module_parameter.data.copy_(transformed_param_data) + if not update_in_place: + return transformed_output_data + + update_parameter_data(module, transformed_output_data, base_name) From 06d4cab40dc2faf51efb49a4739ab580b0179338 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 3 Apr 2025 21:38:46 +0000 Subject: [PATCH 4/6] update --- .../quantization/lifecycle/apply.py | 16 +- .../quantization/lifecycle/forward.py | 145 +++++------------- .../quantization/lifecycle/initialize.py | 23 ++- .../transforms/transform_data.py | 5 +- 4 files changed, 73 insertions(+), 116 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 255754df..4e7509be 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -227,6 +227,7 @@ def process_transforms_config( transform_data = TransformData(data=OrderedDict(data)) submodule.transform_data = transform_data # 10358 for now mib; 1/3 of memory if caching/sharing parameter data + breakpoint() # memory should not go up with inputs, same transform return model @@ -234,7 +235,8 @@ def apply_quantization_config( model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False, - transforms_config=None, + transforms_config: Optional[TransformationConfig] = None, + delay_forward_quantize: Optional[bool] = False, ) -> OrderedDict: """ Initializes the model for quantization in-place based on the given config. @@ -320,7 +322,9 @@ def apply_quantization_config( ) # apply current quantization status across all targeted layers - apply_quantization_status(model, config.quantization_status) + apply_quantization_status( + model, config.quantization_status, delay_forward_quantize=delay_forward_quantize + ) return names_to_scheme @@ -360,7 +364,9 @@ def process_kv_cache_config( return config -def apply_quantization_status(model: Module, status: QuantizationStatus): +def apply_quantization_status( + model: Module, status: QuantizationStatus, delay_forward_quantize: bool +): """ Applies in place the quantization lifecycle up to the given status @@ -374,7 +380,9 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): force_zero_point_init = status != QuantizationStatus.COMPRESSED model.apply( lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init + module, + force_zero_point=force_zero_point_init, + delay_forward_quantize=delay_forward_quantize, ) ) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index df459a41..001b6bf6 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -40,7 +40,6 @@ "quantize", "dequantize", "fake_quantize", - "wrap_module_forward_quantized", "forward_quantize", "pre_forward_quantize", "post_forward_quantize", @@ -264,23 +263,57 @@ def pre_forward_quantize(module: Module, input: Any): if not getattr(module, "quantization_enabled", True): return input - input = input[0] + input_ = input[0] scheme = module.quantization_scheme compressed = module.quantization_status == QuantizationStatus.COMPRESSED + transform_data = getattr(module, "transform_data", None) + + # Input Activations + # TODO: break into their own func/hook if scheme.input_activations is not None: - # prehook should calibrate activations before forward call - breakpoint() - input = forward_quantize(module, input, "input", scheme.input_activations) - breakpoint() + if transform_data is not None: + input_ = apply_transforms_to_activations_or_parameter( + module=module, + module_activation_or_parameter=input_, + transform_data=transform_data, + update_in_place=False, + ) + + input_ = forward_quantize(module, input_, "input", scheme.input_activations) + + if transform_data is not None: + input_ = apply_inverse_transforms_to_activations_or_parameter( + module=module, + module_activation_or_parameter=input_, + transform_data=transform_data, + update_in_place=False, + ) + # Weights + # TODO: break into their own func/hook if scheme.weights is not None and not compressed: setattr(module, "unquantized_weight", module.weight.data.clone()) + + if transform_data is not None: + apply_transforms_to_activations_or_parameter( + module=module, + module_activation_or_parameter=module.weight, + transform_data=transform_data, + ) + module.weight.data = forward_quantize( module, module.weight, "weight", scheme.weights ) - breakpoint() - return (input,) + + if transform_data is not None: + apply_inverse_transforms_to_activations_or_parameter( + module=module, + module_activation_or_parameter=module.weight, + transform_data=transform_data, + ) + + return (input_,) def post_forward_quantize(module: Module, input: Any, output: torch.Tensor): @@ -294,102 +327,9 @@ def post_forward_quantize(module: Module, input: Any, output: torch.Tensor): module.weight.data = getattr(module, "unquantized_weight") if scheme.output_activations is not None: - # forward-hook should calibrate/forward_quantize right afer this - if ( - module.quantization_status == QuantizationStatus.CALIBRATION - and not scheme.output_activations.dynamic - ): - return output - output = forward_quantize(module, output, "output", scheme.output_activations) - return output - - -def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): - # expects a module already initialized and injected with the parameters in - # initialize_module_for_quantization - if hasattr(module.forward, "__func__"): - forward_func_orig = module.forward.__func__ - else: - forward_func_orig = module.forward.func - - @wraps(forward_func_orig) # ensures docstring, names, etc are propagated - def wrapped_forward(self, *args, **kwargs): - if not getattr(module, "quantization_enabled", True): - # quantization is disabled on forward passes, return baseline - # forward call - return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) - - input_ = args[0] - - compressed = module.quantization_status == QuantizationStatus.COMPRESSED - transform_data = getattr(module, "transform_data", None) - - if scheme.input_activations is not None: - # prehook should calibrate activations before forward call - if transform_data is not None: - input_ = apply_transforms_to_activations_or_parameter( - module=module, - module_activation_or_parameter=input_, - transform_data=transform_data, - update_in_place=False, - ) - input_ = forward_quantize(module, input_, "input", scheme.input_activations) - if transform_data is not None: - input_ = apply_inverse_transforms_to_activations_or_parameter( - module=module, - module_activation_or_parameter=input_, - transform_data=transform_data, - update_in_place=False, - ) - - if scheme.weights is not None and not compressed: - # calibrate and (fake) quantize weights when applicable - unquantized_weight = self.weight.data.clone() - if transform_data is not None: - apply_transforms_to_activations_or_parameter( - module=module, - module_activation_or_parameter=self.weight, - transform_data=transform_data, - ) - - self.weight.data = forward_quantize( - module, self.weight, "weight", scheme.weights - ) - - if transform_data is not None: - apply_inverse_transforms_to_activations_or_parameter( - module=module, - module_activation_or_parameter=self.weight, - transform_data=transform_data, - ) - # perform wrapped forward call - output = forward_func_orig.__get__(module, module.__class__)( - input_, *args[1:], **kwargs - ) - - # restore back to unquantized_value - if scheme.weights is not None and not compressed: - self.weight.data = unquantized_weight - - if scheme.output_activations is not None: - # forward-hook should calibrate/forward_quantize - if ( - module.quantization_status == QuantizationStatus.CALIBRATION - and not scheme.output_activations.dynamic - ): - return output - - output = forward_quantize( - module, output, "output", scheme.output_activations - ) - return output - - # bind wrapped forward to module class so reference to `self` is correct - bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) - # set forward to wrapped forward - setattr(module, "forward", bound_wrapped_forward) + return output def forward_quantize( @@ -419,7 +359,6 @@ def forward_quantize( scale = getattr(module, f"{base_name}_scale") zero_point = getattr(module, f"{base_name}_zero_point", None) - breakpoint() return fake_quantize( x=value, scale=scale, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index e853a870..adce599c 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -21,7 +21,6 @@ from compressed_tensors.quantization.lifecycle.forward import ( post_forward_quantize, pre_forward_quantize, - wrap_module_forward_quantized, ) from compressed_tensors.quantization.quant_args import ( ActivationOrdering, @@ -43,6 +42,7 @@ "initialize_module_for_quantization", "is_attention_module", "KVCacheScaleType", + "register_quantization_hooks", ] @@ -54,10 +54,22 @@ class KVCacheScaleType(Enum): VALUE = "v_scale" +def register_quantization_hooks(module: Module): + # TODO: some of these checks may be redundant + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + return + + if not is_attention_module(module): + module.register_forward_pre_hook(pre_forward_quantize) + module.register_forward_hook(post_forward_quantize) + + def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, + delay_forward_quantize: bool = False, ): """ attaches appropriate scales, zero points, and observers to a layer @@ -119,12 +131,9 @@ def initialize_module_for_quantization( module.quantization_status = QuantizationStatus.INITIALIZED # TODO: shouldn't need this anymore as we're no longer wrapping? - with disable_hf_hook(module): - # wrap forward call of module to perform - # quantized actions based on calltime status - # module.register_forward_pre_hook(pre_forward_quantize) - # module.register_forward_hook(post_forward_quantize) - wrap_module_forward_quantized(module, scheme) + if not delay_forward_quantize: + with disable_hf_hook(module): + register_quantization_hooks(module) def is_attention_module(module: Module): diff --git a/src/compressed_tensors/transforms/transform_data.py b/src/compressed_tensors/transforms/transform_data.py index 3b9b6e0e..35cf5d16 100644 --- a/src/compressed_tensors/transforms/transform_data.py +++ b/src/compressed_tensors/transforms/transform_data.py @@ -31,5 +31,6 @@ class TransformData: transform_data = TransformData(data=data) """ - data: Dict - idx: int = 0 + weight_transforms: Dict + input_transforms: Dict + output_transforms: Dict From dc14d212a0c457389bc847e308d67985cad08f27 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 7 Apr 2025 19:49:37 +0000 Subject: [PATCH 5/6] update --- .../quantization/lifecycle/apply.py | 28 ++++++++++++------- src/compressed_tensors/transforms/base.py | 2 +- src/compressed_tensors/transforms/hadamard.py | 2 +- .../transforms/random_hadamard.py | 21 +++++++------- .../transforms/transform_data.py | 5 ++-- 5 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 4e7509be..00cf65d0 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -150,6 +150,8 @@ def process_transforms_config( # Each group/scheme targets one type of transform transform_type = group.transform_type transform_creation_args = group.transform_creation_args + shared = group.shared + transform = None # Need a better name - too many groups for transform_arg in group.groups: @@ -201,17 +203,19 @@ def process_transforms_config( **transform_creation_args, ) else: - transform = Transforms.load_from_registry( - transform_type, - dtype=dtype, - transform_name=transform_name, - permutation_name=permutation_name, - device=next(submodule.parameters()).device, - **transform_creation_args, - ) - + # should mean we have identical permuation matrices for all shared submodules + if transform is None: + transform = Transforms.load_from_registry( + transform_type, + dtype=dtype, + device=next(submodule.parameters()).device, + **transform_creation_args, + ) + + transform.transform_name = transform_name + transform.permutation_name = permutation_name transform.register_to_module(module=submodule) - + # add relevant transform data to the submodule as well data = { transform_name: { @@ -226,6 +230,10 @@ def process_transforms_config( else: transform_data = TransformData(data=OrderedDict(data)) submodule.transform_data = transform_data + + if not shared: + transform = None + # 10358 for now mib; 1/3 of memory if caching/sharing parameter data breakpoint() # memory should not go up with inputs, same transform return model diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index 715c0625..0c7a9d77 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -33,7 +33,7 @@ class Transforms(RegistryMixin): def __init__( self, transform: torch.Tensor, - transform_name: str, + transform_name: Optional[str] = None, permutation: Optional[torch.Tensor] = None, permutation_name: Optional[str] = None, ): diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py index 7f280e55..0b1d19a6 100644 --- a/src/compressed_tensors/transforms/hadamard.py +++ b/src/compressed_tensors/transforms/hadamard.py @@ -28,7 +28,7 @@ class Hadamard(Transforms): def __init__( self, size: int, - transform_name: str, + transform_name: Optional[str] = None, empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cpu", dtype: Optional[torch.dtype] = torch.bfloat16, diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py index 20e52fda..ffaf2998 100644 --- a/src/compressed_tensors/transforms/random_hadamard.py +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -31,8 +31,8 @@ class RandomHadamard(Transforms): def __init__( self, size: int, - transform_name: str, - permutation_name: str, + transform_name: Optional[str] = None, + permutation_name: Optional[str] = None, empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cpu", dtype: Optional[torch.dtype] = torch.bfloat16, @@ -72,7 +72,7 @@ def __init__( transform = torch.empty((size, size)).to(dtype) permutation = torch.empty((size)).to(dtype).to(device) else: - transform = self.fetch().to(dtype).to(device) + transform = self.fetch(dtype, device) permutation = ( (torch.randint(low=0, high=2, size=(self.size,)) * 2 - 1) .to(dtype) @@ -86,13 +86,11 @@ def __init__( permutation_name=permutation_name, ) - if not self.matrix_registry.contains(size): - self.matrix_registry.set_matrix(self.size, self.transform) - - def fetch(self): + def fetch(self, dtype, device): transform = self.matrix_registry.get_matrix(self.size) if transform is None: - transform = random_hadamard_matrix(size=self.size) + transform = random_hadamard_matrix(size=self.size).to(dtype).to(device) + self.matrix_registry.set_matrix(self.size, transform) return transform def apply( @@ -101,8 +99,10 @@ def apply( transpose: bool = False, first: bool = True, ) -> torch.Tensor: + + # Too slow? return apply_matrix_transform( - transform=(self.transform * self.permutation) / self.normalized_size, + transform=(self.transform.to(input_tensor.device) * self.permutation.to(input_tensor.device)) / self.normalized_size, input_tensor=input_tensor, transpose=transpose, first=first, @@ -124,10 +124,9 @@ def inverse_apply( :param first: if the transform matrix will be the first or second matrix to be multiplied """ - transpose = not transpose return apply_matrix_transform( - transform=(self.transform * self.permutation) / self.normalized_size, + transform=(self.transform.to(input_tensor.device) * self.permutation.to(input_tensor.device)) / self.normalized_size, input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/src/compressed_tensors/transforms/transform_data.py b/src/compressed_tensors/transforms/transform_data.py index 35cf5d16..3b9b6e0e 100644 --- a/src/compressed_tensors/transforms/transform_data.py +++ b/src/compressed_tensors/transforms/transform_data.py @@ -31,6 +31,5 @@ class TransformData: transform_data = TransformData(data=data) """ - weight_transforms: Dict - input_transforms: Dict - output_transforms: Dict + data: Dict + idx: int = 0 From db6dd28885c9cb7784983213451ed88e49dd658f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 5 May 2025 21:08:19 +0000 Subject: [PATCH 6/6] update ct --- .../quantization/lifecycle/apply.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 00cf65d0..2f341ea7 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -108,6 +108,9 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str): :param model: model to load pretrained quantization parameters to :param model_name_or_path: Hugging Face stub or local folder containing a quantized model, which is used to load quantization parameters + + Note: this currently does not process/support shared transforms i.e transforms with + identical permutation """ model_path = get_safetensors_folder(model_name_or_path) state_dict = get_quantization_state_dict(model_path) @@ -150,8 +153,6 @@ def process_transforms_config( # Each group/scheme targets one type of transform transform_type = group.transform_type transform_creation_args = group.transform_creation_args - shared = group.shared - transform = None # Need a better name - too many groups for transform_arg in group.groups: @@ -194,6 +195,7 @@ def process_transforms_config( QuantizationStatus.COMPRESSED, QuantizationStatus.FROZEN, ]: + # empty tensor to load the parameter from disk transform = Transforms.load_from_registry( transform_type, dtype=dtype, @@ -204,13 +206,12 @@ def process_transforms_config( ) else: # should mean we have identical permuation matrices for all shared submodules - if transform is None: - transform = Transforms.load_from_registry( - transform_type, - dtype=dtype, - device=next(submodule.parameters()).device, - **transform_creation_args, - ) + transform = Transforms.load_from_registry( + transform_type, + dtype=dtype, + device=next(submodule.parameters()).device, + **transform_creation_args, + ) transform.transform_name = transform_name transform.permutation_name = permutation_name @@ -231,11 +232,6 @@ def process_transforms_config( transform_data = TransformData(data=OrderedDict(data)) submodule.transform_data = transform_data - if not shared: - transform = None - - # 10358 for now mib; 1/3 of memory if caching/sharing parameter data - breakpoint() # memory should not go up with inputs, same transform return model