Skip to content

Commit 7ecb1b0

Browse files
committed
fix; add update
1 parent 749420b commit 7ecb1b0

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

src/compressed_tensors/transforms/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ def register_to_module(self, name: str, module: torch.nn.Module):
7575
# TODO: have to verify serialization/offloading
7676
module.register_buffer(name, self.transform)
7777

78+
def update_transform(
79+
self,
80+
data: torch.Tensor,
81+
module: Optional[torch.nn.Module] = None,
82+
name: Optional[str] = None,
83+
):
84+
if module is None:
85+
self.transform.data.copy_(data)
86+
else:
87+
# If updating the module parameter data, assumes this is also the transform
88+
# data
89+
if name is None:
90+
raise ValueError("Name and module are required to update parma data")
91+
update_parameter_data(module, data, name)
92+
7893
def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
7994
"""
8095
Apply the transform to the module

src/compressed_tensors/transforms/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,5 @@ def inverse_apply(
9292
transpose=transpose,
9393
first=first,
9494
)
95-
/ transform.shape[0]
95+
/ self.transform.shape[0]
9696
)

src/compressed_tensors/transforms/matrix_multiply.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ def inverse_apply(
5151
"""
5252

5353
# Note: not implemented for lower precision than float32
54-
transform = torch.linalg.inv(transform)
5554
return apply_matrix_transform(
56-
transform=self.transform,
55+
transform=torch.linalg.inv(self.transform),
5756
input_tensor=input_tensor,
5857
transpose=transpose,
5958
first=first,

0 commit comments

Comments
 (0)