Skip to content

Commit fbaf47a

Browse files
committed
add device option
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent c373345 commit fbaf47a

File tree

4 files changed

+36
-24
lines changed

4 files changed

+36
-24
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def create_transform(self, module: Module, args: TransformArgs):
6060

6161
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6262
data = deterministic_hadamard_matrix(size, dtype=dtype)
63-
data = data.to(dtype=dtype, device=device)
63+
data = data.to(device=device)
6464
return Parameter(data, requires_grad=self.scheme.requires_grad)
6565

6666

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ class RandomHadamardFactory(HadamardFactory):
3030

3131
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
3232
data = random_hadamard_matrix(size, dtype=dtype, gen=self.generator)
33-
data = data.to(dtype=dtype, device=device)
33+
data = data.to(device=device)
3434
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232

3333
def deterministic_hadamard_matrix(
34-
size: int, dtype: torch.dtype = torch.bfloat16
34+
size: int,
35+
dtype: torch.dtype = torch.bfloat16,
36+
device: torch.device = torch.device("cpu"),
3537
) -> torch.Tensor:
3638
"""
3739
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
@@ -49,7 +51,7 @@ def deterministic_hadamard_matrix(
4951
if size != 2**log2:
5052
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
5153

52-
H = torch.tensor([[1]], dtype=dtype)
54+
H = torch.tensor([[1]], dtype=dtype, device=device)
5355

5456
# Sylvester's construction
5557
for _ in range(0, log2):
@@ -61,6 +63,7 @@ def deterministic_hadamard_matrix(
6163
def random_hadamard_matrix(
6264
size: int,
6365
dtype: torch.dtype = torch.bfloat16,
66+
device: torch.device = torch.device("cpu"),
6467
gen: Optional[torch.Generator] = None,
6568
) -> torch.Tensor:
6669
"""
@@ -75,7 +78,9 @@ def random_hadamard_matrix(
7578
:return: randomly generated hadamard matrix
7679
"""
7780
# Benefits: support other shapes / non powers of 2, support randomization
78-
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype)
81+
Q = torch.randint(
82+
low=0, high=2, size=(size,), generator=gen, dtype=dtype, device=device
83+
)
7984
Q = Q * 2 - 1
8085
Q = torch.diag(Q)
8186
return _matmul_hadU(Q) / math.sqrt(size)
@@ -86,16 +91,18 @@ def is_pow2(n: int) -> bool:
8691

8792

8893
def _get_known_divisor(
89-
n: int, dtype: torch.dtype, file_path: str = REPO_PATH
94+
n: int,
95+
dtype: torch.dtype,
96+
device: torch.device = torch.device("cpu"),
97+
file_path: str = REPO_PATH,
9098
) -> Optional[torch.Tensor]:
9199
"""
92100
Fetch a known hadamard matrix from the given file path. The returned matrix will
93101
be of of size `k` such that `n / k` is a power of two. Return None if no such
94102
matrix exists.
95103
96104
Note: This function reopens the safetensors file every time it is called.
97-
This is inefficient, but inconsequential because hadamards are typically
98-
cached by size through the factory that produced them. This is also simpler
105+
This is technically inefficient, but a very small runtime cost and simpler
99106
than forcing callers to manage the file open context
100107
101108
:param n: size of known hadamard matrix
@@ -105,17 +112,18 @@ def _get_known_divisor(
105112
divisors = sorted([int(key) for key in file.keys()], reverse=True)
106113
for divisor in divisors:
107114
if n % divisor == 0 and is_pow2(n // divisor):
108-
return file.get_tensor(str(divisor)).to(dtype=dtype)
115+
return file.get_tensor(str(divisor)).to(dtype=dtype, device=device)
109116

110117
return None
111118

112119

113120
def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
114-
size = X.shape[-1]
121+
size = X.size(0)
115122
dtype = X.dtype
123+
device = X.device
116124

117125
# Check if we have the determined hadamard matrix
118-
hadK = _get_known_divisor(size, dtype)
126+
hadK = _get_known_divisor(size, dtype, device=device)
119127
if hadK is None:
120128
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
121129
K = hadK.size(0)
@@ -130,6 +138,7 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
130138
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
131139
output = output.view(input.shape[0], input.shape[1], -1)
132140
(input, output) = (output, input)
141+
assert input.shape[1] == K
133142
del output
134143

135144
# Do not explicitly repeat - OOM

tests/test_transform/utils/test_hadamard.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,22 @@
3232
3584, # qwen_2_5_vl
3333
3840, # qwen_2_5_vl vision qkv
3434
4096, # llama3
35+
7168, # deepseek_v3
3536
14336, # llama3 intermediate
37+
18432, # deepseek_v3 intermediate
3638
18944, # qwen_2_5_vl intermediate
3739
]
40+
_atol = 1e-1 # bfloat16 is low precision for large matrices
3841

3942

4043
@requires_gpu
4144
@pytest.mark.parametrize("size", _sizes_to_test)
4245
def test_random_hadamard_matrix_compliant(size):
4346
# (H / sqrt(n))(H.T / sqrt(n)) == I
44-
with torch.device("cuda"):
45-
had_matrix = random_hadamard_matrix(size)
46-
product = torch.round(had_matrix @ had_matrix.T)
47-
assert torch.allclose(product, torch.eye(size, dtype=product.dtype), atol=1e-5)
47+
had_matrix = random_hadamard_matrix(size, device="cuda")
48+
product = had_matrix @ had_matrix.T
49+
eye = torch.eye(size, dtype=product.dtype, device="cuda")
50+
assert torch.allclose(product, eye, atol=_atol)
4851

4952

5053
def test_random_hadamard_generator():
@@ -75,13 +78,13 @@ def test_random_hadamard_generator():
7578
@requires_gpu
7679
@pytest.mark.parametrize("size", _sizes_to_test)
7780
def test_deterministic_hadamard_compliant(size):
78-
with torch.device("cuda"):
79-
if not is_pow2(size):
80-
with pytest.raises(ValueError):
81-
had_matrix = deterministic_hadamard_matrix(size)
82-
return
81+
if not is_pow2(size):
82+
with pytest.raises(ValueError):
83+
matrix = deterministic_hadamard_matrix(size, device="cuda")
84+
return
8385

84-
# (H / sqrt(n))(H.T / sqrt(n)) == I
85-
had_matrix = deterministic_hadamard_matrix(size)
86-
product = had_matrix @ had_matrix.T
87-
assert torch.allclose(product, torch.eye(size, dtype=product.dtype), atol=1e-5)
86+
# (H / sqrt(n))(H.T / sqrt(n)) == I
87+
matrix = deterministic_hadamard_matrix(size, device="cuda")
88+
product = matrix @ matrix.T
89+
eye = torch.eye(size, dtype=product.dtype, device="cuda")
90+
assert torch.allclose(product, eye, atol=_atol)

0 commit comments

Comments
 (0)