diff --git a/src/compressed_tensors/transform/__init__.py b/src/compressed_tensors/transform/__init__.py index f6d656dd..e7546d62 100644 --- a/src/compressed_tensors/transform/__init__.py +++ b/src/compressed_tensors/transform/__init__.py @@ -23,3 +23,4 @@ from .factory.hadamard import * from .factory.matrix_multiply import * from .factory.random_hadamard import * +from .apply import * diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py new file mode 100644 index 00000000..a5d4c8c2 --- /dev/null +++ b/src/compressed_tensors/transform/apply.py @@ -0,0 +1,32 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transform import TransformConfig, TransformFactory + + +__all__ = ["apply_transform_config"] + + +def apply_transform_config(model: torch.nn.Module, config: TransformConfig): + """ + Apply a transform config to a model. Weight transforms are fused into weights, while + activation transforms are attached as submodules and trigger via pytorch hooks + + :param model: model to apply config to + :param config: transform config to apply + """ + for name, scheme in config.config_groups.items(): + factory = TransformFactory.from_scheme(scheme, name=name) + factory.apply_to_model(model) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 62b9ddbd..e5a1e05c 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -27,6 +27,7 @@ ) from compressed_tensors.utils import ( align_module_device, + delete_offload_module, has_offloaded_params, patch_attr, register_offload_module, @@ -100,7 +101,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location.value}" transform = self.create_transform(module, args) - register_offload_module(module, transform_name, transform) # (1) + register_offload_module(module, transform_name, transform) # register input transformation hook if args.location == TransformLocation.INPUT: @@ -119,6 +120,7 @@ def input_hook(_, args): assert isinstance(module, torch.nn.Linear) assert module.bias is None + # fuse transform into weight with torch.no_grad(), align_module_device(module): update_offload_parameter(module, "weight", transform(module.weight)) @@ -129,6 +131,9 @@ def input_hook(_, args): raise ValueError("Offloaded training is not supported") P.register_parametrization(module, "weight", transform) + # transform is no longer needed (unfusing is not supported) + delete_offload_module(module, transform_name) + # register output transformation hook elif args.location == TransformLocation.OUTPUT: @@ -141,9 +146,6 @@ def output_hook(_, _input, output): else: raise NotImplementedError() - # (1) even in the `weight` cases, this submodule attachment is needed in order - # to support saving in the frozen state - class TransformBase(InternalModule, ABC): """ diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index 19d04896..b34ca51a 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -19,23 +19,18 @@ TransformConfig, TransformFactory, TransformScheme, + apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch from tests.testing_utils import requires_accelerate, requires_gpu -def scheme_kwargs(): - all_types = TransformFactory.registered_names() - base = [{"type": type} for type in all_types] - randomized = [{"type": type, "randomize": True} for type in all_types] - return base + randomized - - -@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) -def test_correctness_linear(scheme_kwargs): +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomized", (True, False)) +def test_correctness_linear(type, randomized): size = (4, 8) module = torch.nn.Linear(*size, bias=True) - scheme = TransformScheme(**scheme_kwargs) + scheme = TransformScheme(type=type, randomized=randomized) factory = TransformFactory.from_scheme(scheme, name="") input_tfm = factory.create_transform( @@ -59,8 +54,9 @@ def test_correctness_linear(scheme_kwargs): assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) -@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) -def test_correctness_model(scheme_kwargs, model_apply, offload=False): +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomized", (True, False)) +def test_correctness_model(type, randomized, model_apply, offload=False): # load model model = model_apply[0] if offload: @@ -75,15 +71,10 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False): # apply transforms config = TransformConfig( config_groups={ - "": TransformScheme( - **scheme_kwargs, - apply=model_apply[1], - ) + "": TransformScheme(type=type, randomized=randomized, apply=model_apply[1]) } ) - for name, scheme in config.config_groups.items(): - factory = TransformFactory.from_scheme(scheme, name=name) - factory.apply_to_model(model) + apply_transform_config(model, config) # compare outputs output = model(input) @@ -92,6 +83,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) -def test_correctness_model_offload(scheme_kwargs, model_apply): - test_correctness_model(scheme_kwargs, model_apply, offload=True) +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomized", (True, False)) +def test_correctness_model_offload(type, randomized, model_apply): + test_correctness_model(type, randomized, model_apply, offload=True) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index ec6f0bdf..fcca33d4 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -20,23 +20,18 @@ TransformArgs, TransformBase, TransformConfig, - TransformFactory, TransformScheme, + apply_transform_config, ) from compressed_tensors.utils import align_modules, offloaded_dispatch from tests.test_transform.conftest import TransformableModel from tests.testing_utils import requires_accelerate, requires_gpu -def scheme_kwargs(): - all_types = TransformFactory.registered_names() - base = [{"type": type} for type in all_types] - randomized = [{"type": type, "randomize": True} for type in all_types] - return base + randomized - - -@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) -def test_memory_sharing(scheme_kwargs, offload=False): +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomized", (True, False)) +@pytest.mark.parametrize("requires_grad", (True, False)) +def test_memory_sharing(type, randomized, requires_grad, offload=False): # load model (maybe with offloading) model = TransformableModel(2, 2, 4, 4, 8, 8) if offload: @@ -46,7 +41,9 @@ def test_memory_sharing(scheme_kwargs, offload=False): config = TransformConfig( config_groups={ "": TransformScheme( - **scheme_kwargs, + type=type, + randomzied=randomized, + requires_grad=requires_grad, apply=[ TransformArgs(targets="Linear", location="input"), TransformArgs(targets="Linear", location="output"), @@ -54,9 +51,7 @@ def test_memory_sharing(scheme_kwargs, offload=False): ) } ) - for name, scheme in config.config_groups.items(): - factory = TransformFactory.from_scheme(scheme, name=name) - factory.apply_to_model(model) + apply_transform_config(model, config) # check that memory is shared when onloaded with align_modules(model.modules()): @@ -88,12 +83,10 @@ def test_memory_sharing(scheme_kwargs, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) -def test_memory_sharing_offload(scheme_kwargs): - test_memory_sharing(scheme_kwargs, offload=True) - - -@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) -def test_memory_sharing_training(scheme_kwargs): - scheme_kwargs["requires_grad"] = True - test_memory_sharing(scheme_kwargs, offload=False) +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomized", (True, False)) +def test_memory_sharing_offload( + type, + randomized, +): + test_memory_sharing(type, randomized, requires_grad=False, offload=True)