Skip to content

Commit 2f35a02

Browse files
tdoubleps3woztlrmchlsmthheheda12345
authored
Enable V1 for Hybrid SSM/Attention Models (vllm-project#20016)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
1 parent ffe00ef commit 2f35a02

File tree

14 files changed

+401
-136
lines changed

14 files changed

+401
-136
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55

6+
from tests.models.registry import HF_EXAMPLE_MODELS
67
from tests.utils import multi_gpu_test
78
from vllm.engine.arg_utils import EngineArgs
89
from vllm.sampling_params import SamplingParams
@@ -19,31 +20,55 @@
1920
SSM_MODELS = [
2021
"state-spaces/mamba-130m-hf",
2122
"tiiuae/falcon-mamba-tiny-dev",
22-
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
23-
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
24-
# doesn't compare vLLM output with HF output.
25-
# See https://github.com/huggingface/transformers/pull/35943
2623
"mistralai/Mamba-Codestral-7B-v0.1",
2724
]
2825

2926
HYBRID_MODELS = [
3027
"ai21labs/Jamba-tiny-dev",
31-
# NOTE: Currently the test failes due to HF transformers issue fixed in:
32-
# https://github.com/huggingface/transformers/pull/39033
33-
# We will enable vLLM test for Granite after next HF transformers release.
34-
# "ibm-granite/granite-4.0-tiny-preview",
3528
# NOTE: Running Plamo2 in transformers implementation requires to install
3629
# causal-conv1d package, which is not listed as a test dependency as it's
3730
# not compatible with pip-compile.
3831
"pfnet/plamo-2-1b",
3932
"Zyphra/Zamba2-1.2B-instruct",
4033
"hmellor/tiny-random-BambaForCausalLM",
34+
"ibm-ai-platform/Bamba-9B-v1",
35+
"nvidia/Nemotron-H-8B-Base-8K",
36+
"ibm-granite/granite-4.0-tiny-preview",
37+
"tiiuae/Falcon-H1-0.5B-Base",
38+
]
39+
40+
HF_UNSUPPORTED_MODELS = [
41+
# The HF transformers implementation of
42+
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
43+
# doesn't compare vLLM output with HF output.
44+
# See https://github.com/huggingface/transformers/pull/35943
45+
"mistralai/Mamba-Codestral-7B-v0.1",
46+
# Note: I'm not seeing the same output from vLLM V0 vs. HF transformers
47+
# for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1
48+
"nvidia/Nemotron-H-8B-Base-8K",
49+
# NOTE: Currently the test fails due to HF transformers issue fixed in:
50+
# https://github.com/huggingface/transformers/pull/39033
51+
# We will enable vLLM test for Granite after next HF transformers release.
52+
"ibm-granite/granite-4.0-tiny-preview",
4153
]
4254

4355
V1_SUPPORTED_MODELS = [
4456
"mistralai/Mamba-Codestral-7B-v0.1",
57+
"ibm-ai-platform/Bamba-9B-v1",
58+
"Zyphra/Zamba2-1.2B-instruct",
59+
"nvidia/Nemotron-H-8B-Base-8K",
60+
"ibm-granite/granite-4.0-tiny-preview",
61+
"tiiuae/Falcon-H1-0.5B-Base",
4562
]
4663

64+
ATTN_BLOCK_SIZES = {
65+
"ibm-ai-platform/Bamba-9B-v1": 528,
66+
"Zyphra/Zamba2-1.2B-instruct": 80,
67+
"nvidia/Nemotron-H-8B-Base-8K": 528,
68+
"ibm-granite/granite-4.0-tiny-preview": 400,
69+
"tiiuae/Falcon-H1-0.5B-Base": 800,
70+
}
71+
4772
# Avoid OOM
4873
MAX_NUM_SEQS = 4
4974

@@ -60,8 +85,16 @@ def test_models(
6085
max_tokens: int,
6186
num_logprobs: int,
6287
) -> None:
88+
89+
try:
90+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
91+
model_info.check_available_online(on_fail="skip")
92+
model_info.check_transformers_version(on_fail="skip")
93+
except ValueError:
94+
pass
95+
6396
with hf_runner(model) as hf_model:
64-
if model != "mistralai/Mamba-Codestral-7B-v0.1":
97+
if model not in HF_UNSUPPORTED_MODELS:
6598
hf_outputs = hf_model.generate_greedy_logprobs_limit(
6699
example_prompts, max_tokens, num_logprobs)
67100
else:
@@ -72,12 +105,21 @@ def test_models(
72105
example_prompts, max_tokens, num_logprobs)
73106

74107
if model in V1_SUPPORTED_MODELS:
108+
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
109+
block_size = ATTN_BLOCK_SIZES[model]
110+
else:
111+
block_size = 16
112+
75113
with monkeypatch.context() as m:
76114
m.setenv("VLLM_USE_V1", "1")
115+
if model in HYBRID_MODELS:
116+
# required due to reorder_batch behaviour
117+
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
77118
with vllm_runner(model,
78119
max_num_seqs=MAX_NUM_SEQS,
79120
enforce_eager=True,
80-
enable_prefix_caching=False) as vllm_model:
121+
enable_prefix_caching=False,
122+
block_size=block_size) as vllm_model:
81123
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
82124
example_prompts, max_tokens, num_logprobs)
83125
else:
@@ -111,6 +153,14 @@ def test_batching(
111153
max_tokens: int,
112154
num_logprobs: int,
113155
) -> None:
156+
157+
try:
158+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
159+
model_info.check_available_online(on_fail="skip")
160+
model_info.check_transformers_version(on_fail="skip")
161+
except ValueError:
162+
pass
163+
114164
for_loop_outputs = []
115165
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
116166
for prompt in example_prompts:

