-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Enable V1 for Hybrid SSM/Attention Models #20016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
de4e3a2
617cd26
9378c54
300d25f
0822308
a9fc73f
0e5b6de
89f504a
0b7783b
ded4833
31db869
c45e7e5
0f20e11
e2c14ba
c5a25eb
cfc38c0
aaa6f0e
d187bfd
fde28dc
1777fd1
58e66c9
c2da03e
e0404c9
c74698d
b72b729
105737c
d8ff3b9
c857ec3
ea8cf32
b38d3fb
e6b0015
14fd006
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2399,7 +2399,8 @@ def _reshape_kv_cache_tensors( | |
tensor = tensor.view(dtype).view(target_shape) | ||
state_tensors.append(tensor) | ||
start_pos += size_in_bytes | ||
assert start_pos == raw_tensor.numel() | ||
if kv_cache_spec.multiple_of is None: | ||
assert start_pos == raw_tensor.numel() | ||
tdoublep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
kv_caches[layer_name] = tuple(state_tensors) | ||
else: | ||
raise NotImplementedError | ||
|
@@ -2513,6 +2514,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: | |
mamba_layers = get_layers_from_vllm_config(self.vllm_config, | ||
MambaMixer2) | ||
if len(mamba_layers) > 0: | ||
if len(attn_layers) > 0: | ||
# Mamba state must be padded to an integer number of | ||
# 16th tokens worth of attention pages | ||
attn_layer_name = next(iter(attn_layers)) | ||
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to clarify here, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not exactly. Why do we want to know the size in bytes of an attention page that stores 16 tokens? It's because we want to ensure that the mamba page size if padded up to a value that makes it possible for the user to align the attention page size with. Since the user can only set attention block size in multiples of 16, that is why the factor of 16 is needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment to explain the magic number "16"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tlrmchlsmth If I remember correctly, you told me that FlashMLA does not support block_size 16. Can you confirm? If it is true, we may need some other assertion here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So actually, I think this magic "16" is not necessarily needed. The constraint that the block size must be a multiple of 16 is only coming from FlashAttention backend (which is not compatible with Mamba right now for reasons discussed). I just checked and with FlashInfer it is possible to set the block size to any number. Still, probably makes sense to keep "16" since we want to support FlashAttention in near future. Do you agree @heheda12345 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Done |
||
multiple_of = 16 * attn_page_size // block_size | ||
tdoublep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
multiple_of = None | ||
|
||
if self.vllm_config.speculative_config is not None: | ||
raise NotImplementedError( | ||
"Mamba with speculative decoding is not supported yet.") | ||
|
@@ -2529,5 +2539,19 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: | |
kv_cache_spec[layer_name] = MambaSpec( | ||
shapes=mamba_module.get_state_shape(), | ||
dtype=self.kv_cache_dtype, | ||
block_size=max_model_len) | ||
block_size=max_model_len, | ||
multiple_of=multiple_of) | ||
|
||
if len(attn_layers) > 0: | ||
mamba_layer_name = next(iter(mamba_layers)) | ||
mamba_page_size = kv_cache_spec[ | ||
mamba_layer_name].page_size_bytes | ||
if attn_page_size < mamba_page_size: | ||
required_attn_block_size = cdiv(mamba_page_size, | ||
multiple_of) * 16 | ||
raise ValueError( | ||
"Attention block size must be increased to " | ||
f"{required_attn_block_size} in order to match " | ||
"the mamba page size") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a fairly reasonable approach, especially for a first pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Main question is whether we are OK with vLLM V1 failing under default parameters for hybrid models? If not, we could automatically scale up the attention block size and log what is happening to inform the user, rather than explicitly ask the user to do it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's a better option, paired with logging a warning. But that could also wait for a follow-up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the printed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It needs to be at least this value in order to work though right? I can't really think of practical scenarios when we would want the attention page size to the bigger than the mamba page size. Mamba page size is typically orders of magnitude bigger than attention page size (per token). If the attention page size is bigger, we will need to pad mamba page to align it and waste more space. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've changed the language in the exception, please take a look. |
||
|
||
return kv_cache_spec |
Uh oh!
There was an error while loading. Please reload this page.