Skip to content

Commit 852b1fa

Browse files
authored
[Transform] Extend set of known Hadamard matrices (#351)
* use hadamards database file Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * try manifest Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * try setup, update hadamards list Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix setup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstrings, cleanup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix setup, thank you @dbarbuzzi Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove numpy, add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * solidify dtype, add gpu tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add device option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct on execution device, cache on offload device Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * save construction device changes for later Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * cite nja sloane Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * put on device via safe_open Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * nits and docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 54f5b4e commit 852b1fa

File tree

6 files changed

+132
-126
lines changed

6 files changed

+132
-126
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,6 @@ def _setup_extras() -> Dict:
113113
extras_require=_setup_extras(),
114114
install_requires=_setup_install_requires(),
115115
package_dir={"": "src"},
116+
package_data={"": ["transform/utils/hadamards.safetensors"]},
116117
packages=_setup_packages(),
117118
)

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5959
return HadamardTransform(weight, args)
6060

6161
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62-
data = deterministic_hadamard_matrix(size)
62+
data = deterministic_hadamard_matrix(size, dtype, device)
6363
data = data.to(dtype=dtype, device=device)
6464
return Parameter(data, requires_grad=self.scheme.requires_grad)
6565

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ class RandomHadamardFactory(HadamardFactory):
2929
"""
3030

3131
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32-
data = random_hadamard_matrix(size, self.generator)
32+
data = random_hadamard_matrix(size, dtype, device, self.generator)
3333
data = data.to(dtype=dtype, device=device)
3434
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 91 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -13,149 +13,148 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, Tuple
16+
from pathlib import Path
17+
from typing import Optional
1718

18-
import numpy
1919
import torch
20+
from safetensors import safe_open
2021

2122

22-
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
23+
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
2324

24-
# adapted from:
25-
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
26-
def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
25+
26+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
27+
28+
29+
# note that hadamard matrix multiplication can be accelerated using a library such as
30+
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
31+
32+
33+
def deterministic_hadamard_matrix(
34+
size: int,
35+
dtype: torch.dtype = torch.bfloat16,
36+
device: torch.device = torch.device("cpu"),
37+
) -> torch.Tensor:
2738
"""
2839
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
2940
`n` must be a power of 2.
3041
42+
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
43+
3144
:param size: order of the matrix, must be a power of 2
45+
:param dtype: data type of matrix
46+
:param device: device to construct matrix on
3247
:return: hadamard matrix of size `size`
3348
"""
3449
if size <= 0:
3550
raise ValueError("Cannot construct deterministic hadamard of size <= 0")
3651

37-
log2 = int(math.log(size, 2))
52+
log2 = int(math.log2(size))
3853
if size != 2**log2:
3954
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
4055

41-
H = numpy.array([[1]], dtype=int)
56+
H = torch.tensor([[1]], dtype=dtype, device=device)
4257

4358
# Sylvester's construction
44-
for i in range(0, log2):
45-
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
46-
47-
return torch.from_numpy(H / math.sqrt(size))
59+
for _ in range(log2):
60+
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
4861

49-
50-
# adapted from:
51-
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
52-
53-
# TODO: the following library exists for online rotations and should be considered
54-
# in the future:
55-
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
62+
return H / math.sqrt(size)
5663

5764

5865
def random_hadamard_matrix(
59-
size: int, gen: Optional[torch.Generator] = None
66+
size: int,
67+
dtype: torch.dtype = torch.bfloat16,
68+
device: torch.device = torch.device("cpu"),
69+
gen: Optional[torch.Generator] = None,
6070
) -> torch.Tensor:
6171
"""
62-
Produces a randomly generated Hadamard matrix.
63-
See https://cornell-relaxml.github.io/quip-sharp/ ,
64-
Section "Randomized Hadamard Transformation"
72+
Produces a randomly generated Hadamard matrix. Differs from
73+
`deterministic_hadamard_matrix` in that this function supports non powers of 2
74+
and randomization using a seeded generator
75+
76+
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
77+
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
6578
6679
:param size: The dimension of the hamadard matrix
80+
:param dtype: data type of matrix
81+
:param device: device to construct matrix on
6782
:param gen: Optional generator random values
6883
:return: randomly generated hadamard matrix
6984
"""
70-
# Benefits: support other shapes / non powers of 2, support randomization
71-
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
85+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
86+
Q = Q.to(device=device)
7287
Q = Q * 2 - 1
7388
Q = torch.diag(Q)
7489
return _matmul_hadU(Q) / math.sqrt(size)
7590

7691

77-
def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
78-
# NOTE: we can easily extend the list of supported shapes/sizes
79-
# by adding to these methods
80-
hadK, K = None, None
81-
if n % 20 == 0:
82-
assert _is_pow2(n // 20)
83-
K = 20
84-
hadK = _get_had20().T if transpose else _get_had20()
85-
elif n % 12 == 0:
86-
assert _is_pow2(n // 12)
87-
K = 12
88-
hadK = _get_had12().T if transpose else _get_had12()
89-
else:
90-
assert _is_pow2(n)
91-
K = 1
92+
def is_pow2(n: int) -> bool:
93+
"""
94+
Check if a number is a power of 2
9295
93-
return hadK, K
96+
:param n: number to check
97+
:return: True iff `n` is a power of 2
98+
"""
99+
return n > 0 and (n & (n - 1) == 0)
100+
101+
102+
def _fetch_hadamard_divisor(
103+
n: int,
104+
dtype: torch.dtype,
105+
device: torch.device = torch.device("cpu"),
106+
file_path: str = REPO_PATH,
107+
) -> Optional[torch.Tensor]:
108+
"""
109+
Fetch a known hadamard matrix from the given file path. The returned matrix will
110+
be of of size `k` such that `n / k` is a power of two. Return None if no such
111+
matrix exists.
94112
113+
Note: This function reopens the safetensors file every time it is called.
114+
This is technically inefficient, but a very small runtime cost and simpler
115+
than forcing callers to manage the file open context
116+
117+
:param n: size of known hadamard matrix
118+
:return: a known hadamard matrix of size `n` if one exists, else None
119+
"""
120+
with safe_open(file_path, framework="pt", device=str(device)) as file:
121+
divisors = sorted((int(key) for key in file.keys()), reverse=True)
122+
for divisor in divisors:
123+
if n % divisor == 0 and is_pow2(n // divisor):
124+
return file.get_tensor(str(divisor)).to(dtype=dtype)
125+
126+
return None
127+
128+
129+
def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
130+
size = X.size(0)
131+
dtype = X.dtype
132+
device = X.device
95133

96-
def _matmul_hadU(X, transpose=False) -> torch.Tensor:
97-
n = X.shape[-1]
98134
# Check if we have the determined hadamard matrix
99-
hadK, K = _get_hadK(n, transpose)
135+
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
136+
if hadK is None:
137+
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
138+
K = hadK.size(0)
139+
100140
# Reshape diag matrix with randomized -1/+1
101-
input = X.clone().view(-1, n, 1)
141+
input = X.clone().view(-1, size, 1)
102142
output = input.clone()
103-
104-
# for cases when hadK is not predetermined, determine hadamard matrix
105143
while input.shape[1] > K:
106144
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
107145
output = output.view(input.shape)
108146
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
109147
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
110148
output = output.view(input.shape[0], input.shape[1], -1)
111149
(input, output) = (output, input)
150+
assert input.shape[1] == K
112151
del output
113152

114-
# K == 1 when hadK is None; this happens when the size dim (n)
115-
# is not comaptible with any of the maintained hadamard matrices
116-
117-
if K > 1:
118-
# Do not explicitly repeat - OOM
119-
# input = torch.bmm(
120-
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
121-
# Use bcast instead
122-
123-
# for cases when hadK is pre-determined
124-
input = hadK.view(1, K, K).to(input) @ input
153+
# Do not explicitly repeat - OOM
154+
# input = torch.bmm(
155+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
156+
# Use bcast instead
157+
input = hadK.view(1, K, K).to(input) @ input
125158

126159
# normalize
127160
return input.view(X.shape)
128-
129-
130-
def _is_pow2(n: int) -> bool:
131-
return (n & (n - 1) == 0) and (n > 0)
132-
133-
134-
def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
135-
had_unpacked = numpy.unpackbits(packed_bits)
136-
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
137-
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
138-
return had_unpacked
139-
140-
141-
# http://www.neilsloane.com/hadamard/index.html
142-
def _get_had12() -> torch.Tensor:
143-
# fmt: off
144-
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
145-
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
146-
# fmt: on
147-
# TODO: just unpack during apply
148-
had_12_unpacked = _reshape_bits(had_12, original_size=12)
149-
return torch.tensor(had_12_unpacked)
150-
151-
152-
def _get_had20() -> torch.Tensor:
153-
# fmt: off
154-
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
155-
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
156-
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
157-
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
158-
# fmt: on
159-
# TODO: just unpack during apply
160-
had_20_unpacked = _reshape_bits(had_20, original_size=20)
161-
return torch.tensor(had_20_unpacked)
Binary file not shown.

tests/test_transform/utils/test_hadamard.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,46 +13,48 @@
1313
# limitations under the License.
1414

1515

16-
import numpy
1716
import pytest
1817
import torch
1918
from compressed_tensors.transform.utils.hadamard import (
20-
_get_had12,
21-
_get_had20,
2219
deterministic_hadamard_matrix,
20+
is_pow2,
2321
random_hadamard_matrix,
2422
)
23+
from tests.testing_utils import requires_gpu
2524

2625

27-
@pytest.mark.parametrize(
28-
"had_func",
29-
[
30-
_get_had12,
31-
_get_had20,
32-
],
33-
)
34-
def test_packed_hadamard_compliant(had_func):
35-
had_matrix = had_func()
36-
size = had_matrix.size(0)
37-
# HH.T == nI
38-
product = had_matrix @ had_matrix.T
39-
assert torch.equal(product, size * torch.eye(size))
26+
_sizes_to_test = [
27+
768, # gpt2 small
28+
1024, # gpt2 medium
29+
1280, # qwen_2_5_vl vision
30+
1600, # gpt2 xl
31+
2048, # gpt3 small
32+
3584, # qwen_2_5_vl
33+
3840, # qwen_2_5_vl vision qkv
34+
4096, # llama3
35+
7168, # deepseek_v3
36+
14336, # llama3 intermediate
37+
18432, # deepseek_v3 intermediate
38+
18944, # qwen_2_5_vl intermediate
39+
]
40+
_atol = 1e-1 # bfloat16 is low precision for large matrices
4041

4142

42-
@pytest.mark.parametrize(
43-
"size",
44-
[4096, 2048],
45-
)
43+
@requires_gpu
44+
@pytest.mark.parametrize("size", _sizes_to_test)
4645
def test_random_hadamard_matrix_compliant(size):
47-
had_matrix = random_hadamard_matrix(size)
48-
product = torch.round(had_matrix @ had_matrix.T)
49-
assert torch.equal(product, torch.eye(size))
46+
# (H / sqrt(n))(H.T / sqrt(n)) == I
47+
matrix = random_hadamard_matrix(size, device="cuda")
48+
product = matrix @ matrix.T
49+
eye = torch.eye(size, dtype=product.dtype, device="cuda")
50+
assert torch.allclose(product, eye, atol=_atol)
5051

5152

5253
def test_random_hadamard_generator():
54+
# check that generation is deterministic with a seed
5355
generator = torch.Generator().manual_seed(42)
54-
one = random_hadamard_matrix(2048, generator)
55-
two = random_hadamard_matrix(2048, generator)
56+
one = random_hadamard_matrix(2048, gen=generator)
57+
two = random_hadamard_matrix(2048, gen=generator)
5658

5759
one_true = torch.tensor(
5860
[
@@ -73,12 +75,16 @@ def test_random_hadamard_generator():
7375
assert torch.all(two[:3, :3].sign() == two_true.sign())
7476

7577

76-
@pytest.mark.parametrize(
77-
"size",
78-
[1024],
79-
)
78+
@requires_gpu
79+
@pytest.mark.parametrize("size", _sizes_to_test)
8080
def test_deterministic_hadamard_compliant(size):
81-
had_matrix = deterministic_hadamard_matrix(size)
81+
if not is_pow2(size):
82+
with pytest.raises(ValueError):
83+
matrix = deterministic_hadamard_matrix(size, device="cuda")
84+
return
85+
8286
# (H / sqrt(n))(H.T / sqrt(n)) == I
83-
product = had_matrix @ had_matrix.T
84-
assert numpy.array_equal(product, numpy.eye(size))
87+
matrix = deterministic_hadamard_matrix(size, device="cuda")
88+
product = matrix @ matrix.T
89+
eye = torch.eye(size, dtype=product.dtype, device="cuda")
90+
assert torch.allclose(product, eye, atol=_atol)

0 commit comments

Comments
 (0)