Skip to content

Commit a89209b

Browse files
authored
[v1] Support mamba2 (#19327)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent ffacb22 commit a89209b

File tree

9 files changed

+583
-121
lines changed

9 files changed

+583
-121
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
"state-spaces/mamba-130m-hf",
1818
"tiiuae/falcon-mamba-tiny-dev",
1919
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
20-
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
20+
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
21+
# doesn't compare vLLM output with HF output.
2122
# See https://github.com/huggingface/transformers/pull/35943
22-
# "mistralai/Mamba-Codestral-7B-v0.1",
23+
"mistralai/Mamba-Codestral-7B-v0.1",
2324
]
2425

2526
HYBRID_MODELS = [
@@ -35,6 +36,10 @@
3536
"hmellor/tiny-random-BambaForCausalLM",
3637
]
3738

39+
V1_SUPPORTED_MODELS = [
40+
"mistralai/Mamba-Codestral-7B-v0.1",
41+
]
42+
3843
# Avoid OOM
3944
MAX_NUM_SEQS = 4
4045

@@ -46,24 +51,50 @@ def test_models(
4651
hf_runner,
4752
vllm_runner,
4853
example_prompts,
54+
monkeypatch,
4955
model: str,
5056
max_tokens: int,
5157
num_logprobs: int,
5258
) -> None:
5359
with hf_runner(model) as hf_model:
54-
hf_outputs = hf_model.generate_greedy_logprobs_limit(
55-
example_prompts, max_tokens, num_logprobs)
60+
if model != "mistralai/Mamba-Codestral-7B-v0.1":
61+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
62+
example_prompts, max_tokens, num_logprobs)
63+
else:
64+
hf_outputs = None
5665

5766
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
58-
vllm_outputs = vllm_model.generate_greedy_logprobs(
67+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
5968
example_prompts, max_tokens, num_logprobs)
6069

61-
check_logprobs_close(
62-
outputs_0_lst=hf_outputs,
63-
outputs_1_lst=vllm_outputs,
64-
name_0="hf",
65-
name_1="vllm",
66-
)
70+
if model in V1_SUPPORTED_MODELS:
71+
with monkeypatch.context() as m:
72+
m.setenv("VLLM_USE_V1", "1")
73+
with vllm_runner(model,
74+
max_num_seqs=MAX_NUM_SEQS,
75+
enforce_eager=True,
76+
enable_prefix_caching=False) as vllm_model:
77+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
78+
example_prompts, max_tokens, num_logprobs)
79+
else:
80+
vllm_v1_outputs = None
81+
82+
if hf_outputs is not None:
83+
check_logprobs_close(
84+
outputs_0_lst=hf_outputs,
85+
outputs_1_lst=vllm_v0_outputs,
86+
name_0="hf",
87+
name_1="vllm-v0",
88+
)
89+
90+
if model in V1_SUPPORTED_MODELS:
91+
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
92+
check_logprobs_close(
93+
outputs_0_lst=ref_outputs,
94+
outputs_1_lst=vllm_v1_outputs,
95+
name_0="hf" if hf_outputs is not None else "vllm-v0",
96+
name_1="vllm-v1",
97+
)
6798

6899

69100
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)

tests/v1/test_oracle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
UNSUPPORTED_MODELS_V1 = [
1313
"openai/whisper-large-v3", # transcription
1414
"facebook/bart-large-cnn", # encoder decoder
15-
"mistralai/Mamba-Codestral-7B-v0.1", # mamba
15+
"state-spaces/mamba-130m-hf", # mamba1
1616
"hmellor/tiny-random-BambaForCausalLM", # hybrid
1717
"BAAI/bge-m3", # embedding
1818
]

vllm/engine/arg_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1355,12 +1355,17 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13551355
recommend_to_remove=False)
13561356
return False
13571357

1358-
# No Mamba or Encoder-Decoder so far.
1358+
# No Encoder-Decoder, not all Mamba so far.
13591359
if not model_config.is_v1_compatible:
13601360
_raise_or_fallback(feature_name=model_config.architectures,
13611361
recommend_to_remove=False)
13621362
return False
13631363

1364+
# V1 mamba models are unoptimized.
1365+
if model_config.has_inner_state and _warn_or_fallback(
1366+
feature_name="Mamba"):
1367+
return False
1368+
13641369
# No Concurrent Partial Prefills so far.
13651370
if (self.max_num_partial_prefills
13661371
!= SchedulerConfig.max_num_partial_prefills

0 commit comments

Comments
 (0)