Skip to content

Commit 4f605a6

Browse files
authored
Fix noisy warning for uncalibrated q_scale/p_scale (vllm-project#17414)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 8342e3a commit 4f605a6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

vllm/model_executor/layers/quantization/kv_cache.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
124124
# These are used in the final Attention.forward()
125125
layer._q_scale.copy_(q_scale)
126126
layer._prob_scale.copy_(prob_scale)
127-
if q_scale == 1.0 or prob_scale == 1.0:
127+
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
128+
or prob_scale == 1.0):
128129
logger.warning_once(
129-
f"Using Q scale {q_scale} and prob scale {prob_scale} "
130-
"with fp8 attention. This may cause accuracy issues. "
131-
"Please make sure Q/prob scaling factors are "
130+
f"Using uncalibrated q_scale {q_scale} and/or prob_scale "
131+
f"{prob_scale} with fp8 attention. This may cause accuracy "
132+
"issues. Please make sure q/prob scaling factors are "
132133
"available in the fp8 checkpoint.")
133134

134135
del layer.k_scale

0 commit comments

Comments
 (0)