Skip to content

Commit ab6101e

Browse files
committed
fix typo, change class, remove long test case
1 parent 1adfa30 commit ab6101e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/compressed_tensors/transforms/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
# first or second matirx in torch.matmul depending on dimensions, can be inferred
2828
# by the layer time likely.
2929

30-
MATIRX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"]
30+
MATRIX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"]
3131

3232

33-
class Transforms(RegistryMixin):
33+
class Transforms(torch.nn.Parameter, RegistryMixin):
3434
def __new__(
3535
cls,
3636
transform: torch.Tensor,
@@ -69,7 +69,7 @@ def __new__(
6969

7070
@classmethod
7171
def fetch_apply(cls, name: str):
72-
if name in MATIRX_TRANSFORMS:
72+
if name in MATRIX_TRANSFORMS:
7373
return apply_matrix_transform
7474
raise NotImplementedError("Only matrix transforms are supported")
7575

tests/test_transforms/test_hadamards.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_random_hadamard_matrix_compliant(size):
5151

5252
@pytest.mark.parametrize(
5353
"size",
54-
[1024, 2048],
54+
[1024],
5555
)
5656
def test_deterministic_hadamard_compliant(size):
5757
had_matrix = deterministic_hadamard_matrix(size)

0 commit comments

Comments
 (0)