|
| 1 | +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, |
| 10 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import Callable, Optional, Tuple |
| 16 | + |
| 17 | +import torch |
| 18 | +from compressed_tensors.transform import TransformLocation |
| 19 | + |
| 20 | + |
| 21 | +__all__ = ["get_transform_size", "apply_transform_weight"] |
| 22 | + |
| 23 | + |
| 24 | +def get_transform_size( |
| 25 | + module: torch.nn.Module, |
| 26 | + location: TransformLocation, |
| 27 | + head_dim: Optional[int] = None, |
| 28 | +) -> int: |
| 29 | + """ |
| 30 | + Determine the size of a transform matrix given its location on the module |
| 31 | +
|
| 32 | + :param module: module that matrix will be applied to |
| 33 | + :param location: location on module |
| 34 | + :param head_dim: size of head when transform is applied to mha |
| 35 | + :return: size of matrix |
| 36 | + """ |
| 37 | + if isinstance(module, torch.nn.Linear): |
| 38 | + if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): |
| 39 | + size = module.in_features |
| 40 | + else: |
| 41 | + size = module.out_features |
| 42 | + elif isinstance(module, torch.nn.Embedding): |
| 43 | + if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): |
| 44 | + size = module.num_embeddings |
| 45 | + else: |
| 46 | + size = module.embedding_dim |
| 47 | + else: |
| 48 | + raise NotImplementedError(f"Transforms on {type(module)} are not supported") |
| 49 | + |
| 50 | + if head_dim is not None: |
| 51 | + if size % head_dim != 0: |
| 52 | + raise ValueError( |
| 53 | + f"{head_dim} must divide {size} for {type(module)} at {location}" |
| 54 | + ) |
| 55 | + |
| 56 | + size = head_dim |
| 57 | + |
| 58 | + return size |
| 59 | + |
| 60 | + |
| 61 | +def apply_transform_weight( |
| 62 | + transform_weight: torch.Tensor, |
| 63 | + value: torch.Tensor, |
| 64 | + location: TransformLocation, |
| 65 | + module_type: type[torch.nn.Module], |
| 66 | +) -> torch.Tensor: |
| 67 | + """ |
| 68 | + Using the transform location, apply the transform_weight to the |
| 69 | + given value wrt linear weights. For more info on input and output transforms, |
| 70 | + see `TransformLocation` |
| 71 | +
|
| 72 | + The following explains how weights should be applied to values according to location |
| 73 | +
|
| 74 | + let x be input activation |
| 75 | + W be weight, |
| 76 | + yh, xh, Wh be transformed output, input, weight |
| 77 | +
|
| 78 | + note that |
| 79 | + y = (x W.T) // torch.nn.Linear |
| 80 | +
|
| 81 | + Choose values for yh, xh, and Wh which incorporate matrix transforms |
| 82 | +
|
| 83 | + let V, Vi be transform matrices on input side |
| 84 | + U, Ui be transform matrices on output side |
| 85 | +
|
| 86 | + pick xh = (x V) |
| 87 | + Wh = (U.T W Vi.T) |
| 88 | + yh = (y U) |
| 89 | +
|
| 90 | + The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh |
| 91 | +
|
| 92 | + (xh) (Wh).T = (x V) (U.T W Vi.T).T |
| 93 | + = (x V) (Vi W.T U) // transpose matrix product identity |
| 94 | + = (x W.T) U |
| 95 | + = y U |
| 96 | + = yh |
| 97 | +
|
| 98 | + :param transform_weight: transform weight to apply |
| 99 | + :param value: value to apply transform_weight to |
| 100 | + :param location: determines how weight should be applied |
| 101 | + :param model_type: result of type(module), passed in to determine application of |
| 102 | + weight transform |
| 103 | + :return: value after transform_weight has been applied |
| 104 | + """ |
| 105 | + |
| 106 | + assert transform_weight.shape[0] == transform_weight.shape[1] |
| 107 | + |
| 108 | + if module_type == torch.nn.Linear: |
| 109 | + if location == TransformLocation.INPUT: |
| 110 | + return _multihead_matmul(value, transform_weight) |
| 111 | + |
| 112 | + elif location == TransformLocation.WEIGHT_INPUT: |
| 113 | + # equivalent to (transform_weight @ value.T).T |
| 114 | + return _multihead_matmul(value, transform_weight.T) |
| 115 | + |
| 116 | + elif location == TransformLocation.WEIGHT_OUTPUT: |
| 117 | + # equivalent to (value.T @ transform_weight).T |
| 118 | + return _multihead_matmul(transform_weight.T, value) |
| 119 | + |
| 120 | + elif location == TransformLocation.OUTPUT: |
| 121 | + return _multihead_matmul(value, transform_weight) |
| 122 | + |
| 123 | + # similar derivation to torch.nn.Linear, but `y = (x W)` |
| 124 | + elif module_type == torch.nn.Embedding: |
| 125 | + if location == TransformLocation.INPUT: |
| 126 | + return _multihead_matmul(value, transform_weight) |
| 127 | + |
| 128 | + elif location == TransformLocation.WEIGHT_INPUT: |
| 129 | + return _multihead_matmul( |
| 130 | + transform_weight, |
| 131 | + value, |
| 132 | + ) |
| 133 | + |
| 134 | + elif location == TransformLocation.WEIGHT_OUTPUT: |
| 135 | + return _multihead_matmul(value, transform_weight) |
| 136 | + |
| 137 | + elif location == TransformLocation.OUTPUT: |
| 138 | + return _multihead_matmul(value, transform_weight) |
| 139 | + |
| 140 | + raise NotImplementedError( |
| 141 | + f"Applying transforms to {module_type} {location} is not supported" |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: |
| 146 | + """ |
| 147 | + Performs A @ B for last two dims of two matrices A and B that possibly |
| 148 | + have different shapes, as is the case in multi-headed dimension. If |
| 149 | + shapes are different, this is equivalent to converting the last two dims |
| 150 | + of the smaller matrix into a block-diagonal matrix with the same shape as |
| 151 | + the last two dims of the larger matrix. |
| 152 | +
|
| 153 | + E.g. if A is half the size of B, this function will perform |
| 154 | + [[A ] @ B |
| 155 | + [ A]] |
| 156 | +
|
| 157 | + If B is a third of the size of A, this function will perform |
| 158 | + A @ [[B ] |
| 159 | + [ B ] |
| 160 | + [ B]] |
| 161 | +
|
| 162 | + This function will error out if the shapes are not evenly divisble |
| 163 | +
|
| 164 | + :param A: left-hand tensor |
| 165 | + :param B: right-hand tensor |
| 166 | + :return: result |
| 167 | + """ |
| 168 | + if A.shape[-1] > B.shape[-2]: |
| 169 | + head_dim = B.shape[-2] |
| 170 | + num_heads = A.shape[-1] // head_dim |
| 171 | + A = A.unflatten(-1, (num_heads, head_dim)) |
| 172 | + return (A @ B).flatten(-2, -1) |
| 173 | + elif A.shape[-1] < B.shape[-2]: |
| 174 | + head_dim = A.shape[-1] |
| 175 | + num_heads = B.shape[-2] // head_dim |
| 176 | + B = B.unflatten(-2, (num_heads, head_dim)) |
| 177 | + return (A @ B).flatten(-3, -2) |
| 178 | + else: |
| 179 | + return A @ B |
0 commit comments