Skip to content

Commit b762e56

Browse files
authored
Update: CompressedLinear to decompress once (#266)
* Update: CompressedLinear to decompress once Signed-off-by: Rahul Tuli <rahul@neuralmagic.com> * Update name! --------- Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
1 parent ea8848b commit b762e56

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/compressed_tensors/linear/compressed_linear.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class CompressedLinear(Linear):
3838
:param quantization_format: compression format module is stored as
3939
"""
4040

41+
def __init__(self, *args, **kwargs) -> None:
42+
super().__init__(*args, **kwargs)
43+
self._is_compressed = True
44+
4145
@classmethod
4246
@torch.no_grad()
4347
def from_linear(
@@ -86,5 +90,8 @@ def forward(self, input: Tensor) -> Tensor:
8690
"""
8791
Decompresses the weight, then runs the wrapped forward pass
8892
"""
89-
uncompressed_weight = self.compressor.decompress_module(self)
90-
return linear(input, uncompressed_weight, self.bias)
93+
if self._is_compressed:
94+
self.weight = self.compressor.decompress_module(self)
95+
self._is_compressed = False
96+
97+
return linear(input, self.weight, self.bias)

0 commit comments

Comments
 (0)