diff --git a/src/compressed_tensors/base.py b/src/compressed_tensors/base.py index 0e073262..e139f9aa 100644 --- a/src/compressed_tensors/base.py +++ b/src/compressed_tensors/base.py @@ -18,3 +18,4 @@ KV_CACHE_SCHEME_NAME = "kv_cache_scheme" COMPRESSION_VERSION_NAME = "version" QUANTIZATION_METHOD_NAME = "quant_method" +TRANSFORMS_CONFIG = "transforms_config" diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 624201c2..8e2dd2b9 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -29,6 +29,7 @@ QUANTIZATION_CONFIG_NAME, QUANTIZATION_METHOD_NAME, SPARSITY_CONFIG_NAME, + TRANSFORMS_CONFIG, ) from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig @@ -45,6 +46,7 @@ is_module_quantized, iter_named_leaf_modules, ) +from compressed_tensors.transforms.transform_config import TransformationConfig from compressed_tensors.utils import ( get_safetensors_folder, merge_names, @@ -133,6 +135,8 @@ def from_compression_config( sparsity_config = cls.parse_sparsity_config(compression_config) quantization_config = cls.parse_quantization_config(compression_config) + transforms_config = cls.parse_transforms_config(compression_config) + if sparsity_config is None and quantization_config is None: return None @@ -144,8 +148,13 @@ def from_compression_config( if quantization_config is not None: quantization_config = QuantizationConfig.model_validate(quantization_config) + if transforms_config is not None: + transforms_config = TransformationConfig.model_validate(transforms_config) + return cls( - sparsity_config=sparsity_config, quantization_config=quantization_config + sparsity_config=sparsity_config, + quantization_config=quantization_config, + transforms_config=transforms_config, ) @classmethod @@ -170,6 +179,10 @@ def from_pretrained_model( model, format=quantization_format ) + # TODO: update to fetch from the pretrained model + # using the attached config for now + transforms_config = getattr(model, "transforms_config", None) + if isinstance(sparsity_config, str): # we passed in a sparsity format sparsity_config = SparsityCompressionConfig.load_from_registry( sparsity_config @@ -179,9 +192,25 @@ def from_pretrained_model( return None return cls( - sparsity_config=sparsity_config, quantization_config=quantization_config + sparsity_config=sparsity_config, + quantization_config=quantization_config, + transforms_config=transforms_config, ) + @staticmethod + def parse_transforms_config( + compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"] + ) -> Union[Dict[str, Any], None]: + + if compression_config is None: + return None + + if is_compressed_tensors_config(compression_config): + t_config = compression_config.transforms_config + return t_config.model_dump() if t_config is not None else None + + return compression_config.get(TRANSFORMS_CONFIG, None) + @staticmethod def parse_sparsity_config( compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"] @@ -243,9 +272,11 @@ def __init__( self, sparsity_config: Optional[SparsityCompressionConfig] = None, quantization_config: Optional[QuantizationConfig] = None, + transforms_config: Optional[TransformationConfig] = None, ): self.sparsity_config = sparsity_config self.quantization_config = quantization_config + self.transforms_config = transforms_config self.sparsity_compressor = None self.quantization_compressor = None @@ -434,7 +465,9 @@ def decompress(self, model_path: str, model: Module): self.quantization_config, QuantizationStatus.FROZEN ): names_to_scheme = apply_quantization_config( - model, self.quantization_config + model, + self.quantization_config, + transforms_config=self.transforms_config, ) load_pretrained_quantization(model, model_path) @@ -497,6 +530,12 @@ def update_config(self, save_directory: str): SPARSITY_CONFIG_NAME ] = sparsity_config_data + if self.transforms_config is not None: + transforms_config_data = self.transforms_config.to_dict() + config_data[QUANTIZATION_CONFIG_NAME][ + TRANSFORMS_CONFIG + ] = transforms_config_data + with open(config_file_path, "w") as config_file: json.dump(config_data, config_file, indent=2, sort_keys=True) diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index bba47d81..6b3e790f 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -126,6 +126,7 @@ def decompress_weight( :param quantization_args: quantization parameters for the weight :return: tensor of the decompressed weight """ + weight = compressed_data["weight_packed"] scale = compressed_data["weight_scale"] zero_point = compressed_data.get("weight_zero_point", None) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ca8fa68a..c2c5a704 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -41,6 +41,9 @@ iter_named_leaf_modules, iter_named_quantizable_modules, ) +from compressed_tensors.transforms import Transforms +from compressed_tensors.transforms.transform_config import TransformationConfig +from compressed_tensors.transforms.transform_data import TransformData from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder @@ -49,20 +52,50 @@ __all__ = [ "load_pretrained_quantization", + "load_transforms", "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", "expand_target_names", "is_target", + "process_transforms_config", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized -from compressed_tensors.utils.safetensors_load import get_quantization_state_dict +from compressed_tensors.utils.safetensors_load import ( + get_quantization_state_dict, + get_weight_mappings, +) +from safetensors import safe_open _LOGGER = logging.getLogger(__name__) +def load_transforms(model: Module, model_name_or_path: str): + model_path = get_safetensors_folder(model_name_or_path) + weight_mappings = get_weight_mappings(model_path) + + state_dict = {} + for weight_name, safe_path in weight_mappings.items(): + if "transform" in weight_name: + with safe_open(safe_path, framework="pt", device="cpu") as f: + state_dict[weight_name] = f.get_tensor(weight_name) + + for name, submodule in iter_named_leaf_modules(model): + transform_data = getattr(submodule, "transform_data", None) + + if transform_data: + for transform_name, transform_values in transform_data.data.items(): + full_name = f"{name}.{transform_name}" + transform_data = state_dict.get(full_name, None) + transform = transform_values.get("transform") + transform.register_to_module(name=transform_name, module=submodule) + transform.update_transform( + module=submodule, data=transform_data, name=transform_name + ) + + def load_pretrained_quantization(model: Module, model_name_or_path: str): """ Loads the quantization parameters (scale and zero point) from model_name_or_path to @@ -104,8 +137,94 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str): ) +def process_transforms_config( + transforms_config: TransformationConfig, + model: torch.nn.Module, + quantization_status: Optional[QuantizationStatus] = QuantizationStatus.INITIALIZED, +): + for _, group in transforms_config.transform_groups.items(): + # Each group/scheme targets one type of transform + transform_type = group.transform_type + transform_creation_args = group.transform_creation_args + + # Need a better name - too many groups + for transform_arg in group.groups: + module_targets = transform_arg.module_targets + + for name, submodule in model.named_modules(): + if len(transform_arg.ignore) > 0: + if matches := find_name_or_class_matches( + name, submodule, transform_arg.ignore + ): + for match in matches: + print("ignoring", match, name) + continue # layer matches ignore list, continue + + targets = find_name_or_class_matches( + name, submodule, transform_arg.targets + ) + + if targets: + # Every layer which matches gets its own transform + # Same transform type and args are used however + + # attach the transform to the submodule + # because we can have more than one transform, need to attach some + # form of key to fetch + # OR we store it in the dictionary, handle cpu-offloading separatly + + if hasattr(submodule, "transform_data"): + idx = submodule.transform_data.idx + 1 + else: + idx = 0 + # only support weight parameters for now, assume one value in + # module targets + transform_name = f"{module_targets[0]}_transform_{idx}" + + # create an empty tensor OR create a new transform + dtype = getattr(submodule, module_targets[0]).dtype + if quantization_status in [ + QuantizationStatus.COMPRESSED, + QuantizationStatus.FROZEN, + ]: + transform = Transforms.load_from_registry( + transform_type, + dtype=dtype, + empty=True, + **transform_creation_args, + ) + else: + transform = Transforms.load_from_registry( + transform_type, + dtype=dtype, + **transform_creation_args, + ) + transform.register_to_module( + name=transform_name, module=submodule + ) + + # add relevant transform data to the submodule as well + data = { + transform_name: { + "transform": transform, + "call_args": transform_arg.call_args, + } + } + + if hasattr(submodule, "transform_data"): + submodule.transform_data.data.update(data) + submodule.transform_data.idx = idx + else: + transform_data = TransformData(data=OrderedDict(data)) + submodule.transform_data = transform_data + return model + + def apply_quantization_config( - model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False + model: Module, + config: Union[QuantizationConfig, None], + run_compressed: bool = False, + transforms_config=None, ) -> OrderedDict: """ Initializes the model for quantization in-place based on the given config. @@ -184,6 +303,12 @@ def apply_quantization_config( f"{set(config.ignore) - set(ignored_submodules)}" ) + if transforms_config: + model.transforms_config = transforms_config + model = process_transforms_config( + transforms_config, model, config.quantization_status + ) + # apply current quantization status across all targeted layers apply_quantization_status(model, config.quantization_status) return names_to_scheme diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index f4f93f27..510f7e29 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -28,6 +28,10 @@ calculate_range, compute_dynamic_scales_and_zp, ) +from compressed_tensors.transforms.apply import ( + apply_inverse_transforms_to_parameter, + apply_transforms_to_parameter, +) from compressed_tensors.utils import safe_permute from torch.nn import Module @@ -280,10 +284,25 @@ def wrapped_forward(self, *args, **kwargs): if scheme.weights is not None and not compressed: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() + transform_data = getattr(module, "transform_data", None) + if transform_data is not None: + apply_transforms_to_parameter( + module=module, + module_parameter=self.weight, + transform_data=transform_data, + ) + self.weight.data = forward_quantize( module, self.weight, "weight", scheme.weights ) + if transform_data is not None: + apply_inverse_transforms_to_parameter( + module=module, + module_parameter=self.weight, + transform_data=transform_data, + ) + # perform wrapped forward call output = forward_func_orig.__get__(module, module.__class__)( input_, *args[1:], **kwargs diff --git a/src/compressed_tensors/transforms/apply.py b/src/compressed_tensors/transforms/apply.py new file mode 100644 index 00000000..b3a5d18b --- /dev/null +++ b/src/compressed_tensors/transforms/apply.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from compressed_tensors.transforms import Transforms +from compressed_tensors.transforms.transform_data import TransformData + + +__all__ = ["apply_transforms_to_parameter", "apply_inverse_transforms_to_parameter"] + + +def apply_transforms_to_parameter( + module: torch.nn.Module, + module_parameter: torch.nn.Parameter, + transform_data: TransformData, +): + """ + Apply all transforms relevant to a parameter using a module's + transform data. The parameter data is updated in-place. + + :param module: torch.nn.Moudle + :param module_parameter: the torch.nn.Parameter to transform + :param transform_data: a module's TransformData + + Only implemented for weight parameters thus far. + + """ + + for transform_name, transform_values in transform_data.data.items(): + transform = transform_values.get("transform") + call_args = transform_values.get("call_args") + transformed_param_data = transform.apply( + input_tensor=module_parameter, **call_args + ) + module_parameter.data.copy_(transformed_param_data) + + +def apply_inverse_transforms_to_parameter( + module: torch.nn.Module, + module_parameter: torch.nn.Parameter, + transform_data: TransformData, +): + """ + Apply all inverse transform operations relevant to a parameter using a module's + TransformData. The parameter data is updated in-place. + + :param module: torch.nn.Moudle + :param module_parameter: the torch.nn.Parameter to transform + :param transform_data: a module's TransformData + + Only implemented for weight parameters thus far. + + """ + + for transform_name, transform_values in reversed(transform_data.data.items()): + transform = transform_values.get("transform") + call_args = transform_values.get("call_args") + transformed_param_data = transform.inverse_apply( + input_tensor=module_parameter, **call_args + ) + module_parameter.data.copy_(transformed_param_data) diff --git a/tests/test_transforms/test_integration.py b/tests/test_transforms/test_integration.py new file mode 100644 index 00000000..56765768 --- /dev/null +++ b/tests/test_transforms/test_integration.py @@ -0,0 +1,185 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.nn as nn +from compressed_tensors.quantization import process_transforms_config +from compressed_tensors.transforms import Hadamard, RandomHadamard, Transforms +from compressed_tensors.transforms.transform_args import ( + ModuleTarget, + TransformationArgs, +) +from compressed_tensors.transforms.transform_config import TransformationConfig +from compressed_tensors.transforms.transform_data import TransformData +from compressed_tensors.transforms.transform_scheme import TransformationScheme + + +@pytest.fixture +def transform_recipe_basic(): + targets = ["Linear"] + module_targets = [ModuleTarget.WEIGHT] + linear_layer_args = TransformationArgs( + targets=targets, module_targets=module_targets + ) + + scheme = TransformationScheme( + transform_type="hadamard", + groups=[linear_layer_args], + transform_creation_args={"size": 64}, + ) + config = TransformationConfig( + transform_groups={ + "transform_0": scheme, + } + ) + return config + + +@pytest.fixture +def transform_recipe_complex_multiple(transform_recipe_basic): + targets = ["Embedding"] + module_targets = [ModuleTarget.WEIGHT] + embedding_args = TransformationArgs(targets=targets, module_targets=module_targets) + + scheme = TransformationScheme( + transform_type="hadamard", + groups=[embedding_args], + transform_creation_args={"size": 128}, + ) + transform_recipe_basic.transform_groups["transform_1"] = scheme + return transform_recipe_basic + + +@pytest.fixture +def transform_recipe_complex(transform_recipe_basic): + targets = ["Linear"] + module_targets = [ModuleTarget.OUTPUT_ACTIVATIONS] + linear_layer_args = TransformationArgs( + targets=targets, module_targets=module_targets + ) + + scheme = TransformationScheme( + transform_type="random-hadamard", + groups=[linear_layer_args], + transform_creation_args={"size": 64}, + ) + transform_recipe_basic.transform_groups["transform_1"] = scheme + return transform_recipe_basic + + +@pytest.fixture +def basic_model(): + class BasicModel(nn.Module): + def __init__(self, vocab_size, embed_size, hidden_size, num_classes): + super(BasicModel, self).__init__() + + self.embedding = nn.Embedding(vocab_size, embed_size) + self.block1 = nn.Sequential( + nn.Linear(embed_size, hidden_size), nn.ReLU(), nn.Dropout(0.2) + ) + self.block2 = nn.Sequential( + nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.2) + ) + self.fc = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + x = self.embedding(x) + x = x.mean(dim=1) + x = self.block1(x) + x = self.block2(x) + x = self.fc(x) + + return x + + vocab_size = 1000 + embed_size = 128 + hidden_size = 64 + num_classes = 10 + return BasicModel(vocab_size, embed_size, hidden_size, num_classes) + + +def _verify_correct_data(layer: torch.nn.Module): + assert hasattr(layer, "transform_data") + assert isinstance(layer.transform_data, TransformData) + + # data keys are all the different transforms relevant + # to the module + transform_data = layer.transform_data + + for k, v in transform_data.data.items(): + current_transform = getattr(layer, k) + assert isinstance(current_transform, torch.nn.Parameter) + assert "call_args" in v + + +@pytest.mark.skip(reason="Skipping until activation transforms are supported") +def test_recipe_complex(basic_model, transform_recipe_complex): + model = process_transforms_config( + model=basic_model, transforms_config=transform_recipe_complex + ) + + blocks = [model.block1, model.block2] + for block in blocks: + for layer in block: + if isinstance(layer, torch.nn.Linear): + _verify_correct_data(layer) + + +def test_recipe_basic(basic_model, transform_recipe_basic): + model = process_transforms_config( + model=basic_model, transforms_config=transform_recipe_basic + ) + + blocks = [model.block1, model.block2] + for block in blocks: + for layer in block: + if isinstance(layer, torch.nn.Linear): + _verify_correct_data(layer) + + +def test_recipe_complex_multiple(basic_model, transform_recipe_complex_multiple): + model = process_transforms_config( + model=basic_model, transforms_config=transform_recipe_complex_multiple + ) + + # Should have the following structure: + """ + >> basic_model.embedding.output_activations_transform + Parameter containing: + tensor([[ 1., 1., 1., ..., 1., 1., 1.], + [ 1., -1., 1., ..., -1., 1., -1.], + [ 1., 1., -1., ..., 1., -1., -1.], + ..., + [ 1., -1., 1., ..., -1., 1., -1.], + [ 1., 1., -1., ..., 1., -1., -1.], + [ 1., -1., -1., ..., -1., -1., 1.]], dtype=torch.bfloat16) + + >> model.embedding.transform_data + TransformData(data={'output_activations_transform': + { + 'call_args': defaultdict() + } + } + ) + """ + + # Verify Embedding layers and Linear Layers have the correct data attached to them + _verify_correct_data(model.embedding) + + blocks = [model.block1, model.block2] + for block in blocks: + for layer in block: + if isinstance(layer, torch.nn.Linear): + _verify_correct_data(layer)