File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed
src/compressed_tensors/quantization/lifecycle Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -83,7 +83,7 @@ def initialize_module_for_quantization(
83
83
84
84
if is_attention_module (module ):
85
85
# quantized actions based on calltime status
86
- _initialize_attn_scales (module )
86
+ _initialize_attn_scales (module , scheme . output_activations )
87
87
88
88
else :
89
89
@@ -220,10 +220,18 @@ def _initialize_scale_zero_point(
220
220
register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
221
221
222
222
223
- def _initialize_attn_scales (module : Module ) -> None :
223
+ def _initialize_attn_scales (module : Module , quantization_args : QuantizationArgs ) -> None :
224
224
"""Initlaize k_scale, v_scale for self_attn"""
225
225
226
- expected_shape = 1 # per tensor
226
+ if quantization_args .strategy == QuantizationStrategy .CHANNEL :
227
+ expected_shape = module .k_proj .out_features
228
+ elif quantization_args .strategy == QuantizationStrategy .TENSOR :
229
+ expected_shape = 1
230
+ else :
231
+ raise ValueError (
232
+ f"One of { (QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL )} must be specified "
233
+ f"for kv cache quantization."
234
+ )
227
235
228
236
param = next (module .parameters ())
229
237
scale_dtype = param .dtype
You can’t perform that action at this time.
0 commit comments