Skip to content

Commit bd9bd37

Browse files
committed
reduce number of compile/cudagraph tests
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 961b5e8 commit bd9bd37

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.nn import functional as F
1313
from transformers import MixtralConfig
1414
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
15-
from typing import Callable, Optional
15+
from typing import Callable, Optional, Union
1616

1717
import vllm.model_executor.layers.fused_moe # noqa
1818
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
@@ -44,7 +44,7 @@
4444

4545

4646
def run_moe_test(
47-
baseline_moe_fn: Callable,
47+
baseline: Union[Callable, torch.Tensor],
4848
moe_fn: Callable,
4949
a: torch.Tensor,
5050
w1: torch.Tensor,
@@ -58,8 +58,11 @@ def run_moe_test(
5858
use_cudagraph: bool = False,
5959
atol:float=2e-2,
6060
rtol:float=0,
61-
):
62-
baseline_output = baseline_moe_fn(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map)
61+
) -> torch.Tensor:
62+
if isinstance(baseline, torch.Tensor):
63+
baseline_output = baseline
64+
else:
65+
baseline_output = baseline(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map)
6366

6467
# Pad the weight if moe padding is enabled
6568
if padding:
@@ -77,7 +80,6 @@ def run_moe_test(
7780
global_num_experts=global_num_experts,
7881
expert_map=expert_map)
7982

80-
8183
if use_cudagraph:
8284
test_output.fill_(0)
8385
stream = torch.cuda.Stream()
@@ -96,8 +98,9 @@ def run_moe_test(
9698

9799
torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
98100

101+
return baseline_output
102+
99103

100-
# TODO: reduce combinations
101104
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
102105
@pytest.mark.parametrize("n", [128, 1024, 2048])
103106
@pytest.mark.parametrize("k", [128, 511, 1024])
@@ -192,13 +195,13 @@ def m_fused_moe(
192195
padding=padding,
193196
)
194197

195-
use_compile = m >= chunk_size and current_platform.is_cuda_alike()
198+
use_compile = m >= chunk_size and n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
196199
use_cudagraph = use_compile
197200

198201
with set_current_vllm_config(vllm_config):
199-
runner(torch_moe, iterative_moe)
200-
runner(torch_moe, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph)
201-
runner(torch_moe, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph)
202+
baseline_output = runner(torch_moe, iterative_moe)
203+
runner(baseline_output, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph)
204+
runner(baseline_output, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph)
202205

203206

204207
@pytest.mark.parametrize("m", [1, 32, 222])

tests/kernels/quantization/test_block_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
455455
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
456456
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
457457

458-
use_compile = M > chunk_size and current_platform.is_cuda_alike()
458+
use_compile = M > chunk_size and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
459459
use_cudagraph = use_compile
460460

461461
# Set the context to avoid lots of warning spam.

0 commit comments

Comments
 (0)