Skip to content

Commit 1e1760b

Browse files
committed
clean-up
1 parent ab6101e commit 1e1760b

File tree

5 files changed

+97
-55
lines changed

5 files changed

+97
-55
lines changed

src/compressed_tensors/transforms/base.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from compressed_tensors.registry.registry import RegistryMixin
1919
from compressed_tensors.transforms.utils import apply_matrix_transform
20+
from compressed_tensors.utils import register_offload_parameter, update_parameter_data
2021

2122

2223
__all__ = ["Transforms"]
@@ -27,18 +28,16 @@
2728
# first or second matirx in torch.matmul depending on dimensions, can be inferred
2829
# by the layer time likely.
2930

30-
MATRIX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"]
3131

32-
33-
class Transforms(torch.nn.Parameter, RegistryMixin):
34-
def __new__(
35-
cls,
32+
class Transforms(RegistryMixin):
33+
def __init__(
34+
self,
3635
transform: torch.Tensor,
36+
learnable: Optional[bool] = True,
3737
device: Optional[Union[str, torch.device]] = "cuda",
3838
dtype: Optional[torch.dtype] = torch.bfloat16,
39-
*args,
40-
**kwargs,
4139
):
40+
self.learnable = learnable
4241
"""
4342
Base class for setting up transforms. The registry creates transforms
4443
as parameters which can be attached to modules.
@@ -48,38 +47,45 @@ def __new__(
4847
size = 1024
4948
dtype = torch.bfloat16
5049
module = torch.nn.Linear(size, size)
50+
name = "weight_transform"
5151
5252
hadamard_transform = Transforms.load_from_registry(
5353
"random_hadamard", size=size, dtype=dtype
5454
)
55-
hadamard_apply = Transforms.fetch_apply("random_hadamard")
56-
module.weight_transform = hadamard_transform
5755
58-
transformed_output = hadamard_apply(input_tensor=module.weight,
59-
transform=moduel.weight_transform)
56+
hadamard_transform.register_to_module(name, module)
57+
module.transform_data = {name: {"call_args": dict, "class": hadamard_transform}}
6058
61-
hadamard_inverse = Transforms.fetch_inverse_apply("random_hadamard")
62-
original_weight = hadamard_inverse(input_tensor=transformed_output,
63-
transform=model.weight_trainsform,
64-
transpose=True)
59+
transformed_output = hadamard_transform.apply(input_tensor=module.weight)
60+
original_weight = hadamard_transform.inverse_apply(
61+
input_tensor=transformed_output)
6562
6663
:param transform: transform (e.g. torch.Tensor, scalar) to be applied
6764
"""
68-
return torch.nn.Parameter(transform.to(device).to(dtype), requires_grad=False)
69-
70-
@classmethod
71-
def fetch_apply(cls, name: str):
72-
if name in MATRIX_TRANSFORMS:
73-
return apply_matrix_transform
74-
raise NotImplementedError("Only matrix transforms are supported")
75-
76-
@classmethod
77-
def fetch_inverse_apply(cls, name: str):
78-
return cls.get_value_from_registry(name=name).inverse_apply
65+
if self.learnable:
66+
self.transform = torch.nn.Parameter(
67+
transform.to(dtype).to(device), requires_grad=False
68+
)
69+
else:
70+
self.transform = torch.nn.Buffer(transform.to(dtype).to(device))
71+
72+
# register to class for easy offloading, serialization, deserialization
73+
def register_to_module(self, name: str, module: torch.nn.Module):
74+
if self.learnable:
75+
register_offload_parameter(module, name, self.transform)
76+
else:
77+
# TODO: have to verify serialization/offloading
78+
module.register_buffer(name, self.transform)
79+
80+
def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
81+
"""
82+
Apply the transform to the module
83+
"""
84+
raise NotImplementedError()
7985

80-
@staticmethod
86+
# TODO: potentially split into its own transform using the same shared set-up
8187
def inverse_apply(
82-
transform: torch.Tensor, input_tensor: torch.Tensor, *args, **kwargs
88+
self, input_tensor: torch.Tensor, *args, **kwargs
8389
) -> torch.Tensor:
8490
"""
8591
Apply the inverse operation applied by the apply method

src/compressed_tensors/transforms/hadamard.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222

2323
@Transforms.register("hadamard")
2424
class Hadamard(Transforms):
25-
def __new__(
26-
cls,
25+
def __init__(
26+
self,
2727
size: int,
2828
empty: Optional[bool] = False,
2929
device: Optional[Union[str, torch.device]] = "cuda",
3030
dtype: Optional[torch.dtype] = torch.bfloat16,
31+
*args,
32+
**kwargs,
3133
):
3234
"""
3335
Produces a hadamard matrix with dims (size, size), with values
@@ -50,11 +52,23 @@ def __new__(
5052
else:
5153
transform = torch.empty((size, size))
5254

53-
return super().__new__(cls, transform=transform, device=device, dtype=dtype)
55+
super().__init__(transform=transform, dtype=dtype, device=device)
56+
57+
def apply(
58+
self,
59+
input_tensor: torch.Tensor,
60+
transpose: bool = False,
61+
first: bool = True,
62+
) -> torch.Tensor:
63+
return apply_matrix_transform(
64+
transform=self.transform,
65+
input_tensor=input_tensor,
66+
transpose=transpose,
67+
first=first,
68+
)
5469

55-
@staticmethod
5670
def inverse_apply(
57-
transform: torch.Tensor,
71+
self,
5872
input_tensor: torch.Tensor,
5973
transpose: bool = False,
6074
first: bool = True,
@@ -73,7 +87,7 @@ def inverse_apply(
7387
# need to normalize before sending back
7488
return (
7589
apply_matrix_transform(
76-
transform=transform,
90+
transform=self.transform,
7791
input_tensor=input_tensor,
7892
transpose=transpose,
7993
first=first,

src/compressed_tensors/transforms/matrix_multiply.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,27 @@
1414

1515
import torch
1616
from compressed_tensors.transforms import Transforms
17+
from compressed_tensors.transforms.utils import apply_matrix_transform
1718

1819

1920
# TODO: fix loading
2021
@Transforms.register("matrix-mul")
2122
class MatrixMultiply(Transforms):
22-
@staticmethod
23+
def apply(
24+
self,
25+
input_tensor: torch.Tensor,
26+
transpose: bool = False,
27+
first: bool = True,
28+
) -> torch.Tensor:
29+
return apply_matrix_transform(
30+
transform=self.transform,
31+
input_tensor=input_tensor,
32+
transpose=transpose,
33+
first=first,
34+
)
35+
2336
def inverse_apply(
24-
transform: torch.Tensor,
37+
self,
2538
input_tensor: torch.Tensor,
2639
transpose: bool = False,
2740
first: bool = True,
@@ -40,7 +53,7 @@ def inverse_apply(
4053
# Note: not implemented for lower precision than float32
4154
transform = torch.linalg.inv(transform)
4255
return apply_matrix_transform(
43-
transform=transform,
56+
transform=self.transform,
4457
input_tensor=input_tensor,
4558
transpose=transpose,
4659
first=first,

src/compressed_tensors/transforms/random_hadamard.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
@Transforms.register("random-hadamard")
2424
class RandomHadamard(Transforms):
25-
def __new__(
26-
cls,
25+
def __init__(
26+
self,
2727
size: int,
2828
empty: Optional[bool] = False,
2929
device: Optional[Union[str, torch.device]] = "cuda",
@@ -58,11 +58,23 @@ def __new__(
5858
else:
5959
transform = torch.empty((size, size))
6060

61-
return super().__new__(cls, transform=transform, device=device, dtype=dtype)
61+
super().__init__(transform=transform, device=device, dtype=dtype)
62+
63+
def apply(
64+
self,
65+
input_tensor: torch.Tensor,
66+
transpose: bool = False,
67+
first: bool = True,
68+
) -> torch.Tensor:
69+
return apply_matrix_transform(
70+
transform=self.transform,
71+
input_tensor=input_tensor,
72+
transpose=transpose,
73+
first=first,
74+
)
6275

63-
@staticmethod
6476
def inverse_apply(
65-
transform: torch.Tensor,
77+
self,
6678
input_tensor: torch.Tensor,
6779
transpose: bool = False,
6880
first: bool = True,
@@ -80,7 +92,7 @@ def inverse_apply(
8092

8193
transpose = not transpose
8294
return apply_matrix_transform(
83-
transform=transform,
95+
transform=self.transform,
8496
input_tensor=input_tensor,
8597
transpose=transpose,
8698
first=first,

tests/test_transforms/test_transforms.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,18 @@ def test_random_hadamard_transform(size: int, dtype: torch.dtype):
4444
# check initialize
4545
assert hadamard_transform is not None
4646

47-
val_1 = torch.round(hadamard_transform @ hadamard_transform.T)
47+
val_1 = torch.round(hadamard_transform.transform @ hadamard_transform.transform.T)
4848

4949
# output will be normalized, multiply by sqrt(size) to ensure form
50-
normalized = math.sqrt(size) * hadamard_transform
50+
normalized = math.sqrt(size) * hadamard_transform.transform
5151
# all values should be -1 or +1
5252
assert torch.all(torch.isin(normalized, torch.Tensor([-1, +1])))
5353
# check creation; HH.T == I
5454
assert torch.equal(val_1, torch.eye(size))
5555

5656
# check apply
5757
x = torch.rand((size, size), dtype=dtype)
58-
apply = Transforms.fetch_apply("random-hadamard")
59-
transformed_value = apply(input_tensor=x, transform=hadamard_transform)
58+
transformed_value = hadamard_transform.apply(input_tensor=x)
6059
# TODO: check to make sure the matrix was applied correctly?
6160
assert transformed_value.shape == (size, size)
6261

@@ -75,16 +74,15 @@ def test_deterministic_hadamard_transform(size: int, dtype: torch.dtype):
7574

7675
# check initialize
7776
assert hadamard_transform is not None
78-
assert torch.all(torch.isin(hadamard_transform, torch.Tensor([-1, +1])))
77+
assert torch.all(torch.isin(hadamard_transform.transform, torch.Tensor([-1, +1])))
7978

80-
val_1 = hadamard_transform @ hadamard_transform.T
79+
val_1 = hadamard_transform.transform @ hadamard_transform.transform.T
8180
# check creation; HH.T == nI
8281
assert torch.equal(val_1 / size, torch.eye(size))
8382

8483
# check apply
8584
x = torch.rand((size, size), dtype=dtype)
86-
apply = Transforms.fetch_apply("hadamard")
87-
transformed_value = apply(input_tensor=x, transform=hadamard_transform)
85+
transformed_value = hadamard_transform.apply(input_tensor=x)
8886
# TODO: check to make sure the matrix was applied correctly?
8987
assert transformed_value.shape == (size, size)
9088

@@ -103,9 +101,8 @@ def test_multiplier_transform(size: int, dtype: torch.dtype):
103101
"matrix-mul", transform=multiplier, device="cpu", dtype=dtype
104102
)
105103
assert multiplier_transform is not None
106-
assert torch.equal(multiplier_transform, multiplier)
104+
assert torch.equal(multiplier_transform.transform, multiplier)
107105

108106
x = torch.rand((size, size), dtype=dtype)
109-
apply = Transforms.fetch_apply("matrix-mul")
110-
transformed_value = apply(input_tensor=x, transform=multiplier_transform)
111-
assert torch.equal(transformed_value, x)
107+
transformed_output = multiplier_transform.apply(x)
108+
assert torch.equal(transformed_output, x)

0 commit comments

Comments
 (0)