diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index dd515976..f55089b2 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -18,9 +18,9 @@ from compressed_tensors.transform import TransformArgs, TransformScheme from compressed_tensors.transform.factory.base import TransformBase, TransformFactory from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix -from compressed_tensors.transform.utils.utils import ( +from compressed_tensors.transform.utils.matrix import ( apply_transform_weight, - get_matrix_size, + get_transform_size, ) from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict @@ -52,7 +52,7 @@ def create_transform(self, module: Module, args: TransformArgs): :param args: defines how the transform will be applied to the module """ assert isinstance(module, Linear) - size = get_matrix_size(module, args.location) + size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = module.weight.dtype device = get_offloaded_device(module) exec_device = get_execution_device(module) @@ -60,7 +60,7 @@ def create_transform(self, module: Module, args: TransformArgs): factory_kwargs = {"construct_device": exec_device} weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs) perm = self.perms[weight] if self.scheme.randomize else None - return HadamardTransform(weight, perm, args) + return HadamardTransform(weight, perm, args, type(module)) def _create_weight( self, @@ -81,12 +81,17 @@ def _create_permutation(self, weight: Parameter) -> Parameter: class HadamardTransform(TransformBase): def __init__( - self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs + self, + weight: Parameter, + perm: Optional[Parameter], + args: TransformArgs, + module_type: type[torch.nn.Module], ): super().__init__() self.weight = weight self.perm = perm self.args = args + self.module_type = module_type def forward(self, value: Tensor) -> Tensor: weight = self.weight @@ -97,4 +102,6 @@ def forward(self, value: Tensor) -> Tensor: if self.args.inverse: weight = weight.T - return apply_transform_weight(weight, value, self.args.location) + return apply_transform_weight( + weight, value, self.args.location, self.module_type + ) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 47e9bcbb..7279befc 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -17,9 +17,9 @@ import torch from compressed_tensors.transform import TransformArgs, TransformScheme from compressed_tensors.transform.factory.base import TransformBase, TransformFactory -from compressed_tensors.transform.utils.utils import ( +from compressed_tensors.transform.utils.matrix import ( apply_transform_weight, - get_matrix_size, + get_transform_size, ) from compressed_tensors.utils import get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict @@ -51,7 +51,7 @@ def create_transform(self, module: Module, args: TransformArgs): :param args: defines how the transform will be applied to the module """ assert isinstance(module, Linear) - size = get_matrix_size(module, args.location) + size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = module.weight.dtype device = get_offloaded_device(module) @@ -59,7 +59,7 @@ def create_transform(self, module: Module, args: TransformArgs): if args.inverse: weight = self.inverses[weight] - return RandomMatrixTransform(weight, args) + return RandomMatrixTransform(weight, args, type(module)) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: # TODO: verify that weight is invertible (has non-zero determinant) @@ -74,17 +74,27 @@ def _create_inverse(self, weight: Parameter) -> Parameter: class RandomMatrixTransform(TransformBase): - def __init__(self, weight: Tensor, args: TransformArgs): + def __init__( + self, + weight: Tensor, + args: TransformArgs, + module_type: type[torch.nn.Module], + ): super().__init__() self.weight = weight # is an inverse if args.inverse self.args = args + self.module_type = module_type def forward(self, value: Tensor) -> Parameter: - return apply_transform_weight(self.weight, value, self.args.location) + return apply_transform_weight( + self.weight, value, self.args.location, self.module_type + ) def right_inverse(self, value: Tensor) -> Tensor: inverse = high_precision_invert(self.weight) - return apply_transform_weight(inverse, value, self.args.location) + return apply_transform_weight( + inverse, value, self.args.location, self.module_type + ) def high_precision_invert(weight: Tensor) -> Tensor: diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 64d646e0..1620c541 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from compressed_tensors.transform import TransformArgs from pydantic import BaseModel, Field @@ -40,3 +40,4 @@ class TransformScheme(BaseModel): apply: List[TransformArgs] = Field(default_factory=list) randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) + head_dim: Optional[int] = Field(default=None) diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py new file mode 100644 index 00000000..d3f95fd0 --- /dev/null +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple + +import torch +from compressed_tensors.transform import TransformLocation + + +__all__ = ["get_transform_size", "apply_transform_weight"] + + +def get_transform_size( + module: torch.nn.Module, + location: TransformLocation, + head_dim: Optional[int] = None, +) -> int: + """ + Determine the size of a transform matrix given its location on the module + + :param module: module that matrix will be applied to + :param location: location on module + :param head_dim: size of head when transform is applied to mha + :return: size of matrix + """ + if isinstance(module, torch.nn.Linear): + if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): + size = module.in_features + else: + size = module.out_features + else: + raise NotImplementedError(f"Transforms on {type(module)} are not supported") + + if head_dim is not None: + if size % head_dim != 0: + raise ValueError( + f"{head_dim} must divide {size} for {type(module)} at {location}" + ) + + size = head_dim + + return size + + +def apply_transform_weight( + weight: torch.Tensor, + value: torch.Tensor, + location: TransformLocation, + module_type: type[torch.nn.Module], +) -> torch.Tensor: + """ + :param weight: transform weight to apply + :param value: value to apply weight to + :param location: determines how weight should be applied + :param model_type: result of type(module), passed in to determine application of + weight transform + :return: value after weight has been applied + """ + fn, axis = _get_transform_method(module_type, location) + + assert weight.shape[0] == weight.shape[1] + head_dim = weight.shape[0] + num_heads = value.shape[axis] // head_dim + + value = value.unflatten(axis, (num_heads, head_dim)) + value = fn(weight, value) + value = value.flatten(axis - 1, axis) + + return value + + +def _get_transform_method( + module_type: type[torch.nn.Module], + location: TransformLocation, +) -> Tuple[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], int]: + """ + Using the transform location, determine how to apply the transform weight to the + given value wrt linear weights. For more info on input and output transforms, + see `TransformLocation` + + The following explains how weights should be applied to values according to location + + let x be input activation + W be weight, + yh, xh, Wh be transformed output, input, weight + + note that + y = (x W.T) // torch.nn.Linear + + Choose values for yh, xh, and Wh which incorporate matrix transforms + + let V, Vi be transform matrices on input side + U, Ui be transform matrices on output side + + pick xh = (x V) + Wh = (U.T W Vi.T) + yh = (y U) + + The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh + + (xh) (Wh).T = (x V) (U.T W Vi.T).T + = (x V) (Vi W.T U) // transpose matrix product identity + = (x W.T) U + = y U + = yh + + :param weight: transform weight to apply + :param value: value to apply weight to + :param location: determines how weight should be applied + :return: value after transform weight has been applied + """ + fn = axis = None + + if module_type == torch.nn.Linear: + if location == TransformLocation.INPUT: + fn = lambda weight, value: value @ weight + axis = -1 + + elif location == TransformLocation.WEIGHT_INPUT: + fn = lambda weight, value: value @ weight.T + axis = -1 + + elif location == TransformLocation.WEIGHT_OUTPUT: + fn = lambda weight, value: weight.T @ value + axis = -2 + + elif location == TransformLocation.OUTPUT: + fn = lambda weight, value: value @ weight + axis = -1 + + if fn is None: + raise NotImplementedError( + f"Applying transforms to {module_type} {location} is not supported" + ) + + return fn, axis diff --git a/src/compressed_tensors/transform/utils/utils.py b/src/compressed_tensors/transform/utils/utils.py deleted file mode 100644 index e60d24dc..00000000 --- a/src/compressed_tensors/transform/utils/utils.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from compressed_tensors.transform import TransformLocation - - -__all__ = ["get_matrix_size", "apply_transform_weight"] - - -def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int: - """ - Determine the size of a matrix given its location on the module - - :param module: module that matrix will be applied to - :param location: location on module - :return: size of matrix - """ - assert isinstance(module, torch.nn.Linear) - if location in ("input", TransformLocation.WEIGHT_INPUT): - return module.in_features - else: - return module.out_features - - -def apply_transform_weight( - weight: torch.Tensor, - value: torch.Tensor, - location: TransformLocation, -) -> torch.Tensor: - """ - Using the transform location, determine how to apply the transform weight to the - given value. For more info on input and output transforms, see `TransformLocation` - - The following explains how weights should be applied to values according to location - - let x be input activation - W be weight, - yh, xh, Wh be transformed output, input, weight - - note that - y = (x W.T) // torch.nn.Linear - - Choose values for yh, xh, and Wh which incorporate matrix transforms - - let V, Vi be transform matrices on input side - U, Ui be transform matrices on output side - - pick xh = (x V) - Wh = (U.T W Vi.T) - yh = (y U) - - The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh - - (xh) (Wh).T = (x V) (U.T W Vi.T).T - = (x V) (Vi W.T U) // transpose matrix product identity - = (x W.T) U - = y U - = yh - - :param weight: transform weight to apply - :param value: value to apply weight to - :param location: determines how weight should be applied - :return: value after transform weight has been applied - """ - - if location == TransformLocation.INPUT: - return value @ weight - - elif location == TransformLocation.WEIGHT_INPUT: - return value @ weight.T - - elif location == TransformLocation.WEIGHT_OUTPUT: - return weight.T @ value - - elif location == TransformLocation.OUTPUT: - return value @ weight - - else: - raise NotImplementedError(f"{location} has not been implemented yet") diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index 2067a647..83756ffd 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -33,6 +33,67 @@ def forward(self, x): return x +class MockAttention(torch.nn.Module): + def __init__( + self, hidden_size: int, num_attention_heads: int, num_key_value_heads: int + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.num_key_value_groups = num_attention_heads // num_key_value_heads + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + assert hidden_size >= num_attention_heads * self.head_dim + + self.q_proj = torch.nn.Linear( + hidden_size, num_attention_heads * self.head_dim, bias=False + ) + self.k_proj = torch.nn.Linear( + hidden_size, num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = torch.nn.Linear( + hidden_size, num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = torch.nn.Linear( + num_attention_heads * self.head_dim, hidden_size, bias=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, hidden_size = hidden_states.shape + hidden_shape = (batch_size, seq_len, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + key_states = self.repeat_kv(key_states, self.num_key_value_groups) + value_states = self.repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + ) + + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape((batch_size, seq_len, -1)).contiguous() + + return self.o_proj(attn_output) + + def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + @pytest.fixture(scope="function") def model_apply(): model = TransformableModel(2, 4, 8, 16, 32, 64) diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index b34ca51a..b9f18f0c 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -22,15 +22,17 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch +from tests.test_transform.conftest import MockAttention from tests.testing_utils import requires_accelerate, requires_gpu @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @pytest.mark.parametrize("randomized", (True, False)) -def test_correctness_linear(type, randomized): +@pytest.mark.parametrize("head_dim", (None, 2, 4)) +def test_correctness_linear(type, randomized, head_dim): size = (4, 8) module = torch.nn.Linear(*size, bias=True) - scheme = TransformScheme(type=type, randomized=randomized) + scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim) factory = TransformFactory.from_scheme(scheme, name="") input_tfm = factory.create_transform( @@ -46,7 +48,7 @@ def test_correctness_linear(type, randomized): module, TransformArgs(targets="Linear", location="output", inverse=True) ) - input = torch.rand((17, size[0])) + input = torch.rand((17, 5, size[0])) true_output = input @ module.weight.T input_transformed = input_tfm(input) weight_transformed = w_out_tfm(w_in_tfm(module.weight)) @@ -63,7 +65,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False): model = offloaded_dispatch(model, torch.device("cuda")) # get output - input = torch.rand((17, model.fcs[0].in_features)) + input = torch.rand((17, 5, model.fcs[0].in_features)) if offload: input = input.to(torch.device("cuda")) true_output = model(input) @@ -87,3 +89,40 @@ def test_correctness_model(type, randomized, model_apply, offload=False): @pytest.mark.parametrize("randomized", (True, False)) def test_correctness_model_offload(type, randomized, model_apply): test_correctness_model(type, randomized, model_apply, offload=True) + + +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomized", (True, False)) +@pytest.mark.parametrize("head_dim", (4, 8)) +def test_correctness_attention_heads(type, randomized, head_dim): + hidden_size = 64 + num_attention_heads = 8 + + attention = MockAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=head_dim, + ) + + input = torch.rand(17, 5, hidden_size) + true_output = attention(input) + + config = TransformConfig( + config_groups={ + "": TransformScheme( + type=type, + randomized=randomized, + head_dim=head_dim, + apply=[ + TransformArgs(targets="v_proj", location="weight_output"), + TransformArgs( + targets="o_proj", location="weight_input", inverse=True + ), + ], + ) + } + ) + apply_transform_config(attention, config) + + output = attention(input) + assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)