Skip to content

Commit c8f6b53

Browse files
committed
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_construct_cache_device
2 parents c1a4a34 + feba695 commit c8f6b53

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414

1515
import math
16-
import os
16+
from pathlib import Path
1717
from typing import Optional
1818

1919
import torch
2020
from safetensors import safe_open
2121

2222

23-
REPO_PATH = os.path.join(os.path.dirname(__file__), "hadamards.safetensors")
23+
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
2424

2525

2626
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
@@ -42,6 +42,8 @@ def deterministic_hadamard_matrix(
4242
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
4343
4444
:param size: order of the matrix, must be a power of 2
45+
:param dtype: data type of matrix
46+
:param device: device to construct matrix on
4547
:return: hadamard matrix of size `size`
4648
"""
4749
if size <= 0:
@@ -54,7 +56,7 @@ def deterministic_hadamard_matrix(
5456
H = torch.tensor([[1]], dtype=dtype, device=device)
5557

5658
# Sylvester's construction
57-
for _ in range(0, log2):
59+
for _ in range(log2):
5860
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
5961

6062
return H / math.sqrt(size)
@@ -67,17 +69,16 @@ def random_hadamard_matrix(
6769
gen: Optional[torch.Generator] = None,
6870
) -> torch.Tensor:
6971
"""
70-
Produces a randomly generated Hadamard matrix.
71-
See https://cornell-relaxml.github.io/quip-sharp/ ,
72-
Section "Randomized Hadamard Transformation"
73-
74-
Improves upon deterministic_hadamard_matrix
75-
in that this supports non powers of 2 and random seeds
72+
Produces a randomly generated Hadamard matrix. Differs from
73+
`deterministic_hadamard_matrix` in that this function supports non powers of 2
74+
and randomization using a seeded generator
7675
7776
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
7877
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
7978
8079
:param size: The dimension of the hamadard matrix
80+
:param dtype: data type of matrix
81+
:param device: device to construct matrix on
8182
:param gen: Optional generator random values
8283
:return: randomly generated hadamard matrix
8384
"""
@@ -89,10 +90,16 @@ def random_hadamard_matrix(
8990

9091

9192
def is_pow2(n: int) -> bool:
92-
return (n & (n - 1) == 0) and (n > 0)
93+
"""
94+
Check if a number is a power of 2
95+
96+
:param n: number to check
97+
:return: True iff `n` is a power of 2
98+
"""
99+
return n > 0 and (n & (n - 1) == 0)
93100

94101

95-
def _get_known_divisor(
102+
def _fetch_hadamard_divisor(
96103
n: int,
97104
dtype: torch.dtype,
98105
device: torch.device = torch.device("cpu"),
@@ -110,11 +117,11 @@ def _get_known_divisor(
110117
:param n: size of known hadamard matrix
111118
:return: a known hadamard matrix of size `n` if one exists, else None
112119
"""
113-
with safe_open(file_path, framework="pt", device="cpu") as file:
114-
divisors = sorted([int(key) for key in file.keys()], reverse=True)
120+
with safe_open(file_path, framework="pt", device=str(device)) as file:
121+
divisors = sorted((int(key) for key in file.keys()), reverse=True)
115122
for divisor in divisors:
116123
if n % divisor == 0 and is_pow2(n // divisor):
117-
return file.get_tensor(str(divisor)).to(dtype=dtype, device=device)
124+
return file.get_tensor(str(divisor)).to(dtype=dtype)
118125

119126
return None
120127

@@ -125,7 +132,7 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
125132
device = X.device
126133

127134
# Check if we have the determined hadamard matrix
128-
hadK = _get_known_divisor(size, dtype, device=device)
135+
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
129136
if hadK is None:
130137
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
131138
K = hadK.size(0)

0 commit comments

Comments
 (0)