Skip to content

Commit ecbe770

Browse files
authored
[Transform] Do not fuse div operation into hadamard matrices (#395)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 76b3023 commit ecbe770

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Optional, Union
1616

17+
import math
1718
import torch
1819
from compressed_tensors.transform import TransformArgs, TransformScheme
1920
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
@@ -87,6 +88,7 @@ def __init__(
8788
self.weight = weight
8889
self.perm = perm
8990
self.args = args
91+
self._scale = math.sqrt(weight.size(0))
9092

9193
def forward(self, value: Tensor) -> Tensor:
9294
weight = self.weight
@@ -97,4 +99,4 @@ def forward(self, value: Tensor) -> Tensor:
9799
if self.args.inverse:
98100
weight = weight.T
99101

100-
return apply_transform_weight(weight, value, self.args.location)
102+
return apply_transform_weight(weight, value, self.args.location) / self._scale

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def deterministic_hadamard_matrix(
5959
for _ in range(log2):
6060
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
6161

62-
return H / math.sqrt(size)
62+
return H
6363

6464

6565
def random_hadamard_matrix(
@@ -86,7 +86,7 @@ def random_hadamard_matrix(
8686
Q = Q.to(device=device)
8787
Q = Q * 2 - 1
8888
Q = torch.diag(Q)
89-
return _matmul_hadU(Q) / math.sqrt(size)
89+
return _matmul_hadU(Q)
9090

9191

9292
def is_pow2(n: int) -> bool:

tests/test_transform/utils/test_hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
def test_random_hadamard_matrix_compliant(size):
4646
# (H / sqrt(n))(H.T / sqrt(n)) == I
4747
matrix = random_hadamard_matrix(size, device="cuda")
48-
product = matrix @ matrix.T
48+
product = (matrix @ matrix.T) / matrix.size(0)
4949
eye = torch.eye(size, dtype=product.dtype, device="cuda")
5050
assert torch.allclose(product, eye, atol=_atol)
5151

@@ -85,6 +85,6 @@ def test_deterministic_hadamard_compliant(size):
8585

8686
# (H / sqrt(n))(H.T / sqrt(n)) == I
8787
matrix = deterministic_hadamard_matrix(size, device="cuda")
88-
product = matrix @ matrix.T
88+
product = (matrix @ matrix.T) / matrix.size(0)
8989
eye = torch.eye(size, dtype=product.dtype, device="cuda")
9090
assert torch.allclose(product, eye, atol=_atol)

0 commit comments

Comments
 (0)