Skip to content

Commit d330744

Browse files
authored
[Transform] Hadamard and Matrix Transform Utils (#330)
* add utils and tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * standardize random hadamard Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * apply sqrt division first Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add random option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent d7ce8ec commit d330744

File tree

7 files changed

+461
-3
lines changed

7 files changed

+461
-3
lines changed

src/compressed_tensors/transform/transform_args.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,31 @@
1313
# limitations under the License.
1414

1515
from enum import Enum
16-
from typing import Any, List
16+
from typing import List
1717

1818
from pydantic import BaseModel, Field, field_validator
1919

2020

21-
__all__ = ["TransformArgs"]
21+
__all__ = ["TransformArgs", "TransformLocation"]
2222

2323

2424
class TransformLocation(str, Enum):
25+
"""
26+
Enum representing which parameters/activations a transform weight should be applied
27+
to on a given module.
28+
29+
| -------------------------------------------------------------------------------------------------------- | # noqa: E501
30+
| Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501
31+
| --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501
32+
| `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501
33+
| `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501
34+
| `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
35+
| `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
36+
| `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501
37+
| `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501
38+
| -------------------------------------------------------------------------------------------------------- | # noqa: E501
39+
"""
40+
2541
INPUT = "input"
2642
WEIGHT_INPUT = "weight_input"
2743
WEIGHT_OUTPUT = "weight_output"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Optional, Tuple
17+
18+
import numpy
19+
import torch
20+
21+
22+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
23+
24+
# adapted from:
25+
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
26+
def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
27+
"""
28+
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
29+
`n` must be a power of 2.
30+
31+
:param size: order of the matrix, must be a power of 2
32+
:return: hadamard matrix of size `size`
33+
"""
34+
if size <= 0:
35+
raise ValueError("Cannot construct deterministic hadamard of size <= 0")
36+
37+
log2 = int(math.log(size, 2))
38+
if size != 2**log2:
39+
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
40+
41+
H = numpy.array([[1]], dtype=int)
42+
43+
# Sylvester's construction
44+
for i in range(0, log2):
45+
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
46+
47+
return torch.from_numpy(H / math.sqrt(size))
48+
49+
50+
# adapted from:
51+
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
52+
53+
# TODO: the following library exists for online rotations and should be considered
54+
# in the future:
55+
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
56+
57+
58+
def random_hadamard_matrix(
59+
size: int, gen: Optional[torch.Generator] = None
60+
) -> torch.Tensor:
61+
"""
62+
Produces a randomly generated Hadamard matrix.
63+
See https://cornell-relaxml.github.io/quip-sharp/ ,
64+
Section "Randomized Hadamard Transformation"
65+
66+
:param size: The dimension of the hamadard matrix
67+
:param gen: Optional generator random values
68+
:return: randomly generated hadamard matrix
69+
"""
70+
# Benefits: support other shapes / non powers of 2, support randomization
71+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
72+
Q = Q * 2 - 1
73+
Q = torch.diag(Q)
74+
return _matmul_hadU(Q) / math.sqrt(size)
75+
76+
77+
def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
78+
# NOTE: we can easily extend the list of supported shapes/sizes
79+
# by adding to these methods
80+
hadK, K = None, None
81+
if n % 20 == 0:
82+
assert _is_pow2(n // 20)
83+
K = 20
84+
hadK = _get_had20().T if transpose else _get_had20()
85+
elif n % 12 == 0:
86+
assert _is_pow2(n // 12)
87+
K = 12
88+
hadK = _get_had12().T if transpose else _get_had12()
89+
else:
90+
assert _is_pow2(n)
91+
K = 1
92+
93+
return hadK, K
94+
95+
96+
def _matmul_hadU(X, transpose=False) -> torch.Tensor:
97+
n = X.shape[-1]
98+
# Check if we have the determined hadamard matrix
99+
hadK, K = _get_hadK(n, transpose)
100+
# Reshape diag matrix with randomized -1/+1
101+
input = X.clone().view(-1, n, 1)
102+
output = input.clone()
103+
104+
# for cases when hadK is not predetermined, determine hadamard matrix
105+
while input.shape[1] > K:
106+
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
107+
output = output.view(input.shape)
108+
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
109+
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
110+
output = output.view(input.shape[0], input.shape[1], -1)
111+
(input, output) = (output, input)
112+
del output
113+
114+
# K == 1 when hadK is None; this happens when the size dim (n)
115+
# is not comaptible with any of the maintained hadamard matrices
116+
117+
if K > 1:
118+
# Do not explicitly repeat - OOM
119+
# input = torch.bmm(
120+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
121+
# Use bcast instead
122+
123+
# for cases when hadK is pre-determined
124+
input = hadK.view(1, K, K).to(input) @ input
125+
126+
# normalize
127+
return input.view(X.shape)
128+
129+
130+
def _is_pow2(n: int) -> bool:
131+
return (n & (n - 1) == 0) and (n > 0)
132+
133+
134+
def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
135+
had_unpacked = numpy.unpackbits(packed_bits)
136+
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
137+
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
138+
return had_unpacked
139+
140+
141+
# http://www.neilsloane.com/hadamard/index.html
142+
def _get_had12() -> torch.Tensor:
143+
# fmt: off
144+
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
145+
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
146+
# fmt: on
147+
# TODO: just unpack during apply
148+
had_12_unpacked = _reshape_bits(had_12, original_size=12)
149+
return torch.tensor(had_12_unpacked)
150+
151+
152+
def _get_had20() -> torch.Tensor:
153+
# fmt: off
154+
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
155+
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
156+
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
157+
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
158+
# fmt: on
159+
# TODO: just unpack during apply
160+
had_20_unpacked = _reshape_bits(had_20, original_size=20)
161+
return torch.tensor(had_20_unpacked)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.transform import TransformLocation
17+
18+
19+
__all__ = ["get_matrix_size", "apply_transform_weight"]
20+
21+
22+
def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
23+
"""
24+
Determine the size of a matrix given its location on the module
25+
26+
:param module: module that matrix will be applied to
27+
:param location: location on module
28+
:return: size of matrix
29+
"""
30+
assert isinstance(module, torch.nn.Linear)
31+
if location in ("input", TransformLocation.WEIGHT_INPUT):
32+
return module.in_features
33+
else:
34+
return module.out_features
35+
36+
37+
def apply_transform_weight(
38+
weight: torch.Tensor,
39+
value: torch.Tensor,
40+
location: TransformLocation,
41+
) -> torch.Tensor:
42+
"""
43+
Using the transform location, determine how to apply the transform weight to the
44+
given value. For more info on input and output transforms, see `TransformLocation`
45+
46+
The following explains how weights should be applied to values according to location
47+
48+
let x be input activation
49+
W be weight,
50+
yh, xh, Wh be transformed output, input, weight
51+
52+
note that
53+
y = (x W.T) // torch.nn.Linear
54+
55+
Choose values for yh, xh, and Wh which incorporate matrix transforms
56+
57+
let V, Vi be transform matrices on input side
58+
U, Ui be transform matrices on output side
59+
60+
pick xh = (x V)
61+
Wh = (U.T W Vi.T)
62+
yh = (y U)
63+
64+
The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
65+
66+
(xh) (Wh).T = (x V) (U.T W Vi.T).T
67+
= (x V) (Vi W.T U) // transpose matrix product identity
68+
= (x W.T) U
69+
= y U
70+
= yh
71+
72+
:param weight: transform weight to apply
73+
:param value: value to apply weight to
74+
:param location: determines how weight should be applied
75+
:return: value after transform weight has been applied
76+
"""
77+
78+
if location == TransformLocation.INPUT:
79+
return value @ weight
80+
81+
elif location == TransformLocation.WEIGHT_INPUT:
82+
return value @ weight.T
83+
84+
elif location == TransformLocation.WEIGHT_OUTPUT:
85+
return weight.T @ value
86+
87+
elif location == TransformLocation.OUTPUT:
88+
return value @ weight
89+
90+
else:
91+
raise NotImplementedError(f"{location} has not been implemented yet")

src/compressed_tensors/utils/helpers.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import warnings
1617
from functools import wraps
1718
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
@@ -38,6 +39,8 @@
3839
"shard_tensor",
3940
"pack_bitmasks",
4041
"unpack_bitmasks",
42+
"patch_attr",
43+
"ParameterizedDefaultDict",
4144
]
4245

4346
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -328,3 +331,53 @@ def unpack_bitmasks(
328331
)
329332

330333
return unpacked_bitmasks_torch
334+
335+
336+
@contextlib.contextmanager
337+
def patch_attr(base: object, attr: str, value: Any):
338+
"""
339+
Patch the value of an object attribute. Original value is restored upon exit
340+
341+
:param base: object which has the attribute to patch
342+
:param attr: name of the the attribute to patch
343+
:param value: used to replace original value
344+
345+
Usage:
346+
>>> from types import SimpleNamespace
347+
>>> obj = SimpleNamespace()
348+
>>> with patch_attr(obj, "attribute", "value"):
349+
... assert obj.attribute == "value"
350+
>>> assert not hasattr(obj, "attribute")
351+
"""
352+
_sentinel = object()
353+
original_value = getattr(base, attr, _sentinel)
354+
355+
setattr(base, attr, value)
356+
try:
357+
yield
358+
finally:
359+
if original_value is not _sentinel:
360+
setattr(base, attr, original_value)
361+
else:
362+
delattr(base, attr)
363+
364+
365+
class ParameterizedDefaultDict(dict):
366+
"""
367+
Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
368+
the key is passed as arguments to the `default_factory`
369+
370+
:param default_factory: function which takes a key as input and returns the
371+
corresponding default value
372+
"""
373+
374+
def __init__(self, default_factory: Callable[[Any], Any]):
375+
self.default_factory = default_factory
376+
377+
def __missing__(self, key):
378+
if isinstance(key, tuple):
379+
value = self.default_factory(*key)
380+
else:
381+
value = self.default_factory(key)
382+
self[key] = value
383+
return value

0 commit comments

Comments
 (0)