12
12
from torch .nn import functional as F
13
13
from transformers import MixtralConfig
14
14
from transformers .models .mixtral .modeling_mixtral import MixtralSparseMoeBlock
15
- from typing import Callable , Optional
15
+ from typing import Callable , Optional , Union
16
16
17
17
import vllm .model_executor .layers .fused_moe # noqa
18
18
from tests .kernels .utils import opcheck , stack_and_dev , torch_moe
44
44
45
45
46
46
def run_moe_test (
47
- baseline_moe_fn : Callable ,
47
+ baseline : Union [ Callable , torch . Tensor ] ,
48
48
moe_fn : Callable ,
49
49
a : torch .Tensor ,
50
50
w1 : torch .Tensor ,
@@ -58,8 +58,11 @@ def run_moe_test(
58
58
use_cudagraph : bool = False ,
59
59
atol :float = 2e-2 ,
60
60
rtol :float = 0 ,
61
- ):
62
- baseline_output = baseline_moe_fn (a , w1 , w2 , score , topk , global_num_experts = global_num_experts , expert_map = expert_map )
61
+ ) -> torch .Tensor :
62
+ if isinstance (baseline , torch .Tensor ):
63
+ baseline_output = baseline
64
+ else :
65
+ baseline_output = baseline (a , w1 , w2 , score , topk , global_num_experts = global_num_experts , expert_map = expert_map )
63
66
64
67
# Pad the weight if moe padding is enabled
65
68
if padding :
@@ -77,7 +80,6 @@ def run_moe_test(
77
80
global_num_experts = global_num_experts ,
78
81
expert_map = expert_map )
79
82
80
-
81
83
if use_cudagraph :
82
84
test_output .fill_ (0 )
83
85
stream = torch .cuda .Stream ()
@@ -96,8 +98,9 @@ def run_moe_test(
96
98
97
99
torch .testing .assert_close (test_output , baseline_output , atol = atol , rtol = rtol )
98
100
101
+ return baseline_output
102
+
99
103
100
- # TODO: reduce combinations
101
104
@pytest .mark .parametrize ("m" , [1 , 33 , 64 , 222 , 32768 , 40000 ])
102
105
@pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
103
106
@pytest .mark .parametrize ("k" , [128 , 511 , 1024 ])
@@ -192,13 +195,13 @@ def m_fused_moe(
192
195
padding = padding ,
193
196
)
194
197
195
- use_compile = m >= chunk_size and current_platform .is_cuda_alike ()
198
+ use_compile = m >= chunk_size and n >= 1024 and k >= 1024 and current_platform .is_cuda_alike ()
196
199
use_cudagraph = use_compile
197
200
198
201
with set_current_vllm_config (vllm_config ):
199
- runner (torch_moe , iterative_moe )
200
- runner (torch_moe , fused_moe_fn , use_compile = use_compile , use_cudagraph = use_cudagraph )
201
- runner (torch_moe , m_fused_moe , use_compile = use_compile , use_cudagraph = use_cudagraph )
202
+ baseline_output = runner (torch_moe , iterative_moe )
203
+ runner (baseline_output , fused_moe_fn , use_compile = use_compile , use_cudagraph = use_cudagraph )
204
+ runner (baseline_output , m_fused_moe , use_compile = use_compile , use_cudagraph = use_cudagraph )
202
205
203
206
204
207
@pytest .mark .parametrize ("m" , [1 , 32 , 222 ])
0 commit comments