Skip to content

Commit 2c96312

Browse files
author
evian
committed
[KV Cache] support kv cache int8 per channel quant
Signed-off-by: evian <eviantai@u.nus.edu>
1 parent 180226b commit 2c96312

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def initialize_module_for_quantization(
8383

8484
if is_attention_module(module):
8585
# quantized actions based on calltime status
86-
_initialize_attn_scales(module)
86+
_initialize_attn_scales(module, scheme.output_activations)
8787

8888
else:
8989

@@ -220,10 +220,18 @@ def _initialize_scale_zero_point(
220220
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
221221

222222

223-
def _initialize_attn_scales(module: Module) -> None:
223+
def _initialize_attn_scales(module: Module, quantization_args: QuantizationArgs) -> None:
224224
"""Initlaize k_scale, v_scale for self_attn"""
225225

226-
expected_shape = 1 # per tensor
226+
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
227+
expected_shape = module.k_proj.out_features
228+
elif quantization_args.strategy == QuantizationStrategy.TENSOR:
229+
expected_shape = 1
230+
else:
231+
raise ValueError(
232+
f"One of {(QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL)} must be specified "
233+
f"for kv cache quantization."
234+
)
227235

228236
param = next(module.parameters())
229237
scale_dtype = param.dtype

0 commit comments

Comments
 (0)