Skip to content

Commit 4ae491d

Browse files
committed
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_permutations
2 parents 779956f + dd72b6a commit 4ae491d

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

src/compressed_tensors/transform/transform_args.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,31 @@
1313
# limitations under the License.
1414

1515
from enum import Enum
16-
from typing import Any, List
16+
from typing import List
1717

1818
from pydantic import BaseModel, Field, field_validator
1919

2020

21-
__all__ = ["TransformLocation", "TransformArgs"]
21+
__all__ = ["TransformArgs", "TransformLocation"]
2222

2323

2424
class TransformLocation(str, Enum):
25+
"""
26+
Enum representing which parameters/activations a transform weight should be applied
27+
to on a given module.
28+
29+
| -------------------------------------------------------------------------------------------------------- | # noqa: E501
30+
| Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501
31+
| --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501
32+
| `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501
33+
| `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501
34+
| `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
35+
| `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
36+
| `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501
37+
| `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501
38+
| -------------------------------------------------------------------------------------------------------- | # noqa: E501
39+
"""
40+
2541
INPUT = "input"
2642
WEIGHT_INPUT = "weight_input"
2743
WEIGHT_OUTPUT = "weight_output"

src/compressed_tensors/transform/utils/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,28 @@ def apply_transform_weight(
4141
) -> torch.Tensor:
4242
"""
4343
Using the transform location, determine how to apply the transform weight to the
44-
given value
44+
given value. For more info on input and output transforms, see `TransformLocation`
45+
46+
The following explains how weights should be applied to values according to location
4547
4648
let x be input activation
4749
W be weight,
4850
yh, xh, Wh be transformed output, input, weight
4951
5052
note that
5153
y = (x W.T) // torch.nn.Linear
52-
yh = (xh) (Wh).T // transformed
54+
55+
Choose values for yh, xh, and Wh which incorporate matrix transforms
5356
5457
let V, Vi be transform matrices on input side
5558
U, Ui be transform matrices on output side
5659
57-
show that the following values for yh, xh, and Wh are consistent
58-
5960
pick xh = (x V)
6061
Wh = (U.T W Vi.T)
6162
yh = (y U)
6263
64+
The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
65+
6366
(xh) (Wh).T = (x V) (U.T W Vi.T).T
6467
= (x V) (Vi W.T U) // transpose matrix product identity
6568
= (x W.T) U
@@ -83,3 +86,6 @@ def apply_transform_weight(
8386

8487
elif location == TransformLocation.OUTPUT:
8588
return value @ weight
89+
90+
else:
91+
raise NotImplementedError(f"{location} has not been implemented yet")

0 commit comments

Comments
 (0)