Skip to content

Commit ccb88ed

Browse files
committed
nits and docstrings
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 5807ee1 commit ccb88ed

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414

1515
import math
16-
import os
16+
from pathlib import Path
1717
from typing import Optional
1818

1919
import torch
2020
from safetensors import safe_open
2121

2222

23-
REPO_PATH = os.path.join(os.path.dirname(__file__), "hadamards.safetensors")
23+
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
2424

2525

2626
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
@@ -42,6 +42,8 @@ def deterministic_hadamard_matrix(
4242
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
4343
4444
:param size: order of the matrix, must be a power of 2
45+
:param dtype: data type of matrix
46+
:param device: device to construct matrix on
4547
:return: hadamard matrix of size `size`
4648
"""
4749
if size <= 0:
@@ -54,7 +56,7 @@ def deterministic_hadamard_matrix(
5456
H = torch.tensor([[1]], dtype=dtype, device=device)
5557

5658
# Sylvester's construction
57-
for _ in range(0, log2):
59+
for _ in range(log2):
5860
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
5961

6062
return H / math.sqrt(size)
@@ -78,6 +80,8 @@ def random_hadamard_matrix(
7880
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
7981
8082
:param size: The dimension of the hamadard matrix
83+
:param dtype: data type of matrix
84+
:param device: device to construct matrix on
8185
:param gen: Optional generator random values
8286
:return: randomly generated hadamard matrix
8387
"""
@@ -89,10 +93,16 @@ def random_hadamard_matrix(
8993

9094

9195
def is_pow2(n: int) -> bool:
92-
return (n & (n - 1) == 0) and (n > 0)
96+
"""
97+
Check if a number is a power of 2
98+
99+
:param n: number to check
100+
:return: True iff `n` is a power of 2
101+
"""
102+
return n > 0 and (n & (n - 1) == 0)
93103

94104

95-
def _get_known_divisor(
105+
def _fetch_hadamard_divisor(
96106
n: int,
97107
dtype: torch.dtype,
98108
device: torch.device = torch.device("cpu"),
@@ -111,7 +121,7 @@ def _get_known_divisor(
111121
:return: a known hadamard matrix of size `n` if one exists, else None
112122
"""
113123
with safe_open(file_path, framework="pt", device=str(device)) as file:
114-
divisors = sorted([int(key) for key in file.keys()], reverse=True)
124+
divisors = sorted((int(key) for key in file.keys()), reverse=True)
115125
for divisor in divisors:
116126
if n % divisor == 0 and is_pow2(n // divisor):
117127
return file.get_tensor(str(divisor)).to(dtype=dtype)
@@ -125,7 +135,7 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
125135
device = X.device
126136

127137
# Check if we have the determined hadamard matrix
128-
hadK = _get_known_divisor(size, dtype, device=device)
138+
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
129139
if hadK is None:
130140
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
131141
K = hadK.size(0)

0 commit comments

Comments
 (0)