|
| 1 | +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, |
| 10 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from abc import ABC, abstractmethod |
| 16 | +from typing import Optional |
| 17 | + |
| 18 | +import torch |
| 19 | +import torch.nn.utils.parametrize as P |
| 20 | +from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils |
| 21 | +from compressed_tensors.registry.registry import RegistryMixin, T |
| 22 | +from compressed_tensors.transform import ( |
| 23 | + TransformArgs, |
| 24 | + TransformLocation, |
| 25 | + TransformScheme, |
| 26 | +) |
| 27 | +from compressed_tensors.utils import ( |
| 28 | + align_module_device, |
| 29 | + has_offloaded_params, |
| 30 | + patch_attr, |
| 31 | + register_offload_module, |
| 32 | + update_offload_parameter, |
| 33 | +) |
| 34 | +from torch import Tensor |
| 35 | +from torch.nn import Module, Parameter |
| 36 | + |
| 37 | + |
| 38 | +__all__ = ["TransformFactory", "TransformBase"] |
| 39 | + |
| 40 | + |
| 41 | +class TransformFactory(RegistryMixin, ABC): |
| 42 | + """ |
| 43 | + Abstract factory base used to create and apply transforms to a model |
| 44 | +
|
| 45 | + :param name: name associated with transform scheme |
| 46 | + :param scheme: transform scheme which defines how transforms should be created |
| 47 | + :param seed: random seed used to transform weight randomization |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): |
| 51 | + self.name = name |
| 52 | + self.scheme = scheme |
| 53 | + self.generator = torch.Generator() |
| 54 | + if seed is not None: |
| 55 | + self.generator.manual_seed(seed) |
| 56 | + |
| 57 | + @classmethod |
| 58 | + def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T: |
| 59 | + """ |
| 60 | + Create a transform factory from a scheme |
| 61 | +
|
| 62 | + :param scheme: defines how transforms should be created |
| 63 | + :param kwargs: TransformFactory constructor arguments |
| 64 | + :return: subclass of `TransformFactory` corresponding to the scheme type |
| 65 | + """ |
| 66 | + constructor = cls.get_value_from_registry(name=scheme.type) |
| 67 | + return constructor(scheme=scheme, **kwargs) |
| 68 | + |
| 69 | + @abstractmethod |
| 70 | + def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase": |
| 71 | + """ |
| 72 | + Abstract method which defines how a transform should be created. May utilize |
| 73 | + caching to maximize shared memory |
| 74 | +
|
| 75 | + :param module: parent module that transform will be applied to |
| 76 | + :param args: defines how the transform will be applied to the module |
| 77 | + :return: instance of TransformBase |
| 78 | + """ |
| 79 | + raise NotImplementedError() |
| 80 | + |
| 81 | + def apply_to_model(self, model: Module): |
| 82 | + """ |
| 83 | + Create transforms and apply them to the model |
| 84 | +
|
| 85 | + :param model: module to apply transforms to |
| 86 | + """ |
| 87 | + for arg in self.scheme.apply: |
| 88 | + for name, module in list(model.named_modules()): |
| 89 | + if is_target(name, module, arg.targets, arg.ignore): |
| 90 | + self._apply_to_module(module, arg) |
| 91 | + |
| 92 | + def _apply_to_module(self, module: Module, args: TransformArgs): |
| 93 | + """ |
| 94 | + Create transforms and apply them to the module |
| 95 | +
|
| 96 | + :param module: target module to apply transforms to |
| 97 | + :param args: defines how the transform will be applied to the target module |
| 98 | + """ |
| 99 | + # create transform as submodule |
| 100 | + transform_name = f"{self.name}_{args.location}" |
| 101 | + transform = self.create_transform(module, args) |
| 102 | + register_offload_module(module, transform_name, transform) # (1) |
| 103 | + |
| 104 | + # register input transformation hook |
| 105 | + if args.location == TransformLocation.INPUT: |
| 106 | + |
| 107 | + def input_hook(_, args): |
| 108 | + input = args[0] |
| 109 | + return transform(input) |
| 110 | + |
| 111 | + module.register_forward_pre_hook(input_hook, prepend=True) |
| 112 | + |
| 113 | + # eagerly apply transformation to weight |
| 114 | + elif args.location in ( |
| 115 | + TransformLocation.WEIGHT_INPUT, |
| 116 | + TransformLocation.WEIGHT_OUTPUT, |
| 117 | + ): |
| 118 | + assert isinstance(module, torch.nn.Linear) |
| 119 | + assert module.bias is None |
| 120 | + |
| 121 | + with torch.no_grad(), align_module_device(module): |
| 122 | + update_offload_parameter(module, "weight", transform(module.weight)) |
| 123 | + |
| 124 | + if self.scheme.requires_grad: |
| 125 | + # for training, the weight changes with every forward pass |
| 126 | + # so we can leverage parametrization to propagate the gradient |
| 127 | + if has_offloaded_params(module): |
| 128 | + raise ValueError("Offloaded training is not supported") |
| 129 | + P.register_parametrization(module, "weight", transform) |
| 130 | + |
| 131 | + # register output transformation hook |
| 132 | + elif args.location == TransformLocation.OUTPUT: |
| 133 | + |
| 134 | + def output_hook(_, _input, output): |
| 135 | + return transform(output) |
| 136 | + |
| 137 | + module.register_forward_hook(output_hook) |
| 138 | + |
| 139 | + # other locations such as q_attn and k_attn have not been implemented |
| 140 | + else: |
| 141 | + raise NotImplementedError() |
| 142 | + |
| 143 | + # (1) even in the `weight` cases, this submodule attachment is needed in order |
| 144 | + # to support saving in the frozen state |
| 145 | + |
| 146 | + |
| 147 | +class TransformBase(Module, ABC): |
| 148 | + """ |
| 149 | + Represents the application of a transform accord to TransformArgs |
| 150 | + """ |
| 151 | + |
| 152 | + args: TransformArgs |
| 153 | + weight: Parameter |
| 154 | + |
| 155 | + @abstractmethod |
| 156 | + def forward(self, value: Tensor) -> Tensor: |
| 157 | + raise NotImplementedError() |
| 158 | + |
| 159 | + def right_inverse(self, value: Tensor) -> Tensor: |
| 160 | + with patch_attr(self.args, "inverse", not self.args.inverse): |
| 161 | + return self.forward(value) |
| 162 | + |
| 163 | + def __repr__(self): |
| 164 | + return f"{self.__class__.__name__}(inverse={self.args.inverse})" |
0 commit comments