tests/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def check_available_online(
169169
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
170170
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
171171
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
172-
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
172+
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base",
173173
min_transformers_version="4.53"),
174174
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
175175
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),

tests/v1/test_oracle.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"openai/whisper-large-v3", # transcription
1414
"facebook/bart-large-cnn", # encoder decoder
1515
"state-spaces/mamba-130m-hf", # mamba1
16-
"hmellor/tiny-random-BambaForCausalLM", # hybrid
1716
"BAAI/bge-m3", # embedding
1817
]
1918

vllm/model_executor/layers/mamba/ops/mamba_ssm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _selective_scan_update_kernel(
108108
# is the same as the batch id.
109109
if HAS_STATE_BATCH_INDICES:
110110
state_batch_indices_ptr += pid_b
111-
state_batch_idx = tl.load(state_batch_indices_ptr)
111+
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
112112
state_ptr += (state_batch_idx * stride_state_batch +
113113
pid_h * stride_state_head)
114114
else:

vllm/model_executor/models/bamba.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch import nn
1010
from transformers import BambaConfig
1111

12+
from vllm import envs
1213
from vllm.attention.layer import Attention
1314
from vllm.config import CacheConfig, VllmConfig
1415
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -36,7 +37,7 @@
3637
from vllm.utils import LayerBlockType
3738

3839
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
39-
SupportsQuant, SupportsV0Only)
40+
SupportsQuant)
4041
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4142
make_empty_intermediate_tensors_factory, make_layers,
4243
maybe_prefix)
@@ -97,7 +98,9 @@ def __init__(self,
9798
head_dim=config.mamba_d_head,
9899
rms_norm_eps=config.rms_norm_eps,
99100
activation=config.hidden_act,
100-
quant_config=quant_config)
101+
quant_config=quant_config,
102+
prefix=f"{prefix}.mixer",
103+
chunk_size=config.mamba_chunk_size)
101104

