@@ -2575,15 +2575,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
2575
2575
mamba_layers = get_layers_from_vllm_config (self .vllm_config ,
2576
2576
MambaMixer2 )
2577
2577
if len (mamba_layers ) > 0 :
2578
- if len (attn_layers ) > 0 :
2579
- # Mamba state must be padded to an integer number of
2580
- # 16th tokens worth of attention pages
2581
- attn_layer_name = next (iter (attn_layers ))
2582
- attn_page_size = kv_cache_spec [attn_layer_name ].page_size_bytes
2583
- multiple_of = 16 * attn_page_size // block_size
2584
- else :
2585
- multiple_of = None
2586
-
2587
2578
if self .vllm_config .speculative_config is not None :
2588
2579
raise NotImplementedError (
2589
2580
"Mamba with speculative decoding is not supported yet." )
@@ -2594,25 +2585,39 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
2594
2585
raise NotImplementedError (
2595
2586
"Prefix caching is not supported for Mamba yet." )
2596
2587
max_model_len = self .vllm_config .model_config .max_model_len
2588
+
2589
+ if len (attn_layers ) > 0 :
2590
+ attn_layer_name = next (iter (attn_layers ))
2591
+ attn_page_size = kv_cache_spec [attn_layer_name ].page_size_bytes
2592
+ mamba_layer_name = next (iter (mamba_layers ))
2593
+ mamba_page_size = MambaSpec (
2594
+ shapes = mamba_layers [mamba_layer_name ].get_state_shape (),
2595
+ dtype = self .kv_cache_dtype ,
2596
+ block_size = max_model_len ).page_size_bytes
2597
+ if attn_page_size < mamba_page_size :
2598
+ # attention page size (for 16 tokens)
2599
+ attn_page_size_16 = 16 * attn_page_size // block_size
2600
+ # some attention backends (e.g. FA) only support setting
2601
+ # block size to multiple of 16, so let's suggest a value
2602
+ # that would work (note: FA is currently not compatible
2603
+ # with mamba layers, use FlashInfer instead).
2604
+ suggest_attn_block_size = 16 * cdiv (
2605
+ mamba_page_size , attn_page_size_16 )
2606
+ raise ValueError (
2607
+ "Attention block size should be increased to at least "
2608
+ f"{ suggest_attn_block_size } in order to match "
2609
+ "the mamba page size" )
2610
+ page_size_padded = attn_page_size
2611
+ else :
2612
+ page_size_padded = None
2613
+
2597
2614
# Set block_size to max_model_len, so that mamba model will always
2598
2615
# have only one block in the KV cache.
2599
2616
for layer_name , mamba_module in mamba_layers .items ():
2600
2617
kv_cache_spec [layer_name ] = MambaSpec (
2601
2618
shapes = mamba_module .get_state_shape (),
2602
2619
dtype = self .kv_cache_dtype ,
2603
2620
block_size = max_model_len ,
2604
- multiple_of = multiple_of )
2605
-
2606
- if len (attn_layers ) > 0 :
2607
- mamba_layer_name = next (iter (mamba_layers ))
2608
- mamba_page_size = kv_cache_spec [
2609
- mamba_layer_name ].page_size_bytes
2610
- if attn_page_size < mamba_page_size :
2611
- required_attn_block_size = cdiv (mamba_page_size ,
2612
- multiple_of ) * 16
2613
- raise ValueError (
2614
- "Attention block size must be increased to "
2615
- f"{ required_attn_block_size } in order to match "
2616
- "the mamba page size" )
2621
+ page_size_padded = page_size_padded )
2617
2622
2618
2623
return kv_cache_spec
0 commit comments