From 2c963123ec0b4a5c4cab9a56525ce151f6deedf2 Mon Sep 17 00:00:00 2001 From: evian Date: Sat, 19 Jul 2025 16:21:14 +0800 Subject: [PATCH] [KV Cache] support kv cache int8 per channel quant Signed-off-by: evian --- .../quantization/lifecycle/initialize.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index d816f855..5dd53014 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -83,7 +83,7 @@ def initialize_module_for_quantization( if is_attention_module(module): # quantized actions based on calltime status - _initialize_attn_scales(module) + _initialize_attn_scales(module, scheme.output_activations) else: @@ -220,10 +220,18 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) -def _initialize_attn_scales(module: Module) -> None: +def _initialize_attn_scales(module: Module, quantization_args: QuantizationArgs) -> None: """Initlaize k_scale, v_scale for self_attn""" - expected_shape = 1 # per tensor + if quantization_args.strategy == QuantizationStrategy.CHANNEL: + expected_shape = module.k_proj.out_features + elif quantization_args.strategy == QuantizationStrategy.TENSOR: + expected_shape = 1 + else: + raise ValueError( + f"One of {(QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL)} must be specified " + f"for kv cache quantization." + ) param = next(module.parameters()) scale_dtype = param.dtype