102105
self.feed_forward = BambaMLP(config, quant_config=quant_config)
103106
self.input_layernorm = RMSNorm(config.hidden_size,
@@ -313,10 +316,14 @@ def forward(
313316

314317
attn_metadata = get_forward_context().attn_metadata
315318

316-
mamba2_metadata = prepare_mamba2_metadata(
317-
chunk_size=self.config.mamba_chunk_size,
318-
attn_metadata=attn_metadata,
319-
)
319+
if not envs.VLLM_USE_V1:
320+
mamba2_metadata = prepare_mamba2_metadata(
321+
chunk_size=self.config.mamba_chunk_size,
322+
attn_metadata=attn_metadata,
323+
)
324+
else:
325+
# v1 get mamba2_metadata from forward_context
326+
mamba2_metadata = None
320327

321328
if get_pp_group().is_first_rank:
322329
if inputs_embeds is not None:
@@ -337,7 +344,8 @@ def forward(
337344
num_attn += 1
338345

339346
layer_mamba_cache_params = None
340-
if isinstance(layer, BambaMixerDecoderLayer):
347+
if isinstance(layer,
348+
BambaMixerDecoderLayer) and mamba_cache_params:
341349
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
342350
i - num_attn)
343351

@@ -411,7 +419,7 @@ def load_weights(self, weights: Iterable[tuple[str,
411419

412420

413421
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
414-
IsHybrid, SupportsV0Only, SupportsQuant):
422+
IsHybrid, SupportsQuant):
415423
packed_modules_mapping = {
416424
"qkv_proj": [
417425
"q_proj",
@@ -475,15 +483,22 @@ def forward(self,
475483
intermediate_tensors: Optional[IntermediateTensors] = None,
476484
inputs_embeds: Optional[torch.Tensor] = None,
477485
**kwargs):
478-
if self.mamba_cache is None:
479486

480-
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
481-
self.vllm_config.parallel_config, LayerBlockType.mamba)
487+
mamba_cache_params = None
488+
if not envs.VLLM_USE_V1:
489+
if self.mamba_cache is None:
490+
num_mamba_layers = \
491+
self.model_config.get_num_layers_by_block_type(
492+
self.vllm_config.parallel_config,
493+
LayerBlockType.mamba
494+
)
495+
496+
self.mamba_cache = MambaCacheManager(
497+
self.vllm_config, self.lm_head.weight.dtype,
498+
num_mamba_layers, *self._get_mamba_cache_shape())
499+
500+
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
482501

483-
self.mamba_cache = MambaCacheManager(
484-
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
485-
*self._get_mamba_cache_shape())
486-
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
487502
hidden_states = self.model(input_ids, positions, mamba_cache_params,
488503
intermediate_tensors, inputs_embeds)
489504

vllm/model_executor/models/falcon_h1.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch import nn
99
from transformers import FalconH1Config
1010

11+
from vllm import envs
1112
from vllm.attention.layer import Attention
1213
from vllm.config import CacheConfig, VllmConfig
1314
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -33,8 +34,7 @@
3334
from vllm.model_executor.sampling_metadata import SamplingMetadata
3435
from vllm.sequence import IntermediateTensors
3536

36-
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
37-
SupportsV0Only)
37+
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
3838
from .utils import (PPMissingLayer, is_pp_missing_parameter,
3939
make_empty_intermediate_tensors_factory, make_layers,
4040
maybe_prefix)
@@ -85,6 +85,7 @@ def __init__(
8585
config: FalconH1Config,
8686
cache_config: Optional[CacheConfig] = None,
8787
quant_config: Optional[QuantizationConfig] = None,
88+
prefix: str = "",
8889
) -> None:
8990
super().__init__()
9091
self.config = config
@@ -107,6 +108,8 @@ def __init__(
107108
activation=config.hidden_act,
108109
quant_config=quant_config,
109110
use_rms_norm=config.mamba_rms_norm,
111+
prefix=f"{prefix}.mixer",
112+
chunk_size=config.mamba_chunk_size,
110113
)
111114
# n_groups is overridden later by `MambaMixer2`
112115
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
@@ -316,18 +319,26 @@ def __init__(
316319
prefix: str = "",
317320
) -> None:
318321
super().__init__()
322+
319323
# Instantiate the attention branch
320324
self.self_attn = FalconH1AttentionDecoderLayer(
321325
config=config,
322326
cache_config=cache_config,
323327
quant_config=quant_config,
324328
prefix=prefix,
325329
)
330+
331+
# In V1 all attention/ssm layers must have
332+
# different index in prefix
333+
ssm_layer_idx = config.num_hidden_layers + layer_idx
334+
ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}"
335+
326336
# Instantiate the SSM branch
327337
self.mamba = FalconH1SSMDecoderLayer(
328338
config=config,
329339
cache_config=cache_config,
330340
quant_config=quant_config,
341+
prefix=ssm_prefix,
331342
)
332343
self.ssm_out_multiplier = config.ssm_out_multiplier
333344
self.ssm_in_multiplier = config.ssm_in_multiplier
@@ -452,10 +463,16 @@ def forward(
452463
# proper continuous batching computation including
453464
# chunked prefill
454465
attn_metadata = get_forward_context().attn_metadata
455-
mamba2_metadata = prepare_mamba2_metadata(
456-
chunk_size=self.config.mamba_chunk_size,
457-
attn_metadata=attn_metadata,
458-
)
466+
467+
if not envs.VLLM_USE_V1:
468+
mamba2_metadata = prepare_mamba2_metadata(
469+
chunk_size=self.config.mamba_chunk_size,
470+
attn_metadata=attn_metadata,
471+
)
472+
else:
473+
# v1 get mamba2_metadata from forward_context
474+
mamba2_metadata = None
475+
459476
if get_pp_group().is_first_rank:
460477
if inputs_embeds is not None:
461478
hidden_states = inputs_embeds * self.embedding_multiplier
@@ -468,7 +485,9 @@ def forward(
468485

469486
for i in range(self.start_layer, self.end_layer):
470487
layer = self.layers[i]
471-
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
488+
layer_mamba_cache_params = None
489+
if mamba_cache_params:
490+
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
472491
hidden_states = layer(
473492
positions=positions,
474493
hidden_states=hidden_states,
@@ -484,7 +503,7 @@ def forward(
484503

485504

486505
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
487-
IsHybrid, SupportsV0Only):
506+
IsHybrid):
488507
packed_modules_mapping = {
489508
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
490509
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -558,15 +577,19 @@ def forward(
558577
inputs_embeds: Optional[torch.Tensor] = None,
559578
**kwargs,
560579
):
561-
if self.mamba_cache is None:
562-
self.mamba_cache = MambaCacheManager(
563-
self.vllm_config,
564-
self.lm_head.weight.dtype
565-
if hasattr(self.lm_head, 'weight') else torch.bfloat16,
566-
self.config.num_hidden_layers,
567-
*self._get_mamba_cache_shape(),
568-
)
569-
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
580+
581+
mamba_cache_params = None
582+
if not envs.VLLM_USE_V1:
583+
if self.mamba_cache is None:
584+
self.mamba_cache = MambaCacheManager(
585+
self.vllm_config,
586+
self.lm_head.weight.dtype if hasattr(
587+
self.lm_head, 'weight') else torch.bfloat16,
588+
self.config.num_hidden_layers,
589+
*self._get_mamba_cache_shape(),
590+
)
591+
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
592+
570593
hidden_states = self.model(
571594
input_ids,
572595
positions,

0 commit comments

Comments
 (0)