13
13
from vllm .attention .backends .utils import CommonAttentionState
14
14
from vllm .config import VllmConfig
15
15
from vllm .logger import init_logger
16
- from vllm .utils import cdiv , next_power_of_2
16
+ from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , cdiv , next_power_of_2
17
17
18
18
logger = init_logger (__name__ )
19
19
@@ -137,8 +137,6 @@ def __init__(
137
137
raise NotImplementedError ("Head size must be a multiple of 128." )
138
138
if alibi_slopes is not None :
139
139
raise NotImplementedError ("Alibi slopes is not supported." )
140
- if kv_cache_dtype != "auto" :
141
- raise NotImplementedError ("FP8 KV cache dtype is not supported." )
142
140
if blocksparse_params is not None :
143
141
raise NotImplementedError ("Blocksparse is not supported." )
144
142
@@ -151,6 +149,14 @@ def __init__(
151
149
tpu_version = torch_xla .tpu .version ()
152
150
if tpu_version < 4 :
153
151
raise NotImplementedError ("TPU version must be 4 or higher." )
152
+ self .kv_cache_quantized_dtype = None
153
+ if kv_cache_dtype != "auto" :
154
+ if tpu_version < 5 :
155
+ raise NotImplementedError (
156
+ "FP8 KV cache dtype is only supported when TPU version"
157
+ " is 5 or higher." )
158
+ self .kv_cache_quantized_dtype = STR_DTYPE_TO_TORCH_DTYPE .get (
159
+ kv_cache_dtype .lower ().strip ())
154
160
155
161
def forward (
156
162
self ,
@@ -179,15 +185,16 @@ def forward(
179
185
output = torch .ones_like (query )
180
186
return output
181
187
182
- assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
183
188
num_tokens , hidden_size = query .shape
184
189
query = query .view (num_tokens , self .num_heads , self .head_size )
185
190
186
191
if self .kv_sharing_target_layer_name is None and kv_cache .numel () > 0 :
187
192
# Write input keys and values to the KV cache.
188
193
# Skip this if sharing KV cache with an earlier attention layer.
189
194
slot_mapping = attn_metadata .slot_mapping
190
- write_to_kv_cache (key , value , kv_cache , slot_mapping )
195
+ write_to_kv_cache (key , value , kv_cache , slot_mapping ,
196
+ self .kv_cache_quantized_dtype ,
197
+ layer ._k_scale_float , layer ._v_scale_float )
191
198
192
199
output = torch .ops .xla .ragged_paged_attention (
193
200
query ,
@@ -206,6 +213,8 @@ def forward(
206
213
sm_scale = self .scale ,
207
214
sliding_window = self .sliding_window ,
208
215
soft_cap = self .logits_soft_cap ,
216
+ k_scale = 1 / layer ._k_scale_float ,
217
+ v_scale = 1 / layer ._v_scale_float ,
209
218
)
210
219
211
220
return output .reshape (num_tokens , hidden_size )
@@ -216,6 +225,9 @@ def write_to_kv_cache(
216
225
value : torch .Tensor ,
217
226
kv_cache : torch .Tensor ,
218
227
slot_mapping : torch .Tensor ,
228
+ kv_cache_quantized_dtype : Optional [torch .dtype ] = None ,
229
+ k_scale : float = 1.0 ,
230
+ v_scale : float = 1.0 ,
219
231
) -> None :
220
232
""" Write the key and values to the KV cache.
221
233
@@ -230,6 +242,11 @@ def write_to_kv_cache(
230
242
231
243
key = key .view (- 1 , num_kv_heads , head_size )
232
244
value = value .view (- 1 , num_kv_heads , head_size )
245
+ if kv_cache_quantized_dtype is not None :
246
+ key = key * k_scale
247
+ key = key .to (kv_cache_quantized_dtype )
248
+ value = value * v_scale
249
+ value = value .to (kv_cache_quantized_dtype )
233
250
234
251
kv = torch .cat ([key , value ], axis = - 1 ).reshape (- 1 , num_combined_kv_heads ,
235
252
head_size )
0 commit comments