-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
Enable V1 for Hybrid SSM/Attention Models #20016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
de4e3a2
617cd26
9378c54
300d25f
0822308
a9fc73f
0e5b6de
89f504a
0b7783b
ded4833
31db869
c45e7e5
0f20e11
e2c14ba
c5a25eb
cfc38c0
aaa6f0e
d187bfd
fde28dc
1777fd1
58e66c9
c2da03e
e0404c9
c74698d
b72b729
105737c
d8ff3b9
c857ec3
ea8cf32
b38d3fb
e6b0015
14fd006
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
import pytest | ||
|
||
from tests.models.registry import HF_EXAMPLE_MODELS | ||
from tests.utils import multi_gpu_test | ||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.sampling_params import SamplingParams | ||
|
@@ -19,31 +20,55 @@ | |
SSM_MODELS = [ | ||
"state-spaces/mamba-130m-hf", | ||
"tiiuae/falcon-mamba-tiny-dev", | ||
# TODO: Compare to a Mamba2 model. The HF transformers implementation of | ||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test | ||
# doesn't compare vLLM output with HF output. | ||
# See https://github.com/huggingface/transformers/pull/35943 | ||
"mistralai/Mamba-Codestral-7B-v0.1", | ||
] | ||
|
||
HYBRID_MODELS = [ | ||
"ai21labs/Jamba-tiny-dev", | ||
# NOTE: Currently the test failes due to HF transformers issue fixed in: | ||
# https://github.com/huggingface/transformers/pull/39033 | ||
# We will enable vLLM test for Granite after next HF transformers release. | ||
# "ibm-granite/granite-4.0-tiny-preview", | ||
# NOTE: Running Plamo2 in transformers implementation requires to install | ||
# causal-conv1d package, which is not listed as a test dependency as it's | ||
# not compatible with pip-compile. | ||
"pfnet/plamo-2-1b", | ||
"Zyphra/Zamba2-1.2B-instruct", | ||
"hmellor/tiny-random-BambaForCausalLM", | ||
"ibm-ai-platform/Bamba-9B-v1", | ||
"nvidia/Nemotron-H-8B-Base-8K", | ||
"ibm-granite/granite-4.0-tiny-preview", | ||
"tiiuae/Falcon-H1-0.5B-Base", | ||
] | ||
|
||
HF_UNSUPPORTED_MODELS = [ | ||
# The HF transformers implementation of | ||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test | ||
# doesn't compare vLLM output with HF output. | ||
# See https://github.com/huggingface/transformers/pull/35943 | ||
"mistralai/Mamba-Codestral-7B-v0.1", | ||
# Note: I'm not seeing the same output from vLLM V0 vs. HF transformers | ||
# for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1 | ||
"nvidia/Nemotron-H-8B-Base-8K", | ||
Comment on lines
+41
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if the Nemotron issue is also caused by this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tlrmchlsmth @tdoublep is this a blocker to this PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. |
||
# NOTE: Currently the test fails due to HF transformers issue fixed in: | ||
# https://github.com/huggingface/transformers/pull/39033 | ||
# We will enable vLLM test for Granite after next HF transformers release. | ||
"ibm-granite/granite-4.0-tiny-preview", | ||
] | ||
|
||
V1_SUPPORTED_MODELS = [ | ||
"mistralai/Mamba-Codestral-7B-v0.1", | ||
"ibm-ai-platform/Bamba-9B-v1", | ||
"Zyphra/Zamba2-1.2B-instruct", | ||
"nvidia/Nemotron-H-8B-Base-8K", | ||
"ibm-granite/granite-4.0-tiny-preview", | ||
"tiiuae/Falcon-H1-0.5B-Base", | ||
] | ||
|
||
ATTN_BLOCK_SIZES = { | ||
"ibm-ai-platform/Bamba-9B-v1": 528, | ||
"Zyphra/Zamba2-1.2B-instruct": 80, | ||
"nvidia/Nemotron-H-8B-Base-8K": 528, | ||
"ibm-granite/granite-4.0-tiny-preview": 400, | ||
"tiiuae/Falcon-H1-0.5B-Base": 800, | ||
Comment on lines
+65
to
+69
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tdoublep do you know what attention backends are used for these block sizes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm using FlashInfer in all of the tests. This is required because it reorders the batch in the same way as the Mamba backend. I saw that in some of the cases where the block size is really big (e.g., Falcon-H1) that it triggers some (jit?) compilation from FlashInfer when running serving benchmark. The results still look good though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To my understanding, FlashAttention is compatible with the decode-first order needed by mamba2 as it accepts arbitrary order. Current But due to the kv_cache_shape problem I mentioned below, we cannot use FlashAttention in this PR, so I think it's fine to use FlashInfer in this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the extremely big block size, we only need to make And does the JIT compilation happen during engine initialization or during execution? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. The (2, num_blocks, ...) issue is the bigger reason why FlashAttention can't be supported right now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the serving benchmark, I think it happens for the 1st test prompt that is sent to warm things up. Will double check (confirmed). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to call that kernel during the engine's warm up stage instead of the warmup prompt? (I'm OK with left a warning during engine warmup in this PR when strange block_size is used and fix it later) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK so this is false alarm, sorry. The jit recompilation that I'm seeing is not related to the large attention block size. The only hybrid model I see it happening for is |
||
} | ||
|
||
# Avoid OOM | ||
MAX_NUM_SEQS = 4 | ||
|
||
|
@@ -60,8 +85,16 @@ def test_models( | |
max_tokens: int, | ||
num_logprobs: int, | ||
) -> None: | ||
|
||
try: | ||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model) | ||
model_info.check_available_online(on_fail="skip") | ||
model_info.check_transformers_version(on_fail="skip") | ||
except ValueError: | ||
pass | ||
|
||
with hf_runner(model) as hf_model: | ||
if model != "mistralai/Mamba-Codestral-7B-v0.1": | ||
if model not in HF_UNSUPPORTED_MODELS: | ||
hf_outputs = hf_model.generate_greedy_logprobs_limit( | ||
example_prompts, max_tokens, num_logprobs) | ||
else: | ||
|
@@ -72,12 +105,21 @@ def test_models( | |
example_prompts, max_tokens, num_logprobs) | ||
|
||
if model in V1_SUPPORTED_MODELS: | ||
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES: | ||
block_size = ATTN_BLOCK_SIZES[model] | ||
else: | ||
block_size = 16 | ||
|
||
with monkeypatch.context() as m: | ||
m.setenv("VLLM_USE_V1", "1") | ||
if model in HYBRID_MODELS: | ||
# required due to reorder_batch behaviour | ||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") | ||
tdoublep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with vllm_runner(model, | ||
max_num_seqs=MAX_NUM_SEQS, | ||
enforce_eager=True, | ||
enable_prefix_caching=False) as vllm_model: | ||
enable_prefix_caching=False, | ||
block_size=block_size) as vllm_model: | ||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( | ||
example_prompts, max_tokens, num_logprobs) | ||
else: | ||
|
@@ -111,6 +153,14 @@ def test_batching( | |
max_tokens: int, | ||
num_logprobs: int, | ||
) -> None: | ||
|
||
try: | ||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model) | ||
model_info.check_available_online(on_fail="skip") | ||
model_info.check_transformers_version(on_fail="skip") | ||
except ValueError: | ||
pass | ||
|
||
for_loop_outputs = [] | ||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: | ||
for prompt in example_prompts: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these will be too large for the CI, which runs on L4 GPUs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The models are close to the limit, but the tests do pass in CI.