Skip to content

Commit ea8848b

Browse files
authored
[Bugfix] Support offloaded parameters when initializing KV cache parameters (#261)
* use register_offload_parameter Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent b7dd816 commit ea8848b

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/compressed_tensors/linear/compressed_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
QuantizationStatus,
2222
initialize_module_for_quantization,
2323
)
24+
from compressed_tensors.utils import register_offload_parameter
2425
from torch import Tensor
2526
from torch.nn import Parameter
2627
from torch.nn.functional import linear
@@ -68,7 +69,7 @@ def from_linear(
6869
param = Parameter(
6970
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
7071
)
71-
module.register_parameter(name, param)
72+
register_offload_parameter(module, name, param)
7273

7374
# mark module as compressed
7475
module.quantization_status = QuantizationStatus.COMPRESSED

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,10 @@ def _initialize_attn_scales(module: Module) -> None:
203203
torch.empty(expected_shape, dtype=scale_dtype, device=device),
204204
requires_grad=False,
205205
)
206-
207-
module.register_parameter(KVCacheScaleType.KEY.value, init_scale)
206+
register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale)
208207

209208
init_scale = Parameter(
210209
torch.empty(expected_shape, dtype=scale_dtype, device=device),
211210
requires_grad=False,
212211
)
213-
module.register_parameter(KVCacheScaleType.VALUE.value, init_scale)
212+
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)

0 commit comments

Comments
 (0)