Skip to content

Commit 9d0518b

Browse files
committed
add utils and tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 859fd90 commit 9d0518b

File tree

6 files changed

+417
-1
lines changed

6 files changed

+417
-1
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Tuple
17+
18+
import numpy
19+
import torch
20+
21+
22+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
23+
24+
# adapted from:
25+
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
26+
def deterministic_hadamard_matrix(size: int) -> numpy.ndarray:
27+
"""
28+
Construct an Hadamard matrix.
29+
30+
Constructs an n-by-n Hadamard matrix, using Sylvester's
31+
construction. `n` must be a power of 2.
32+
33+
:param size: order of the matrix; must be a power of 2
34+
35+
returns a (size, size) hadamard matrix
36+
"""
37+
if size <= 0:
38+
raise ValueError("Cannot construct deterministic hadamard of size <= 0")
39+
40+
log2 = int(math.log(size, 2))
41+
if size != 2**log2:
42+
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
43+
44+
H = numpy.array([[1]], dtype=int)
45+
46+
# Sylvester's construction
47+
for i in range(0, log2):
48+
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
49+
50+
return H
51+
52+
53+
# adapted from:
54+
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
55+
56+
# TODO: the following library exists for online rotations and should be considered
57+
# in the future:
58+
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
59+
60+
61+
def random_hadamard_matrix(size: int) -> torch.Tensor:
62+
"""
63+
Produces a randomly generated Hadamard matrix.
64+
See https://cornell-relaxml.github.io/quip-sharp/ ,
65+
Section "Randomized Hadamard Transformation"
66+
67+
:param size: The dimension of the matrix. Matrix generated will have dimensions
68+
(size, size)
69+
70+
"""
71+
# TODO: potentially update to add "seed" as an arugment, to allow
72+
# the matrix generated to be reproducible
73+
74+
# Benefits: support other shapes / non powers of 2, support randomization
75+
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
76+
Q = Q * 2 - 1
77+
Q = torch.diag(Q)
78+
return _matmul_hadU(Q)
79+
80+
81+
def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
82+
# NOTE: we can easily extend the list of supported shapes/sizes
83+
# by adding to these methods
84+
hadK, K = None, None
85+
if n % 20 == 0:
86+
assert _is_pow2(n // 20)
87+
K = 20
88+
hadK = _get_had20().T if transpose else _get_had20()
89+
elif n % 12 == 0:
90+
assert _is_pow2(n // 12)
91+
K = 12
92+
hadK = _get_had12().T if transpose else _get_had12()
93+
else:
94+
assert _is_pow2(n)
95+
K = 1
96+
97+
return hadK, K
98+
99+
100+
def _matmul_hadU(X, transpose=False) -> torch.Tensor:
101+
n = X.shape[-1]
102+
# Check if we have the determined hadamard matrix
103+
hadK, K = _get_hadK(n, transpose)
104+
# Reshape diag matrix with randomized -1/+1
105+
input = X.clone().view(-1, n, 1)
106+
output = input.clone()
107+
108+
# for cases when hadK is not predetermined, determine hadamard matrix
109+
while input.shape[1] > K:
110+
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
111+
output = output.view(input.shape)
112+
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
113+
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
114+
output = output.view(input.shape[0], input.shape[1], -1)
115+
(input, output) = (output, input)
116+
del output
117+
118+
# K == 1 when hadK is None; this happens when the size dim (n)
119+
# is not comaptible with any of the maintained hadamard matrices
120+
121+
if K > 1:
122+
# Do not explicitly repeat - OOM
123+
# input = torch.bmm(
124+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
125+
# Use bcast instead
126+
127+
# for cases when hadK is pre-determined
128+
input = hadK.view(1, K, K).to(input) @ input
129+
130+
# normalize
131+
return input.view(X.shape) / torch.tensor(n).sqrt()
132+
133+
134+
def _is_pow2(n: int) -> bool:
135+
return (n & (n - 1) == 0) and (n > 0)
136+
137+
138+
def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
139+
had_unpacked = numpy.unpackbits(packed_bits)
140+
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
141+
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
142+
return had_unpacked
143+
144+
145+
# http://www.neilsloane.com/hadamard/index.html
146+
def _get_had12() -> torch.Tensor:
147+
# fmt: off
148+
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
149+
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
150+
# fmt: on
151+
# TODO: just unpack during apply
152+
had_12_unpacked = _reshape_bits(had_12, original_size=12)
153+
return torch.tensor(had_12_unpacked)
154+
155+
156+
def _get_had20() -> torch.Tensor:
157+
# fmt: off
158+
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
159+
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
160+
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
161+
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
162+
# fmt: on
163+
# TODO: just unpack during apply
164+
had_20_unpacked = _reshape_bits(had_20, original_size=20)
165+
return torch.tensor(had_20_unpacked)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.transform import TransformLocation
17+
18+
19+
__all__ = ["get_matrix_size", "apply_transform_weight"]
20+
21+
22+
def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
23+
"""
24+
Determine the size of a matrix given its location on the module
25+
26+
:param module: module that matrix will be applied to
27+
:param location: location on module
28+
:return: size of matrix
29+
"""
30+
assert isinstance(module, torch.nn.Linear)
31+
if location in ("input", TransformLocation.WEIGHT_INPUT):
32+
return module.in_features
33+
else:
34+
return module.out_features
35+
36+
37+
def apply_transform_weight(
38+
weight: torch.Tensor,
39+
value: torch.Tensor,
40+
location: TransformLocation,
41+
) -> torch.Tensor:
42+
"""
43+
Using the transform location, determine how to apply the transform weight to the
44+
given value
45+
46+
let x be input activation
47+
W be weight,
48+
yh, xh, Wh be transformed output, input, weight
49+
50+
note that
51+
y = (x W.T) // torch.nn.Linear
52+
yh = (xh) (Wh).T // transformed
53+
54+
let V, Vi be transform matrices on input side
55+
U, Ui be transform matrices on output side
56+
57+
show that the following values for yh, xh, and Wh are consistent
58+
59+
pick xh = (x V)
60+
Wh = (U.T W Vi.T)
61+
yh = (y U)
62+
63+
(xh) (Wh).T = (x V) (U.T W Vi.T).T
64+
= (x V) (Vi W.T U) // transpose matrix product identity
65+
= (x W.T) U
66+
= y U
67+
= yh
68+
69+
:param weight: transform weight to apply
70+
:param value: value to apply weight to
71+
:param location: determines how weight should be applied
72+
:return: value after transform weight has been applied
73+
"""
74+
75+
if location == TransformLocation.INPUT:
76+
return value @ weight
77+
78+
elif location == TransformLocation.WEIGHT_INPUT:
79+
return value @ weight.T
80+
81+
elif location == TransformLocation.WEIGHT_OUTPUT:
82+
return weight.T @ value
83+
84+
elif location == TransformLocation.OUTPUT:
85+
return value @ weight

src/compressed_tensors/utils/helpers.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import warnings
1617
from functools import wraps
1718
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
@@ -38,6 +39,8 @@
3839
"shard_tensor",
3940
"pack_bitmasks",
4041
"unpack_bitmasks",
42+
"patch_attr",
43+
"ParameterizedDefaultDict",
4144
]
4245

4346
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -328,3 +331,53 @@ def unpack_bitmasks(
328331
)
329332

330333
return unpacked_bitmasks_torch
334+
335+
336+
@contextlib.contextmanager
337+
def patch_attr(base: object, attr: str, value: Any):
338+
"""
339+
Patch the value of an object attribute. Original value is restored upon exit
340+
341+
:param base: object which has the attribute to patch
342+
:param attr: name of the the attribute to patch
343+
:param value: used to replace original value
344+
345+
Usage:
346+
>>> from types import SimpleNamespace
347+
>>> obj = SimpleNamespace()
348+
>>> with patch_attr(obj, "attribute", "value"):
349+
... assert obj.attribute == "value"
350+
>>> assert not hasattr(obj, "attribute")
351+
"""
352+
_sentinel = object()
353+
original_value = getattr(base, attr, _sentinel)
354+
355+
setattr(base, attr, value)
356+
try:
357+
yield
358+
finally:
359+
if original_value is not _sentinel:
360+
setattr(base, attr, original_value)
361+
else:
362+
delattr(base, attr)
363+
364+
365+
class ParameterizedDefaultDict(dict):
366+
"""
367+
Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
368+
the key is passed as arguments to the `default_factory`
369+
370+
:param default_factory: function which takes a key as input and returns the
371+
corresponding default value
372+
"""
373+
374+
def __init__(self, default_factory: Callable[[Any], Any]):
375+
self.default_factory = default_factory
376+
377+
def __missing__(self, key):
378+
if isinstance(key, tuple):
379+
value = self.default_factory(*key)
380+
else:
381+
value = self.default_factory(key)
382+
self[key] = value
383+
return value
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import numpy
17+
import pytest
18+
import torch
19+
from compressed_tensors.transform.utils.hadamard import (
20+
_get_had12,
21+
_get_had20,
22+
deterministic_hadamard_matrix,
23+
random_hadamard_matrix,
24+
)
25+
26+
27+
@pytest.mark.parametrize(
28+
"had_func",
29+
[
30+
_get_had12,
31+
_get_had20,
32+
],
33+
)
34+
def test_packed_hadamard_compliant(had_func):
35+
had_matrix = had_func()
36+
size = had_matrix.shape[0]
37+
# HH.T == nI
38+
val_1 = had_matrix @ had_matrix.T
39+
assert torch.equal(val_1 / size, torch.eye(size))
40+
41+
42+
@pytest.mark.parametrize(
43+
"size",
44+
[4096, 2048],
45+
)
46+
def test_random_hadamard_matrix_compliant(size):
47+
had_matrix = random_hadamard_matrix(size)
48+
val_1 = torch.round(had_matrix @ had_matrix.T)
49+
assert torch.equal(val_1, torch.eye(size))
50+
51+
52+
@pytest.mark.parametrize(
53+
"size",
54+
[1024],
55+
)
56+
def test_deterministic_hadamard_compliant(size):
57+
had_matrix = deterministic_hadamard_matrix(size)
58+
# HH.T == nI
59+
val_1 = had_matrix @ had_matrix.T
60+
assert numpy.array_equal(val_1 / size, numpy.eye(size))

0 commit comments

Comments
 (0)