17
17
"state-spaces/mamba-130m-hf" ,
18
18
"tiiuae/falcon-mamba-tiny-dev" ,
19
19
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
20
- # Mamba2 is buggy for Codestral as it doesn't handle n_groups.
20
+ # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
21
+ # doesn't compare vLLM output with HF output.
21
22
# See https://github.com/huggingface/transformers/pull/35943
22
- # "mistralai/Mamba-Codestral-7B-v0.1",
23
+ "mistralai/Mamba-Codestral-7B-v0.1" ,
23
24
]
24
25
25
26
HYBRID_MODELS = [
35
36
"hmellor/tiny-random-BambaForCausalLM" ,
36
37
]
37
38
39
+ V1_SUPPORTED_MODELS = [
40
+ "mistralai/Mamba-Codestral-7B-v0.1" ,
41
+ ]
42
+
38
43
# Avoid OOM
39
44
MAX_NUM_SEQS = 4
40
45
@@ -46,24 +51,50 @@ def test_models(
46
51
hf_runner ,
47
52
vllm_runner ,
48
53
example_prompts ,
54
+ monkeypatch ,
49
55
model : str ,
50
56
max_tokens : int ,
51
57
num_logprobs : int ,
52
58
) -> None :
53
59
with hf_runner (model ) as hf_model :
54
- hf_outputs = hf_model .generate_greedy_logprobs_limit (
55
- example_prompts , max_tokens , num_logprobs )
60
+ if model != "mistralai/Mamba-Codestral-7B-v0.1" :
61
+ hf_outputs = hf_model .generate_greedy_logprobs_limit (
62
+ example_prompts , max_tokens , num_logprobs )
63
+ else :
64
+ hf_outputs = None
56
65
57
66
with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
58
- vllm_outputs = vllm_model .generate_greedy_logprobs (
67
+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
59
68
example_prompts , max_tokens , num_logprobs )
60
69
61
- check_logprobs_close (
62
- outputs_0_lst = hf_outputs ,
63
- outputs_1_lst = vllm_outputs ,
64
- name_0 = "hf" ,
65
- name_1 = "vllm" ,
66
- )
70
+ if model in V1_SUPPORTED_MODELS :
71
+ with monkeypatch .context () as m :
72
+ m .setenv ("VLLM_USE_V1" , "1" )
73
+ with vllm_runner (model ,
74
+ max_num_seqs = MAX_NUM_SEQS ,
75
+ enforce_eager = True ,
76
+ enable_prefix_caching = False ) as vllm_model :
77
+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
78
+ example_prompts , max_tokens , num_logprobs )
79
+ else :
80
+ vllm_v1_outputs = None
81
+
82
+ if hf_outputs is not None :
83
+ check_logprobs_close (
84
+ outputs_0_lst = hf_outputs ,
85
+ outputs_1_lst = vllm_v0_outputs ,
86
+ name_0 = "hf" ,
87
+ name_1 = "vllm-v0" ,
88
+ )
89
+
90
+ if model in V1_SUPPORTED_MODELS :
91
+ ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
92
+ check_logprobs_close (
93
+ outputs_0_lst = ref_outputs ,
94
+ outputs_1_lst = vllm_v1_outputs ,
95
+ name_0 = "hf" if hf_outputs is not None else "vllm-v0" ,
96
+ name_1 = "vllm-v1" ,
97
+ )
67
98
68
99
69
100
@pytest .mark .parametrize ("model" , SSM_MODELS + HYBRID_MODELS )
0 commit comments