Skip to content

Commit 7c02bb2

Browse files
committed
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_factory
2 parents b117523 + cb1cb52 commit 7c02bb2

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Tuple
16+
from typing import Optional, Tuple
1717

1818
import numpy
1919
import torch
@@ -58,21 +58,20 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
5858
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
5959

6060

61-
def random_hadamard_matrix(size: int) -> torch.Tensor:
61+
def random_hadamard_matrix(
62+
size: int, gen: Optional[torch.Generator] = None
63+
) -> torch.Tensor:
6264
"""
6365
Produces a randomly generated Hadamard matrix.
6466
See https://cornell-relaxml.github.io/quip-sharp/ ,
6567
Section "Randomized Hadamard Transformation"
6668
67-
:param size: The dimension of the matrix. Matrix generated will have dimensions
68-
(size, size)
69-
69+
:param size: The dimension of the hamadard matrix
70+
:param gen: Optional generator random values
71+
:return: randomly generated hadamard matrix
7072
"""
71-
# TODO: potentially update to add "seed" as an arugment, to allow
72-
# the matrix generated to be reproducible
73-
7473
# Benefits: support other shapes / non powers of 2, support randomization
75-
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
74+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
7675
Q = Q * 2 - 1
7776
Q = torch.diag(Q)
7877
return _matmul_hadU(Q) / math.sqrt(size)

tests/test_transform/utils/test_hadamard.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,30 @@ def test_random_hadamard_matrix_compliant(size):
4949
assert torch.equal(product, torch.eye(size))
5050

5151

52+
def test_random_hadamard_generator():
53+
generator = torch.Generator().manual_seed(42)
54+
one = random_hadamard_matrix(2048, generator)
55+
two = random_hadamard_matrix(2048, generator)
56+
57+
one_true = torch.tensor(
58+
[
59+
[-1, -1, -1],
60+
[+1, -1, +1],
61+
[-1, -1, +1],
62+
]
63+
)
64+
two_true = torch.tensor(
65+
[
66+
[-1, -1, -1],
67+
[-1, +1, -1],
68+
[+1, +1, -1],
69+
]
70+
)
71+
72+
assert torch.all(one[:3, :3].sign() == one_true.sign())
73+
assert torch.all(two[:3, :3].sign() == two_true.sign())
74+
75+
5276
@pytest.mark.parametrize(
5377
"size",
5478
[1024],

0 commit comments

Comments
 (0)