diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 700c1769..6bd5421b 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -479,7 +479,7 @@ def decompress_model(self, model: Module): # remove any existing parameters exec_device = get_execution_device(module) offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters()): + for name, _ in list(module.named_parameters(recurse=False)): delete_offload_parameter(module, name) # replace with decompressed parameters diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 4c7b6c91..9db318a0 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -26,6 +26,7 @@ ) from compressed_tensors.utils import ( align_module_device, + delete_offload_module, has_offloaded_params, patch_attr, register_offload_module, @@ -99,7 +100,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location}" transform = self.create_transform(module, args) - register_offload_module(module, transform_name, transform) # (1) + register_offload_module(module, transform_name, transform) # register input transformation hook if args.location == TransformLocation.INPUT: @@ -128,6 +129,10 @@ def input_hook(_, args): raise ValueError("Offloaded training is not supported") P.register_parametrization(module, "weight", transform) + else: + # if we're not training, there's no reason to keep the transform + delete_offload_module(module, transform_name) + # register output transformation hook elif args.location == TransformLocation.OUTPUT: @@ -140,9 +145,6 @@ def output_hook(_, _input, output): else: raise NotImplementedError() - # (1) even in the `weight` cases, this submodule attachment is needed in order - # to support saving in the frozen state - class TransformBase(Module, ABC): """ diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index b1da88a3..b4d5f7de 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory): def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) + self.perms = ParameterizedDefaultDict(self._create_permutation) def create_transform(self, module: Module, args: TransformArgs): """ @@ -56,24 +57,35 @@ def create_transform(self, module: Module, args: TransformArgs): device = get_offloaded_device(module) weight = self.weights[size, dtype, device] - return HadamardTransform(weight, args) + perm = self.perms[weight] if self.scheme.randomize else None + return HadamardTransform(weight, perm, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: data = deterministic_hadamard_matrix(size) data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) + def _create_permutation(self, weight: Parameter) -> Parameter: + data = torch.randperm(weight.size(0), generator=self.generator) + return Parameter(data, requires_grad=False) + class HadamardTransform(TransformBase): - def __init__(self, weight: Parameter, args: TransformArgs): + def __init__( + self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs + ): super().__init__() self.weight = weight + self.perm = perm self.args = args def forward(self, value: Tensor) -> Tensor: - if not self.args.inverse: - weight = self.weight - else: - weight = self.weight.T + weight = self.weight + + if self.perm is not None: + weight = weight[self.perm][:, self.perm] + + if self.args.inverse: + weight = weight.T return apply_transform_weight(weight, value, self.args.location) diff --git a/src/compressed_tensors/transform/transform_config.py b/src/compressed_tensors/transform/transform_config.py index 414c21e0..df178c42 100644 --- a/src/compressed_tensors/transform/transform_config.py +++ b/src/compressed_tensors/transform/transform_config.py @@ -49,7 +49,7 @@ class TransformConfig(BaseModel): inverse=True, ), ], - randomize_modules=True, + randomize=True, ), "u": TransformScheme( type="hadamard", @@ -62,7 +62,7 @@ class TransformConfig(BaseModel): targets=["Linear"], location="output", inverse=True # non-mergable ), ], - randomize_modules=True, + randomize=True, ), } ) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 1335063c..64d646e0 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -31,13 +31,12 @@ class TransformScheme(BaseModel): (see `Transforms.registered_names()`) :param apply: list of TransformationArgs containing the information about the modules that should be targeted by the specified transform - :param randomize_modules: True if unique transforms should be applied to each - unique module targeted by `apply`, otherwise reuse transform weights where - applicable + :param randomize: True if uniquely randomized transform weights should be used, + otherwise use identical transform weights where applicable :param requires_grad: True if weights include gradients for training """ type: str apply: List[TransformArgs] = Field(default_factory=list) - randomize_modules: bool = Field(default=False) + randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index c1a46e3c..7cc9f33c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -467,6 +467,7 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M def delete_offload_module(base: torch.nn.Module, name: str): """ Delete a submodule from a model which may contain offloading + :param base: parent module to delete submodule from :param name: name of submodule on parent """ diff --git a/tests/test_transform/factory/test_compression.py b/tests/test_transform/factory/test_compression.py new file mode 100644 index 00000000..6f763fcf --- /dev/null +++ b/tests/test_transform/factory/test_compression.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors import ModelCompressor +from compressed_tensors.quantization import QuantizationStatus +from compressed_tensors.transform import ( + TransformArgs, + TransformBase, + TransformFactory, + TransformLocation, + TransformScheme, +) + + +class TransformableModel(torch.nn.Module): + def __init__(self, *sizes): + super().__init__() + self.fcs = torch.nn.ModuleList([]) + self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) + for index in range(1, len(sizes) - 1): + self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) + + def forward(self, x): + for layer in self.fcs: + x = layer(x) + return x + + +def test_frozen_reload(): + # TODO: test applying and reloadaing a transformers model + pass + + +def test_compressed_keys(): + model = TransformableModel(2, 4, 8, 16, 32, 64) + + scheme = TransformScheme(type="hadamard") + scheme.apply = [ + TransformArgs(targets="fcs.0", location="weight_output"), + TransformArgs(targets="fcs.1", location="input", inverse=True), + TransformArgs(targets="fcs.1", location="output"), + TransformArgs(targets="fcs.2", location="weight_input", inverse=True), + TransformArgs(targets="fcs.2", location="output"), + TransformArgs(targets="fcs.3", location="input", inverse=True), + TransformArgs(targets="fcs.3", location="weight_output"), + TransformArgs(targets="fcs.4", location="weight_input", inverse=True), + ] + factory = TransformFactory.from_scheme(scheme, name="") + + input = torch.rand((17, model.fcs[0].in_features)) + true_output = model(input) + + factory.apply_to_model(model) + + compressor = ModelCompressor() + compressor.compress_model(model) + + keys = { + "fcs.0.weight", + "fcs.1.weight", + "fcs.1._input.weight", + "fcs.1._output.weight", + "fcs.2.weight", + "fcs.2._output.weight", + "fcs.3.weight", + "fcs.3._input.weight", + "fcs.4.weight", + } + assert model.state_dict().keys() == keys + + output = model(input) + assert torch.allclose(true_output, output, atol=1e-7, rtol=0.0) + + +def test_compress_decompress(): + model = TransformableModel(2, 4, 8, 16, 32, 64) + + scheme = TransformScheme(type="hadamard") + scheme.apply = [ + TransformArgs(targets="fcs.0", location="weight_output"), + TransformArgs(targets="fcs.1", location="input", inverse=True), + TransformArgs(targets="fcs.1", location="output"), + TransformArgs(targets="fcs.2", location="weight_input", inverse=True), + TransformArgs(targets="fcs.2", location="output"), + TransformArgs(targets="fcs.3", location="input", inverse=True), + TransformArgs(targets="fcs.3", location="weight_output"), + TransformArgs(targets="fcs.4", location="weight_input", inverse=True), + ] + factory = TransformFactory.from_scheme(scheme, name="") + + input = torch.rand((17, model.fcs[0].in_features)) + true_output = model(input) + + factory.apply_to_model(model) + + compressor = ModelCompressor() + compressor.compress_model(model) + compressor.decompress_model(model) + + output = model(input) + assert torch.allclose(true_output, output, atol=1e-7, rtol=0.0) + + keys = { + "fcs.0.weight", + "fcs.1.weight", + "fcs.1._input.weight", + "fcs.1._output.weight", + "fcs.2.weight", + "fcs.2._output.weight", + "fcs.3.weight", + "fcs.3._input.weight", + "fcs.4.weight", + } + assert model.state_dict().keys() == keys diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index bab1117e..1745281f 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -19,10 +19,17 @@ TransformFactory, TransformScheme, ) -from compressed_tensors.utils import align_modules, force_cpu_offload +from compressed_tensors.utils import force_cpu_offload from tests.testing_utils import requires_accelerate, requires_gpu +def all_schemes(): + all_types = TransformFactory.registered_names() + base = [TransformScheme(type=type) for type in all_types] + randomized = [TransformScheme(type=type, randomize=True) for type in all_types] + return base + randomized + + class TransformableModel(torch.nn.Module): def __init__(self, *sizes): super().__init__() @@ -37,10 +44,7 @@ def forward(self, x): return x -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", all_schemes()) def test_correctness_linear(scheme): size = (4, 8) module = torch.nn.Linear(*size, bias=True) @@ -67,10 +71,7 @@ def test_correctness_linear(scheme): assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", all_schemes()) def test_correctness_model(scheme, offload=False): # load model model = TransformableModel(2, 4, 8, 16, 32, 64) @@ -108,9 +109,6 @@ def test_correctness_model(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", all_schemes()) def test_correctness_model_offload(scheme): test_correctness_model(scheme, offload=True) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 49e882e4..15d72b9b 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -26,6 +26,13 @@ from tests.testing_utils import requires_accelerate, requires_gpu +def all_schemes(): + all_types = TransformFactory.registered_names() + base = [TransformScheme(type=type) for type in all_types] + randomized = [TransformScheme(type=type, randomize=True) for type in all_types] + return base + randomized + + class TransformableModel(torch.nn.Module): def __init__(self, *sizes): super().__init__() @@ -40,10 +47,7 @@ def forward(self, x): return x -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", all_schemes()) def test_memory_sharing(scheme, offload=False): # load scheme and factory scheme = TransformScheme( @@ -93,20 +97,12 @@ def test_memory_sharing(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", all_schemes()) def test_memory_sharing_offload(scheme): test_memory_sharing(scheme, offload=True) -@pytest.mark.parametrize( - "scheme", - [ - TransformScheme(type=name, requires_grad=True) - for name in TransformFactory.registered_names() - ], -) +@pytest.mark.parametrize("scheme", all_schemes()) def test_memory_sharing_training(scheme): + scheme.requires_grad = True test_memory_sharing(scheme, offload=False) diff --git a/tests/test_transform/test_transform_scheme.py b/tests/test_transform/test_transform_scheme.py index ad851762..839ab46a 100644 --- a/tests/test_transform/test_transform_scheme.py +++ b/tests/test_transform/test_transform_scheme.py @@ -24,7 +24,7 @@ def test_basic_scheme(): type="hadamard", apply=[basic_args], ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 1 assert isinstance(scheme.apply[0], TransformArgs) @@ -43,10 +43,10 @@ def test_multiple_groups_global(): scheme = TransformScheme( type="hadamard", apply=[embedding_args, linear_args], - randomize_modules=True, + randomize=True, ) - assert scheme.randomize_modules + assert scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 2 assert isinstance(scheme.apply[0], TransformArgs) @@ -69,6 +69,6 @@ def test_multiple_groups(): apply=apply, ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 20