diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index e5a1e05c..fcfda173 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,7 +13,8 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Optional +from collections import defaultdict +from typing import List, Optional, Tuple import torch import torch.nn.utils.parametrize as P @@ -49,10 +50,13 @@ class TransformFactory(RegistryMixin, ABC): :param seed: random seed used to transform weight randomization """ + transforms: List["TransformBase"] + def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): self.name = name self.scheme = scheme self.generator = torch.Generator() + self.transforms = list() if seed is not None: self.generator.manual_seed(seed) @@ -91,6 +95,8 @@ def apply_to_model(self, model: Module): if is_target(name, module, arg.targets, arg.ignore): self._apply_to_module(module, arg) + self._update_tied_weights() + def _apply_to_module(self, module: Module, args: TransformArgs): """ Create transforms and apply them to the module @@ -98,9 +104,17 @@ def _apply_to_module(self, module: Module, args: TransformArgs): :param module: target module to apply transforms to :param args: defines how the transform will be applied to the target module """ + if has_offloaded_params(module): + if module._hf_hook.place_submodules: + raise NotImplementedError( + "Applying transforms to offloaded submodules with " + "`place_submodules=True` is not supported" + ) + # create transform as submodule transform_name = f"{self.name}_{args.location.value}" transform = self.create_transform(module, args) + self.transforms.append(transform) register_offload_module(module, transform_name, transform) # register input transformation hook @@ -131,8 +145,9 @@ def input_hook(_, args): raise ValueError("Offloaded training is not supported") P.register_parametrization(module, "weight", transform) - # transform is no longer needed (unfusing is not supported) - delete_offload_module(module, transform_name) + else: + # transform is no longer needed (unfusing is not supported) + delete_offload_module(module, transform_name) # register output transformation hook elif args.location == TransformLocation.OUTPUT: @@ -146,6 +161,34 @@ def output_hook(_, _input, output): else: raise NotImplementedError() + def _update_tied_weights(self): + """ + Populate the `_dynamic_tied_weights_keys` attribute of transforms, + which is used by transformers to detect and remove shared pointers + during saving + """ + # avoid issues with this method being called twice + for transform in self.transforms: + transform._dynamic_tied_weights_keys = list() + + # map from data_ptrs to keys + ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list) + for transform in self.transforms: + for name, param in transform.named_parameters(recurse=False): + # NOTE: previously asserted that parent._hf_hook.place_submodules=False + if has_offloaded_params(transform): + param = transform._hf_hook.weights_map[name] + ptr_to_keys[param.data_ptr()].append((transform, name)) + + # populate `_dynamic_tied_weights_keys` if there is more than one key + # and ensure that they share tensors + for shared_keys in ptr_to_keys.values(): + if len(shared_keys) > 1: + tensor = getattr(shared_keys[0][0], shared_keys[0][1]) + + for transform, name in shared_keys: + transform._dynamic_tied_weights_keys.append(name) + setattr(transform, name, tensor) class TransformBase(InternalModule, ABC): """ @@ -154,6 +197,11 @@ class TransformBase(InternalModule, ABC): args: TransformArgs weight: Parameter + _dynamic_tied_weights_keys: List[str] + + def __init__(self): + super().__init__() + self._dynamic_tied_weights_keys = list() @abstractmethod def forward(self, value: Tensor) -> Tensor: diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 47e9bcbb..33cfa736 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -70,6 +70,7 @@ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: def _create_inverse(self, weight: Parameter) -> Parameter: data = high_precision_invert(weight.data) + data = data.contiguous() # ensure proper serialization return Parameter(data, requires_grad=False) diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index 2067a647..92269954 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -14,12 +14,13 @@ import pytest import torch -from compressed_tensors.transform import TransformArgs +from compressed_tensors.transform import TransformArgs, TransformFactory +from transformers import PretrainedConfig, PreTrainedModel -class TransformableModel(torch.nn.Module): +class TransformableModel(PreTrainedModel): def __init__(self, *sizes): - super().__init__() + super().__init__(config=PretrainedConfig()) self.fcs = torch.nn.ModuleList( [ torch.nn.Linear(sizes[index], sizes[index + 1], bias=False) diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index b34ca51a..ea82f73b 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -26,11 +26,11 @@ @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) -@pytest.mark.parametrize("randomized", (True, False)) -def test_correctness_linear(type, randomized): +@pytest.mark.parametrize("randomize", (True, False)) +def test_correctness_linear(type, randomize): size = (4, 8) module = torch.nn.Linear(*size, bias=True) - scheme = TransformScheme(type=type, randomized=randomized) + scheme = TransformScheme(type=type, randomize=randomize) factory = TransformFactory.from_scheme(scheme, name="") input_tfm = factory.create_transform( @@ -55,8 +55,8 @@ def test_correctness_linear(type, randomized): @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) -@pytest.mark.parametrize("randomized", (True, False)) -def test_correctness_model(type, randomized, model_apply, offload=False): +@pytest.mark.parametrize("randomize", (True, False)) +def test_correctness_model(type, randomize, model_apply, offload=False): # load model model = model_apply[0] if offload: @@ -71,7 +71,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False): # apply transforms config = TransformConfig( config_groups={ - "": TransformScheme(type=type, randomized=randomized, apply=model_apply[1]) + "": TransformScheme(type=type, randomize=randomize, apply=model_apply[1]) } ) apply_transform_config(model, config) @@ -84,6 +84,6 @@ def test_correctness_model(type, randomized, model_apply, offload=False): @requires_gpu @requires_accelerate() @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) -@pytest.mark.parametrize("randomized", (True, False)) -def test_correctness_model_offload(type, randomized, model_apply): - test_correctness_model(type, randomized, model_apply, offload=True) +@pytest.mark.parametrize("randomize", (True, False)) +def test_correctness_model_offload(type, randomize, model_apply): + test_correctness_model(type, randomize, model_apply, offload=True) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index fcca33d4..7fc3c914 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -29,9 +29,9 @@ @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) -@pytest.mark.parametrize("randomized", (True, False)) +@pytest.mark.parametrize("randomize", (True, False)) @pytest.mark.parametrize("requires_grad", (True, False)) -def test_memory_sharing(type, randomized, requires_grad, offload=False): +def test_memory_sharing(type, randomize, requires_grad, offload=False): # load model (maybe with offloading) model = TransformableModel(2, 2, 4, 4, 8, 8) if offload: @@ -42,7 +42,7 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False): config_groups={ "": TransformScheme( type=type, - randomzied=randomized, + randomzied=randomize, requires_grad=requires_grad, apply=[ TransformArgs(targets="Linear", location="input"), @@ -84,9 +84,9 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False): @requires_gpu @requires_accelerate() @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) -@pytest.mark.parametrize("randomized", (True, False)) +@pytest.mark.parametrize("randomize", (True, False)) def test_memory_sharing_offload( type, - randomized, + randomize, ): - test_memory_sharing(type, randomized, requires_grad=False, offload=True) + test_memory_sharing(type, randomize, requires_grad=False, offload=True) diff --git a/tests/test_transform/factory/test_serialization.py b/tests/test_transform/factory/test_serialization.py new file mode 100644 index 00000000..a688c2cf --- /dev/null +++ b/tests/test_transform/factory/test_serialization.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.transform import ( + TransformConfig, + TransformScheme, + apply_transform_config, +) +from compressed_tensors.utils import offloaded_dispatch +from tests.testing_utils import requires_accelerate, requires_gpu + + +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomize", (True, False)) +def test_serialization(type, randomize, model_apply, tmp_path, offload=False): + # get model, maybe offload + model, apply = model_apply + if offload: + offloaded_dispatch(model, torch.device("cuda")) + + # apply transforms to model + config = TransformConfig( + config_groups={"": TransformScheme(type=type, randomize=randomize, apply=apply)} + ) + apply_transform_config(model, config) + + # save model + model.save_pretrained(tmp_path) + + # TODO: reload model + + +@pytest.mark.skip(reason="Requires changes in upstream transformers") +# https://github.com/huggingface/transformers/pull/39280 +# https://github.com/huggingface/transformers/pull/39263 +@requires_gpu +@requires_accelerate() +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomize", (True, False)) +def test_serialization_offload(type, randomize, model_apply, tmp_path): + test_serialization(type, randomize, model_apply, tmp_path, offload=True)