diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c8dbeced..03d936dc 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -152,11 +152,7 @@ def apply_quantization_config( # list of submodules to ignore ignored_submodules = defaultdict(list) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in iter_named_quantizable_modules( - model, - include_children=True, - include_attn=True, - ): # child modules and attention modules + for name, submodule in model.named_modules(): # child modules and attention modules # potentially fix module name to remove FSDP wrapper prefix name = fix_fsdp_module_name(name) if matches := find_name_or_class_matches(name, submodule, config.ignore): diff --git a/src/compressed_tensors/transform/__init__.py b/src/compressed_tensors/transform/__init__.py index f6d656dd..e7546d62 100644 --- a/src/compressed_tensors/transform/__init__.py +++ b/src/compressed_tensors/transform/__init__.py @@ -23,3 +23,4 @@ from .factory.hadamard import * from .factory.matrix_multiply import * from .factory.random_hadamard import * +from .apply import * diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py new file mode 100644 index 00000000..fb745b91 --- /dev/null +++ b/src/compressed_tensors/transform/apply.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transform import TransformConfig, TransformFactory + + +__all__ = ["apply_transform_config"] + + +def apply_transform_config(model: torch.nn.Module, config: TransformConfig): + for name, scheme in config.config_groups.items(): + factory = TransformFactory.from_scheme(scheme, name=name) + factory.apply_to_model(model) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 7448c604..a2751f5a 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -99,10 +99,10 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location.value}" transform = self.create_transform(module, args) - register_offload_module(module, transform_name, transform) # (1) # register input transformation hook if args.location == TransformLocation.INPUT: + register_offload_module(module, transform_name, transform) def input_hook(_, args): input = args[0] @@ -115,8 +115,8 @@ def input_hook(_, args): TransformLocation.WEIGHT_INPUT, TransformLocation.WEIGHT_OUTPUT, ): - assert isinstance(module, torch.nn.Linear) - assert module.bias is None + assert isinstance(module, (torch.nn.Linear, torch.nn.Embedding)) + assert not hasattr(module, "bias") or module.bias is None with torch.no_grad(), align_module_device(module): update_offload_parameter(module, "weight", transform(module.weight)) @@ -130,6 +130,7 @@ def input_hook(_, args): # register output transformation hook elif args.location == TransformLocation.OUTPUT: + register_offload_module(module, transform_name, transform) def output_hook(_, _input, output): return transform(output) @@ -140,9 +141,6 @@ def output_hook(_, _input, output): else: raise NotImplementedError() - # (1) even in the `weight` cases, this submodule attachment is needed in order - # to support saving in the frozen state - class TransformBase(Module, ABC): """ diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index c14da51f..f2ab3bc7 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union, Type import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -22,10 +22,10 @@ apply_transform_weight, get_matrix_size, ) -from compressed_tensors.utils import get_offloaded_device +from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("hadamard") @@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory): def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) + self.perms = ParameterizedDefaultDict(self._create_permutation) def create_transform(self, module: Module, args: TransformArgs): """ @@ -50,30 +51,70 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert isinstance(module, Linear) + assert hasattr(module, "weight") size = get_matrix_size(module, args.location) dtype = module.weight.dtype device = get_offloaded_device(module) + exec_device = get_execution_device(module) - weight = self.weights[size, dtype, device] - return HadamardTransform(weight, args) + weight = self.weights.get(size, dtype, device, construct_device=exec_device) + perm = self.perms[weight] if self.scheme.randomize else None + return HadamardTransform(weight, perm, args, type(module)) - def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = deterministic_hadamard_matrix(size, dtype, device) - data = data.to(dtype=dtype, device=device) + def _create_weight( + self, + size: int, + dtype: dtype, + device: device, + construct_device: device, + ) -> Parameter: + # construct on execution device, cache on offload device + if self.scheme.num_heads is None or self.scheme.num_heads <= 1: + data = deterministic_hadamard_matrix(size, dtype, construct_device) + else: + assert size % self.scheme.num_heads == 0 + data = torch.kron( + torch.eye(self.scheme.num_heads, dtype=dtype), + deterministic_hadamard_matrix( + self.scheme.head_dim, dtype, construct_device + ), + ) + data = data.to(device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) + def _create_permutation(self, weight: Parameter) -> Parameter: + data = torch.randperm(weight.size(0), generator=self.generator) + return Parameter(data, requires_grad=False) + class HadamardTransform(TransformBase): - def __init__(self, weight: Parameter, args: TransformArgs): + def __init__( + self, + weight: Parameter, + perm: Union[Parameter, None], + args: TransformArgs, + module_type: Type, + ): super().__init__() self.weight = weight + self.perm = perm self.args = args + self.module_type = module_type def forward(self, value: Tensor) -> Tensor: - if not self.args.inverse: - weight = self.weight - else: - weight = self.weight.T + weight = self.weight + + if self.perm is not None: + weight = weight[self.perm][:, self.perm] + + if self.args.inverse: + weight = weight.T - return apply_transform_weight(weight, value, self.args.location) + # NOTE: SpinQuant code up-converts to float64 for application + # of transform, then down-converts + return apply_transform_weight( + weight.to(torch.float64), + value.to(torch.float64), + self.args.location, + self.module_type, + ).to(weight.dtype) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index e551fc5f..25b141e3 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Type import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -59,7 +59,7 @@ def create_transform(self, module: Module, args: TransformArgs): if args.inverse: weight = self.inverses[weight] - return RandomMatrixTransform(weight, args) + return RandomMatrixTransform(weight, args, type(module)) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: data = torch.rand( @@ -73,17 +73,22 @@ def _create_inverse(self, weight: Parameter) -> Parameter: class RandomMatrixTransform(TransformBase): - def __init__(self, weight: Tensor, args: TransformArgs): + def __init__(self, weight: Tensor, args: TransformArgs, module_type: Type): super().__init__() self.weight = weight # is an inverse if args.inverse self.args = args + self.module_type = module_type def forward(self, value: Tensor) -> Parameter: - return apply_transform_weight(self.weight, value, self.args.location) + return apply_transform_weight( + self.weight, value, self.args.location, self.module_type + ) def right_inverse(self, value: Tensor) -> Tensor: inverse = high_precision_invert(self.weight) - return apply_transform_weight(inverse, value, self.args.location) + return apply_transform_weight( + inverse, value, self.args.location, self.module_type + ) def high_precision_invert(weight: Tensor) -> Tensor: diff --git a/src/compressed_tensors/transform/factory/random_hadamard.py b/src/compressed_tensors/transform/factory/random_hadamard.py index 78fb6975..1d67ab0f 100644 --- a/src/compressed_tensors/transform/factory/random_hadamard.py +++ b/src/compressed_tensors/transform/factory/random_hadamard.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch from compressed_tensors.transform import HadamardFactory, TransformFactory from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix from torch import device, dtype @@ -28,7 +29,14 @@ class RandomHadamardFactory(HadamardFactory): :param seed: random seed used to transform weight randomization """ - def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = random_hadamard_matrix(size, dtype, device, self.generator) - data = data.to(dtype=dtype, device=device) + def _create_weight( + self, + size: int, + dtype: dtype, + device: device, + construct_device: device, + ) -> Parameter: + # construct on execution device, cache on offload device + data = random_hadamard_matrix(size, dtype, construct_device, self.generator) + data = data.to(device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) diff --git a/src/compressed_tensors/transform/transform_config.py b/src/compressed_tensors/transform/transform_config.py index 414c21e0..df178c42 100644 --- a/src/compressed_tensors/transform/transform_config.py +++ b/src/compressed_tensors/transform/transform_config.py @@ -49,7 +49,7 @@ class TransformConfig(BaseModel): inverse=True, ), ], - randomize_modules=True, + randomize=True, ), "u": TransformScheme( type="hadamard", @@ -62,7 +62,7 @@ class TransformConfig(BaseModel): targets=["Linear"], location="output", inverse=True # non-mergable ), ], - randomize_modules=True, + randomize=True, ), } ) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 1335063c..18885d68 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from compressed_tensors.transform import TransformArgs from pydantic import BaseModel, Field @@ -31,13 +31,14 @@ class TransformScheme(BaseModel): (see `Transforms.registered_names()`) :param apply: list of TransformationArgs containing the information about the modules that should be targeted by the specified transform - :param randomize_modules: True if unique transforms should be applied to each - unique module targeted by `apply`, otherwise reuse transform weights where - applicable + :param randomize: True if uniquely randomized transform weights should be used, + otherwise use identical transform weights where applicable :param requires_grad: True if weights include gradients for training """ type: str apply: List[TransformArgs] = Field(default_factory=list) - randomize_modules: bool = Field(default=False) + randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) + # TODO infer num_heads + num_heads: Optional[int] = None diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 9281d6cc..92625f1f 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -51,7 +51,9 @@ def deterministic_hadamard_matrix( log2 = int(math.log2(size)) if size != 2**log2: - raise ValueError("Cannot construct deterministic hadamard of size != 2^n") + raise ValueError( + f"Cannot construct deterministic hadamard of size {size} != 2^n" + ) H = torch.tensor([[1]], dtype=dtype, device=device) diff --git a/src/compressed_tensors/transform/utils/utils.py b/src/compressed_tensors/transform/utils/utils.py index e60d24dc..a0e6909b 100644 --- a/src/compressed_tensors/transform/utils/utils.py +++ b/src/compressed_tensors/transform/utils/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Type import torch from compressed_tensors.transform import TransformLocation @@ -27,17 +28,28 @@ def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int :param location: location on module :return: size of matrix """ - assert isinstance(module, torch.nn.Linear) - if location in ("input", TransformLocation.WEIGHT_INPUT): - return module.in_features - else: - return module.out_features + if isinstance(module, torch.nn.Linear): + if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): + return module.in_features + else: + return module.out_features + elif isinstance(module, torch.nn.Embedding): + if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): + return module.num_embeddings + else: + return module.embedding_dim + + raise ValueError( + f"Unsupported module type {type(module)}, " + "should be either Linear or Embedding." + ) def apply_transform_weight( - weight: torch.Tensor, + transform_weight: torch.Tensor, value: torch.Tensor, location: TransformLocation, + module_type: Type, ) -> torch.Tensor: """ Using the transform location, determine how to apply the transform weight to the @@ -69,23 +81,46 @@ def apply_transform_weight( = y U = yh - :param weight: transform weight to apply - :param value: value to apply weight to - :param location: determines how weight should be applied - :return: value after transform weight has been applied + :param transform_weight: transform weight to apply + :param value: value to apply transform_weight to + :param location: determines how transform_weight should be applied + :param model_type: result of type(module), passed in to determine application of + weight transform. This is needed because torch uses convention: + - torch.nn.Linear(in_features,out_features) has weight shape + (out_features, in_features) + - torch.nn.Embedding(num_embeddings, embedding_dim) has weight shape + (num_embeddings, embedding_dim) + The transform has to account for Linear's transposed weights + :return: value after transform_weight has been applied """ if location == TransformLocation.INPUT: - return value @ weight + return value @ transform_weight elif location == TransformLocation.WEIGHT_INPUT: - return value @ weight.T + if module_type is torch.nn.Linear: + # equivalent to (transform_weight @ value.T).T + return value @ transform_weight.T + else: + raise NotImplementedError( + f"{TransformLocation.WEIGHT_INPUT} transform not " + f"implemented for module type {module_type}" + ) elif location == TransformLocation.WEIGHT_OUTPUT: - return weight.T @ value + if module_type is torch.nn.Linear: + # equivalent to (value.T @ transform_weight).T + return transform_weight.T @ value + elif module_type is torch.nn.Embedding: + return value @ transform_weight + else: + raise NotImplementedError( + f"{TransformLocation.WEIGHT_OUTPUT} transform not " + f"implemented for module type {module_type}" + ) elif location == TransformLocation.OUTPUT: - return value @ weight + return value @ transform_weight else: raise NotImplementedError(f"{location} has not been implemented yet") diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index d8898ae4..c06d3e3f 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -373,11 +373,16 @@ class ParameterizedDefaultDict(dict): def __init__(self, default_factory: Callable[[Any], Any]): self.default_factory = default_factory + self._kwargs = {} - def __missing__(self, key): + def __missing__(self, key: Any) -> Any: if isinstance(key, tuple): - value = self.default_factory(*key) + value = self.default_factory(*key, **self._kwargs) else: - value = self.default_factory(key) + value = self.default_factory(key, **self._kwargs) self[key] = value return value + + def get(self, *args, **kwargs) -> Any: + with patch_attr(self, "_kwargs", kwargs): + return self[args] diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py new file mode 100644 index 00000000..8681b2f8 --- /dev/null +++ b/tests/test_transform/conftest.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.transform import TransformArgs + + +class TransformableModel(torch.nn.Module): + def __init__(self, *sizes): + super().__init__() + self.fcs = torch.nn.ModuleList([]) + self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) + for index in range(1, len(sizes) - 1): + self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) + + def forward(self, x): + for layer in self.fcs: + x = layer(x) + return x + + +@pytest.fixture(scope="function") +def model_apply(): + model = TransformableModel(2, 4, 8, 16, 32, 64) + apply = [ + # weight output -> input + TransformArgs(targets="fcs.0", location="weight_output"), + TransformArgs(targets="fcs.1", location="input", inverse=True), + # output -> weight input + TransformArgs(targets="fcs.1", location="output"), + TransformArgs(targets="fcs.2", location="weight_input", inverse=True), + # output -> input + TransformArgs(targets="fcs.2", location="output"), + TransformArgs(targets="fcs.3", location="input", inverse=True), + # weight output -> weight input + TransformArgs(targets="fcs.3", location="weight_output"), + TransformArgs(targets="fcs.4", location="weight_input", inverse=True), + ] + + return model, apply diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index ed7d7f5e..61430dd2 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -16,34 +16,27 @@ import torch from compressed_tensors.transform import ( TransformArgs, + TransformConfig, TransformFactory, TransformScheme, + apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch from tests.testing_utils import requires_accelerate, requires_gpu -class TransformableModel(torch.nn.Module): - def __init__(self, *sizes): - super().__init__() - self.fcs = torch.nn.ModuleList([]) - self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) - for index in range(1, len(sizes) - 1): - self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) +def scheme_kwargs(): + all_types = TransformFactory.registered_names() + base = [{"type": type} for type in all_types] + randomized = [{"type": type, "randomize": True} for type in all_types] + return base + randomized - def forward(self, x): - for layer in self.fcs: - x = layer(x) - return x - -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_correctness_linear(scheme): +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_linear(scheme_kwargs): size = (4, 8) module = torch.nn.Linear(*size, bias=True) + scheme = TransformScheme(**scheme_kwargs) factory = TransformFactory.from_scheme(scheme, name="") input_tfm = factory.create_transform( @@ -67,50 +60,37 @@ def test_correctness_linear(scheme): assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_correctness_model(scheme, offload=False): +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_model(scheme_kwargs, model_apply, offload=False): # load model - model = TransformableModel(2, 4, 8, 16, 32, 64) + model = model_apply[0] if offload: model = offloaded_dispatch(model, torch.device("cuda")) - # create factory - scheme.apply = [ - # weight output -> input - TransformArgs(targets="fcs.0", location="weight_output"), - TransformArgs(targets="fcs.1", location="input", inverse=True), - # output -> weight input - TransformArgs(targets="fcs.1", location="output"), - TransformArgs(targets="fcs.2", location="weight_input", inverse=True), - # output -> input - TransformArgs(targets="fcs.2", location="output"), - TransformArgs(targets="fcs.3", location="input", inverse=True), - # weight output -> weight input - TransformArgs(targets="fcs.3", location="weight_output"), - TransformArgs(targets="fcs.4", location="weight_input", inverse=True), - ] - factory = TransformFactory.from_scheme(scheme, name="") - - # create inputs + # get output input = torch.rand((17, model.fcs[0].in_features)) if offload: input = input.to(torch.device("cuda")) + true_output = model(input) + + # apply transforms + config = TransformConfig( + config_groups={ + "": TransformScheme( + **scheme_kwargs, + apply=model_apply[1], + ) + } + ) + apply_transform_config(model, config) # compare outputs - true_output = model(input) - factory.apply_to_model(model) output = model(input) assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) @requires_gpu @requires_accelerate() -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_correctness_model_offload(scheme): - test_correctness_model(scheme, offload=True) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_model_offload(scheme_kwargs, model_apply): + test_correctness_model(scheme_kwargs, model_apply, offload=True) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 7b118f75..713c7687 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -19,49 +19,43 @@ from compressed_tensors.transform import ( TransformArgs, TransformBase, + TransformConfig, TransformFactory, TransformScheme, + apply_transform_config, ) from compressed_tensors.utils import align_modules, offloaded_dispatch +from tests.test_transform.conftest import TransformableModel from tests.testing_utils import requires_accelerate, requires_gpu -class TransformableModel(torch.nn.Module): - def __init__(self, *sizes): - super().__init__() - self.fcs = torch.nn.ModuleList([]) - self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) - for index in range(1, len(sizes) - 1): - self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) +def scheme_kwargs(): + all_types = TransformFactory.registered_names() + base = [{"type": type} for type in all_types] + randomized = [{"type": type, "randomize": True} for type in all_types] + return base + randomized - def forward(self, x): - for layer in self.fcs: - x = layer(x) - return x - - -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_memory_sharing(scheme, offload=False): - # load scheme and factory - scheme = TransformScheme( - type="hadamard", - apply=[ - TransformArgs(targets="Linear", location="input"), - TransformArgs(targets="Linear", location="output"), - ], - ) - factory = TransformFactory.from_scheme(scheme, name="") +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing(scheme_kwargs, offload=False): # load model (maybe with offloading) model = TransformableModel(2, 2, 4, 4, 8, 8) if offload: offloaded_dispatch(model, torch.device("cuda")) # add transforms to model - factory.apply_to_model(model) + config = TransformConfig( + config_groups={ + "": TransformScheme( + **scheme_kwargs, + apply=[ + TransformArgs(targets="Linear", location="input"), + TransformArgs(targets="Linear", location="output"), + ], + ) + } + ) + apply_transform_config(model, config) # check that memory is shared when onloaded with align_modules(model.modules()): @@ -93,20 +87,12 @@ def test_memory_sharing(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_memory_sharing_offload(scheme): - test_memory_sharing(scheme, offload=True) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing_offload(scheme_kwargs): + test_memory_sharing(scheme_kwargs, offload=True) -@pytest.mark.parametrize( - "scheme", - [ - TransformScheme(type=name, requires_grad=True) - for name in TransformFactory.registered_names() - ], -) -def test_memory_sharing_training(scheme): - test_memory_sharing(scheme, offload=False) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing_training(scheme_kwargs): + scheme_kwargs["requires_grad"] = True + test_memory_sharing(scheme_kwargs, offload=False) diff --git a/tests/test_transform/test_transform_scheme.py b/tests/test_transform/test_transform_scheme.py index ad851762..839ab46a 100644 --- a/tests/test_transform/test_transform_scheme.py +++ b/tests/test_transform/test_transform_scheme.py @@ -24,7 +24,7 @@ def test_basic_scheme(): type="hadamard", apply=[basic_args], ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 1 assert isinstance(scheme.apply[0], TransformArgs) @@ -43,10 +43,10 @@ def test_multiple_groups_global(): scheme = TransformScheme( type="hadamard", apply=[embedding_args, linear_args], - randomize_modules=True, + randomize=True, ) - assert scheme.randomize_modules + assert scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 2 assert isinstance(scheme.apply[0], TransformArgs) @@ -69,6 +69,6 @@ def test_multiple_groups(): apply=apply, ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 20