From 3d288bfd51be822fda5675ea86155f564a204e1e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 16:58:17 +0000 Subject: [PATCH 01/18] turn try_get_optimal_moe_config into an op so it can be torch.compiled Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 88 +++++++++++++++++-- 1 file changed, 83 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 437e80696ac6..e43577318223 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -810,15 +810,15 @@ def get_default_config( return config -def try_get_optimal_moe_config( - w1_shape: tuple[int, ...], - w2_shape: tuple[int, ...], +def try_get_optimal_moe_config_list( + w1_shape: list[int], + w2_shape: list[int], top_k: int, dtype: Optional[str], M: int, is_marlin: bool = False, block_shape: Optional[list[int]] = None, -): +) -> tuple[int, int, int, int, int, int]: from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() if override_config: @@ -840,7 +840,59 @@ def try_get_optimal_moe_config( # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - return config + + return [config['BLOCK_SIZE_M'], + config['BLOCK_SIZE_N'], + config['BLOCK_SIZE_K'], + config['GROUP_SIZE_SIZE_M'], + config['num_warps'] if 'num_warps' in config else 0, + config['num_stages'] if 'num_stages' in config else 0] + + +def try_get_optimal_moe_config_list_fake( + w1_shape: list[int], + w2_shape: list[int], + top_k: int, + dtype: Optional[str], + M: int, + is_marlin: bool = False, + block_shape: Optional[list[int]] = None, +) -> tuple[int, int, int, int]: + return [64, 64, 64, 8, 4, 3] + + +direct_register_custom_op( + op_name="try_get_optimal_moe_config_list", + op_func=try_get_optimal_moe_config_list, + fake_impl=try_get_optimal_moe_config_list_fake, + mutates_args=[], +) + + +def try_get_optimal_moe_config( + w1_shape: list[int], + w2_shape: list[int], + top_k: int, + dtype: Optional[str], + M: int, + is_marlin: bool = False, + block_shape: Optional[list[int]] = None, +) -> dict[str, int]: + block_m, block_n, block_k, group_m, num_warps, num_stages = torch.ops.vllm.try_get_optimal_moe_config_list( + w1_shape, + w2_shape, + top_k, + dtype, + M, + is_marlin, + block_shape, + ) + return dict(BLOCK_SIZE_M=block_m, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + GROUP_SIZE_M=group_m, + num_warps=num_warps, + num_stages=num_stages) def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, @@ -1182,6 +1234,32 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) + elif True: + fn = modular_triton_fused_moe( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape) + + return fn( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, From 385e0c5e58877f5bfe20d9dbc334cf64c5544df4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 17:03:46 +0000 Subject: [PATCH 02/18] lint Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e43577318223..6427b7152312 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -841,12 +841,14 @@ def try_get_optimal_moe_config_list( config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - return [config['BLOCK_SIZE_M'], - config['BLOCK_SIZE_N'], - config['BLOCK_SIZE_K'], - config['GROUP_SIZE_SIZE_M'], - config['num_warps'] if 'num_warps' in config else 0, - config['num_stages'] if 'num_stages' in config else 0] + return [ + config['BLOCK_SIZE_M'], + config['BLOCK_SIZE_N'], + config['BLOCK_SIZE_K'], + config['GROUP_SIZE_SIZE_M'], + config.get('num_warps', 0), + config.get('num_stages', 0), + ] def try_get_optimal_moe_config_list_fake( @@ -878,15 +880,16 @@ def try_get_optimal_moe_config( is_marlin: bool = False, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: - block_m, block_n, block_k, group_m, num_warps, num_stages = torch.ops.vllm.try_get_optimal_moe_config_list( - w1_shape, - w2_shape, - top_k, - dtype, - M, - is_marlin, - block_shape, - ) + block_m, block_n, block_k, group_m, num_warps, num_stages = ( + torch.ops.vllm.try_get_optimal_moe_config_list( + w1_shape, + w2_shape, + top_k, + dtype, + M, + is_marlin, + block_shape, + )) return dict(BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, @@ -1235,13 +1238,12 @@ def fused_experts(hidden_states: torch.Tensor, apply_router_weight_on_input=apply_router_weight_on_input, ) elif True: - fn = modular_triton_fused_moe( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape) + fn = modular_triton_fused_moe(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape) return fn( hidden_states=hidden_states, From c98ffbea568de863e104c929ef79e97767557236 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 17:37:47 +0000 Subject: [PATCH 03/18] torch.compile tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 8 ++++++++ .../layers/fused_moe/fused_moe.py | 20 ++++--------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index bed374cf4d56..b80a1fcbbdbe 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -57,7 +57,11 @@ def test_fused_moe( ep_size: int, dtype: torch.dtype, padding: bool, + monkeypatch, ): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -84,6 +88,10 @@ def test_fused_moe( per_channel_quant=False, block_shape=None) + m_fused_moe = torch.compile(m_fused_moe, + backend='inductor', + fullgraph=True) + with set_current_vllm_config(vllm_config): torch_output = torch_moe(a, w1, w2, score, topk, e_map) iterative_output = iterative_moe(a, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6427b7152312..645724f69ab9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -845,29 +845,17 @@ def try_get_optimal_moe_config_list( config['BLOCK_SIZE_M'], config['BLOCK_SIZE_N'], config['BLOCK_SIZE_K'], - config['GROUP_SIZE_SIZE_M'], - config.get('num_warps', 0), - config.get('num_stages', 0), + config['GROUP_SIZE_M'], + config.get('num_warps', 4), + config.get('num_stages', 3 if not current_platform.is_rocm() else 2), ] -def try_get_optimal_moe_config_list_fake( - w1_shape: list[int], - w2_shape: list[int], - top_k: int, - dtype: Optional[str], - M: int, - is_marlin: bool = False, - block_shape: Optional[list[int]] = None, -) -> tuple[int, int, int, int]: - return [64, 64, 64, 8, 4, 3] - - direct_register_custom_op( op_name="try_get_optimal_moe_config_list", op_func=try_get_optimal_moe_config_list, - fake_impl=try_get_optimal_moe_config_list_fake, mutates_args=[], + dispatch_key="", ) From c1c362a22b3d6c1441186dcf03e06504bf2d494e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 18:31:52 +0000 Subject: [PATCH 04/18] add tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 52 +++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index b80a1fcbbdbe..5a2edfde737b 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from typing import Callable import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe @@ -40,7 +41,54 @@ vllm_config.scheduler_config.max_model_len = 8192 -@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) +def run_moe_test( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, + padding: bool, + baseline_moe_fn: Callable, + moe_fn: Callable, + use_compile: bool = False, + use_cudagraph: bool = False, +): + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randint(0, + e, (local_e, ), + device="cuda", + dtype=torch.int32) + e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1 = w1[e_ids] + w2 = w2[e_ids] + else: + e_map = None + + with set_current_vllm_config(vllm_config): + baseline_output = baseline_moe_fn(a, w1, w2, score, topk, e_map) + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + + torch.testing.assert_close(test_output, baseline_output, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -60,7 +108,7 @@ def test_fused_moe( monkeypatch, ): current_platform.seed_everything(7) - monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "1024") a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 From 776ad95a313a1b33710a951de9ae27eaf6949153 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 21:13:03 +0000 Subject: [PATCH 05/18] add compiler + cudagraph tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_cutlass_moe.py | 34 +++-- tests/kernels/moe/test_moe.py | 182 +++++++++++++++++--------- tests/kernels/moe/test_nvfp4_moe.py | 2 +- tests/kernels/utils.py | 2 +- 4 files changed, 145 insertions(+), 75 deletions(-) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index ce420901e317..6be6bf4d773c 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -17,19 +17,21 @@ TOP_KS = [6, 8] MNK_FACTORS = [ - (2, 1024, 1024), - (2, 1024, 1536), - (2, 3072, 1024), - (2, 3072, 1536), - (64, 1024, 1024), - (64, 1024, 1536), - (64, 3072, 1024), - (64, 3072, 1536), - (224, 1024, 1024), - (224, 1024, 1536), - (224, 3072, 1024), - (224, 3072, 1536), - (1024 * 128, 1024, 1024), + # (2, 1024, 1024), + # (2, 1024, 1536), + # (2, 3072, 1024), + # (2, 3072, 1536), + # (64, 1024, 1024), + # (64, 1024, 1536), + # (64, 3072, 1024), + # (64, 3072, 1536), + # (224, 1024, 1024), + # (224, 1024, 1536), + # (224, 3072, 1024), + # (224, 3072, 1536), + (7232, 2048, 5120), + (32768, 1024, 1024), + (40000, 2048, 5120), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( @@ -232,8 +234,10 @@ def test_cutlass_moe_8_bit_no_graph( topk: int, per_act_token: bool, per_out_ch: bool, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) @@ -274,8 +278,10 @@ def test_cutlass_moe_8_bit_cuda_graph( topk: int, per_act_token: bool, per_out_ch: bool, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): dtype = torch.half @@ -329,8 +335,10 @@ def test_cutlass_moe_8_bit_EP( per_act_token: bool, per_out_channel: bool, ep_size: int, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 5a2edfde737b..a1c78b7567ae 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -4,13 +4,15 @@ Run `pytest tests/kernels/test_moe.py`. """ +import functools import pytest import torch + from torch.nn import Parameter from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -from typing import Callable +from typing import Callable, Optional import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe @@ -41,53 +43,6 @@ vllm_config.scheduler_config.max_model_len = 8192 -def run_moe_test( - m: int, - n: int, - k: int, - e: int, - topk: int, - ep_size: int, - dtype: torch.dtype, - padding: bool, - baseline_moe_fn: Callable, - moe_fn: Callable, - use_compile: bool = False, - use_cudagraph: bool = False, -): - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - score = torch.randn((m, e), device="cuda", dtype=dtype) - - if ep_size > 1: - local_e = e // ep_size - e_ids = torch.randint(0, - e, (local_e, ), - device="cuda", - dtype=torch.int32) - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) - e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) - w1 = w1[e_ids] - w2 = w2[e_ids] - else: - e_map = None - - with set_current_vllm_config(vllm_config): - baseline_output = baseline_moe_fn(a, w1, w2, score, topk, e_map) - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) - - torch.testing.assert_close(test_output, baseline_output, atol=2e-2, rtol=2e-2) - - @pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -108,8 +63,11 @@ def test_fused_moe( monkeypatch, ): current_platform.seed_everything(7) - monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "1024") + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + # + # Setup test data + # a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -129,19 +87,67 @@ def test_fused_moe( else: e_map = None - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None) + # + # Setup test functions + # + m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None) + + def m_fused_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor]= None, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + return m_fused_moe_fn(a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map) + + fused_moe_fn = functools.partial(fused_moe, renormalize=False) + + # + # Run tests + # + runner = functools.partial( + run_moe_test, + a=a, + w1=w1, + w2=w2, + score=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + padding=padding, + ) + + use_compile=True # TODO: platform supports torch.compile + use_cudagraph=True # TODO: platform supports cudagraphs + + with set_current_vllm_config(vllm_config): + runner(torch_moe, iterative_moe) + runner(torch_moe, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph) + runner(torch_moe, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph) + + return m_fused_moe = torch.compile(m_fused_moe, backend='inductor', fullgraph=True) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w1, w2, score, topk, e_map) + torch_output = torch_moe(a, w1, w2, score, topk, expert_map=e_map) iterative_output = iterative_moe(a, w1, w2, @@ -171,8 +177,8 @@ def test_fused_moe( m_triton_output = m_fused_moe(a, w1, w2, - topk_weights, - topk_ids, + score, + topk, global_num_experts=e, expert_map=e_map) @@ -187,6 +193,62 @@ def test_fused_moe( rtol=0) +def run_moe_test( + baseline_moe_fn: Callable, + moe_fn: Callable, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + padding: bool = False, + use_compile: bool = False, + use_cudagraph: bool = False, + atol:float=2e-2, + rtol:float=0, +): + baseline_output = baseline_moe_fn(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + + if use_compile: + moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True) + + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + + + if use_cudagraph: + test_output.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) + + @pytest.mark.parametrize("m", [1, 32, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @@ -294,7 +356,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -602,7 +664,7 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 22482d9ca85a..76b560e1bb41 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, device=w2.device, block_size=quant_blocksize) - torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) torch.testing.assert_close(torch_output, cutlass_output, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index d1db6a8eb1ba..b686bbd78cb7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1054,7 +1054,7 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_moe(a, w1, w2, score, topk, expert_map): +def torch_moe(a, w1, w2, score, topk, global_num_experts=-1, expert_map=None): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) From 961b5e8a1781b8a299ebdc30ddd417260dcf743a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 22:21:26 +0000 Subject: [PATCH 06/18] tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_cutlass_moe.py | 31 ++-- tests/kernels/moe/test_moe.py | 174 +++++++----------- tests/kernels/moe/test_pplx_cutlass_moe.py | 22 +-- tests/kernels/moe/test_pplx_moe.py | 29 +-- tests/kernels/quantization/test_block_fp8.py | 38 +++- tests/kernels/utils.py | 30 ++- .../layers/fused_moe/deep_gemm_moe.py | 12 +- .../layers/quantization/deepgemm.py | 22 +++ 8 files changed, 173 insertions(+), 185 deletions(-) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 6be6bf4d773c..0d6654aad6d4 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -17,21 +17,22 @@ TOP_KS = [6, 8] MNK_FACTORS = [ - # (2, 1024, 1024), - # (2, 1024, 1536), - # (2, 3072, 1024), - # (2, 3072, 1536), - # (64, 1024, 1024), - # (64, 1024, 1536), - # (64, 3072, 1024), - # (64, 3072, 1536), - # (224, 1024, 1024), - # (224, 1024, 1536), - # (224, 3072, 1024), - # (224, 3072, 1536), - (7232, 2048, 5120), - (32768, 1024, 1024), - (40000, 2048, 5120), + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 3072, 1536), + (224, 1024, 1024), + (224, 1024, 1536), + (224, 3072, 1024), + (224, 3072, 1536), + # These sizes trigger wrong answers. + # (7232, 2048, 5120), + # (32768, 1024, 1024), + # (40000, 2048, 5120), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index a1c78b7567ae..1bb282c85fb8 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -43,6 +43,61 @@ vllm_config.scheduler_config.max_model_len = 8192 +def run_moe_test( + baseline_moe_fn: Callable, + moe_fn: Callable, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + padding: bool = False, + use_compile: bool = False, + use_cudagraph: bool = False, + atol:float=2e-2, + rtol:float=0, +): + baseline_output = baseline_moe_fn(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + + if use_compile: + moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True) + + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + + + if use_cudagraph: + test_output.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) + + +# TODO: reduce combinations @pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -51,6 +106,7 @@ @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) +@pytest.mark.parametrize("chunk_size", [8192]) def test_fused_moe( m: int, n: int, @@ -60,14 +116,17 @@ def test_fused_moe( ep_size: int, dtype: torch.dtype, padding: bool, + chunk_size: int, monkeypatch, ): current_platform.seed_everything(7) - monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) # # Setup test data # + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -90,6 +149,7 @@ def test_fused_moe( # # Setup test functions # + m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, use_int8_w8a8=False, use_int8_w8a16=False, @@ -132,122 +192,14 @@ def m_fused_moe( padding=padding, ) - use_compile=True # TODO: platform supports torch.compile - use_cudagraph=True # TODO: platform supports cudagraphs + use_compile = m >= chunk_size and current_platform.is_cuda_alike() + use_cudagraph = use_compile with set_current_vllm_config(vllm_config): runner(torch_moe, iterative_moe) runner(torch_moe, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph) runner(torch_moe, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph) - return - - m_fused_moe = torch.compile(m_fused_moe, - backend='inductor', - fullgraph=True) - - with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w1, w2, score, topk, expert_map=e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) - - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) - - topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - m_triton_output = m_fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map) - - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - torch.testing.assert_close(m_triton_output, - torch_output, - atol=2e-2, - rtol=0) - torch.testing.assert_close(iterative_output, - torch_output, - atol=2e-2, - rtol=0) - - -def run_moe_test( - baseline_moe_fn: Callable, - moe_fn: Callable, - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, - topk: int, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - padding: bool = False, - use_compile: bool = False, - use_cudagraph: bool = False, - atol:float=2e-2, - rtol:float=0, -): - baseline_output = baseline_moe_fn(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map) - - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - - if use_compile: - moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True) - - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) - - - if use_cudagraph: - test_output.fill_(0) - stream = torch.cuda.Stream() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) - torch.cuda.synchronize() - graph.replay() - torch.cuda.synchronize() - - torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) - @pytest.mark.parametrize("m", [1, 32, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index d90202dfcb3b..fef53b73ba0a 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -15,6 +15,8 @@ FusedMoEModularKernel) from vllm.platforms import current_platform +from tests.kernels.utils import torch_experts + from .deepep_utils import ProcessGroupInfo, parallel_launch try: @@ -164,22 +166,6 @@ def pplx_cutlass_moe( vllm_config.scheduler_config.max_model_len = 8192 -def torch_moe2(a, w1, w2, topk_weight, topk_ids): - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, @@ -210,8 +196,8 @@ def _pplx_moe( group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): - torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, - topk_ids) + torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights, + topk_ids) pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, w2_scale, topk_weights, topk_ids, a1_scale, out_dtype, per_act_token, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 2d6a8f39cec5..61b2144098e1 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -29,6 +29,8 @@ FusedMoEModularKernel) from vllm.platforms import current_platform +from tests.kernels.utils import torch_experts + from .deepep_utils import ProcessGroupInfo, parallel_launch requires_pplx = pytest.mark.skipif( @@ -163,29 +165,6 @@ def batched_moe( return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) -# Note: same as torch_moe but with fused_topk factored out. -def torch_moe2( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, -) -> torch.Tensor: - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @@ -209,7 +188,7 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) @@ -576,7 +555,7 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids) pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a, w1, w2, topk_weight, topk_ids) # TODO (bnell): fix + re-enable diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index eec59573792d..a449c12f240b 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -403,19 +403,23 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS)) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - block_size = [block_m, block_m] - dtype = torch.bfloat16 - +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): if topk > E: pytest.skip(f"Skipping test: topk={topk} > E={E}") if not _valid_deep_gemm_shape(M, N, K): pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") + chunk_size = 1024 + torch.manual_seed(seed) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_size = [block_m, block_m] + dtype = torch.bfloat16 + fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -451,6 +455,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + use_compile = M > chunk_size and current_platform.is_cuda_alike() + use_cudagraph = use_compile + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -463,7 +470,24 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + if use_compile: + deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, + backend="inductor", + fullgraph=True) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + + if use_cudagraph: + out.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index b686bbd78cb7..ca106e79a685 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1054,12 +1054,22 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_moe(a, w1, w2, score, topk, global_num_experts=-1, expert_map=None): +def torch_experts( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None +) -> torch.Tensor: + assert (global_num_experts == -1 or + (global_num_experts == w1.shape[0] and expert_map is None) or + global_num_experts == expert_map.shape[0]) + topk = topk_ids.shape[1] B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) if expert_map is not None: @@ -1073,6 +1083,20 @@ def torch_moe(a, w1, w2, score, topk, global_num_experts=-1, expert_map=None): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None +) -> torch.Tensor: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map) + + def torch_moe_single(a, w, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b4473b907381..9c55ef4b7dc6 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -7,6 +7,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +import vllm.model_executor.layers.quantization.deepgemm + from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute) @@ -108,8 +110,6 @@ def apply( workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], ): - import deep_gemm as dg - a1q = hidden_states _, N, K = w1.size() @@ -144,8 +144,8 @@ def apply( (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) + torch.ops.vllm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( + a1q, a1q_scale, w1, w1_scale, mm1_out, expert_ids) self.activation(activation, act_out, mm1_out.view(-1, N)) @@ -155,8 +155,8 @@ def apply( column_major_scales=True, out_q=quant_out) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) + torch.ops.vllm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( + a2q, a2q_scale, w2, w2_scale, mm2_out, expert_ids) torch.index_select(mm2_out, 0, inv_perm, out=output) diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index 1d40f4915a1b..0bbcc7f3d913 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -4,6 +4,8 @@ import torch +from typing import Optional + from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils import direct_register_custom_op @@ -75,6 +77,18 @@ def w8a8_block_fp8_matmul_deepgemm_fake( return C +def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( + a: torch.Tensor, + a_scale: Optional[torch.Tensor], + b: torch.Tensor, + b_scale: Optional[torch.Tensor], + output: torch.Tensor, + expert_ids: torch.Tensor, +) -> None: + import deep_gemm as dg + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a, a_scale), (b, b_scale), output, expert_ids) + + direct_register_custom_op( op_name="w8a8_block_fp8_matmul_deepgemm", op_func=w8a8_block_fp8_matmul_deepgemm, @@ -82,3 +96,11 @@ def w8a8_block_fp8_matmul_deepgemm_fake( fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, dispatch_key=current_platform.dispatch_key, ) + + +direct_register_custom_op( + op_name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm", + op_func=m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm, + mutates_args=["output"], + dispatch_key=current_platform.dispatch_key, +) From bd9bd379e611c5aa8a096bdae06f047a667557a5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 22:31:02 +0000 Subject: [PATCH 07/18] reduce number of compile/cudagraph tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 23 +++++++++++--------- tests/kernels/quantization/test_block_fp8.py | 2 +- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 1bb282c85fb8..f32e42c19b4a 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -12,7 +12,7 @@ from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -from typing import Callable, Optional +from typing import Callable, Optional, Union import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe @@ -44,7 +44,7 @@ def run_moe_test( - baseline_moe_fn: Callable, + baseline: Union[Callable, torch.Tensor], moe_fn: Callable, a: torch.Tensor, w1: torch.Tensor, @@ -58,8 +58,11 @@ def run_moe_test( use_cudagraph: bool = False, atol:float=2e-2, rtol:float=0, -): - baseline_output = baseline_moe_fn(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map) +) -> torch.Tensor: + if isinstance(baseline, torch.Tensor): + baseline_output = baseline + else: + baseline_output = baseline(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map) # Pad the weight if moe padding is enabled if padding: @@ -77,7 +80,6 @@ def run_moe_test( global_num_experts=global_num_experts, expert_map=expert_map) - if use_cudagraph: test_output.fill_(0) stream = torch.cuda.Stream() @@ -96,8 +98,9 @@ def run_moe_test( torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) + return baseline_output + -# TODO: reduce combinations @pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -192,13 +195,13 @@ def m_fused_moe( padding=padding, ) - use_compile = m >= chunk_size and current_platform.is_cuda_alike() + use_compile = m >= chunk_size and n >= 1024 and k >= 1024 and current_platform.is_cuda_alike() use_cudagraph = use_compile with set_current_vllm_config(vllm_config): - runner(torch_moe, iterative_moe) - runner(torch_moe, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph) - runner(torch_moe, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph) + baseline_output = runner(torch_moe, iterative_moe) + runner(baseline_output, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph) + runner(baseline_output, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph) @pytest.mark.parametrize("m", [1, 32, 222]) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index a449c12f240b..c8cffb6700b5 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -455,7 +455,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - use_compile = M > chunk_size and current_platform.is_cuda_alike() + use_compile = M > chunk_size and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike() use_cudagraph = use_compile # Set the context to avoid lots of warning spam. From 23f26c992301a9506afdda83c96c9694f03de0b8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 22:39:36 +0000 Subject: [PATCH 08/18] lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 50 ++++++++++++++----- tests/kernels/moe/test_pplx_cutlass_moe.py | 4 +- tests/kernels/moe/test_pplx_moe.py | 4 +- tests/kernels/quantization/test_block_fp8.py | 12 +++-- tests/kernels/utils.py | 41 +++++++-------- .../layers/fused_moe/deep_gemm_moe.py | 2 - .../layers/quantization/deepgemm.py | 7 ++- 7 files changed, 70 insertions(+), 50 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index f32e42c19b4a..df825ea65e9a 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -5,14 +5,14 @@ Run `pytest tests/kernels/test_moe.py`. """ import functools +from typing import Callable, Optional, Union + import pytest import torch - from torch.nn import Parameter from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -from typing import Callable, Optional, Union import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe @@ -56,13 +56,19 @@ def run_moe_test( padding: bool = False, use_compile: bool = False, use_cudagraph: bool = False, - atol:float=2e-2, - rtol:float=0, + atol: float = 2e-2, + rtol: float = 0, ) -> torch.Tensor: if isinstance(baseline, torch.Tensor): baseline_output = baseline else: - baseline_output = baseline(a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map) + baseline_output = baseline(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) # Pad the weight if moe padding is enabled if padding: @@ -96,7 +102,10 @@ def run_moe_test( graph.replay() torch.cuda.synchronize() - torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) + torch.testing.assert_close(test_output, + baseline_output, + atol=atol, + rtol=rtol) return baseline_output @@ -167,7 +176,7 @@ def m_fused_moe( score: torch.Tensor, topk: int, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor]= None, + expert_map: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) return m_fused_moe_fn(a, @@ -195,13 +204,20 @@ def m_fused_moe( padding=padding, ) - use_compile = m >= chunk_size and n >= 1024 and k >= 1024 and current_platform.is_cuda_alike() + use_compile = (m >= chunk_size and n >= 1024 and k >= 1024 + and current_platform.is_cuda_alike()) use_cudagraph = use_compile with set_current_vllm_config(vllm_config): baseline_output = runner(torch_moe, iterative_moe) - runner(baseline_output, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph) - runner(baseline_output, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph) + runner(baseline_output, + fused_moe_fn, + use_compile=use_compile, + use_cudagraph=use_cudagraph) + runner(baseline_output, + m_fused_moe, + use_compile=use_compile, + use_cudagraph=use_cudagraph) @pytest.mark.parametrize("m", [1, 32, 222]) @@ -311,7 +327,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map) + torch_output = torch_moe(a, + w1_ref, + w2_ref, + score, + topk, + expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -619,7 +640,12 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) + torch_output = torch_moe(a, + w_ref1, + w_ref2, + score, + topk, + expert_map=e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index fef53b73ba0a..0caf14f040bb 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -6,17 +6,15 @@ import pytest import torch +from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform -from tests.kernels.utils import torch_experts - from .deepep_utils import ProcessGroupInfo, parallel_launch try: diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 61b2144098e1..945fd391fba5 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -18,8 +18,8 @@ except ImportError: has_pplx = False +from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) @@ -29,8 +29,6 @@ FusedMoEModularKernel) from vllm.platforms import current_platform -from tests.kernels.utils import torch_experts - from .deepep_utils import ProcessGroupInfo, parallel_launch requires_pplx = pytest.mark.skipif( diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c8cffb6700b5..c2fe75f510d6 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -403,7 +403,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS)) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, + monkeypatch): if topk > E: pytest.skip(f"Skipping test: topk={topk} > E={E}") @@ -455,7 +456,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - use_compile = M > chunk_size and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike() + use_compile = (chunk_size < M and N >= 1024 and K >= 1024 + and current_platform.is_cuda_alike()) use_cudagraph = use_compile # Set the context to avoid lots of warning spam. @@ -477,14 +479,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) if use_cudagraph: out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ca106e79a685..4cdaee6da878 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1054,18 +1054,16 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_experts( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None -) -> torch.Tensor: - assert (global_num_experts == -1 or - (global_num_experts == w1.shape[0] and expert_map is None) or - global_num_experts == expert_map.shape[0]) +def torch_experts(a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + assert (global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or global_num_experts == expert_map.shape[0]) topk = topk_ids.shape[1] B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -1083,18 +1081,17 @@ def torch_experts( topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -def torch_moe( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, - topk: int, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None -) -> torch.Tensor: +def torch_moe(a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) - return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map) + return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, + expert_map) def torch_moe_single(a, w, score, topk): diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 9c55ef4b7dc6..5eb026400ef2 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -7,8 +7,6 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -import vllm.model_executor.layers.quantization.deepgemm - from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute) diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index 0bbcc7f3d913..d5079e3ba405 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util import logging +from typing import Optional import torch -from typing import Optional - from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils import direct_register_custom_op @@ -86,7 +85,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( expert_ids: torch.Tensor, ) -> None: import deep_gemm as dg - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a, a_scale), (b, b_scale), output, expert_ids) + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a, a_scale), (b, b_scale), + output, expert_ids) direct_register_custom_op( @@ -97,7 +97,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( dispatch_key=current_platform.dispatch_key, ) - direct_register_custom_op( op_name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm", op_func=m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm, From 5d564f68fca07be0905f5e018620b4e16584ae89 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 16 Jun 2025 23:00:06 +0000 Subject: [PATCH 09/18] fix lint Signed-off-by: Bill Nell --- tests/kernels/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 4cdaee6da878..dcda8e479b29 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1063,7 +1063,8 @@ def torch_experts(a: torch.Tensor, expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: assert (global_num_experts == -1 or (global_num_experts == w1.shape[0] and expert_map is None) - or global_num_experts == expert_map.shape[0]) + or (expert_map is not None + and global_num_experts == expert_map.shape[0])) topk = topk_ids.shape[1] B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) From 06b45832a7500dd1fe0bf6c39284b2598edf4978 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 17 Jun 2025 01:14:44 +0000 Subject: [PATCH 10/18] fix lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 645724f69ab9..ae63f16f9bf8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -841,14 +841,14 @@ def try_get_optimal_moe_config_list( config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - return [ + return ( config['BLOCK_SIZE_M'], config['BLOCK_SIZE_N'], config['BLOCK_SIZE_K'], config['GROUP_SIZE_M'], config.get('num_warps', 4), config.get('num_stages', 3 if not current_platform.is_rocm() else 2), - ] + ) direct_register_custom_op( From 463ccaa30dfea26838afcb1d2f90068c46fa7814 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 17 Jun 2025 02:31:08 +0000 Subject: [PATCH 11/18] replace import that lint removed Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 5eb026400ef2..1ffa1c1a8f22 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -6,6 +6,7 @@ import torch +import vllm.model_executor.layers.quantization.deepgemm import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( From 4ab6af7e287dc067cad0cfdde20932b31e34bd55 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 17 Jun 2025 02:52:58 +0000 Subject: [PATCH 12/18] fixes Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 77 ++++++++----------- 1 file changed, 30 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ae63f16f9bf8..2b8704862846 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -842,12 +842,12 @@ def try_get_optimal_moe_config_list( is_marlin, block_shape) return ( - config['BLOCK_SIZE_M'], - config['BLOCK_SIZE_N'], - config['BLOCK_SIZE_K'], - config['GROUP_SIZE_M'], - config.get('num_warps', 4), - config.get('num_stages', 3 if not current_platform.is_rocm() else 2), + config.get('BLOCK_SIZE_M', 0), + config.get('BLOCK_SIZE_N', 0), + config.get('BLOCK_SIZE_K', 0), + config.get('GROUP_SIZE_M', 0), + config.get('num_warps', 0), + config.get('num_stages', 0), ) @@ -868,22 +868,30 @@ def try_get_optimal_moe_config( is_marlin: bool = False, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: - block_m, block_n, block_k, group_m, num_warps, num_stages = ( - torch.ops.vllm.try_get_optimal_moe_config_list( - w1_shape, - w2_shape, - top_k, - dtype, - M, - is_marlin, - block_shape, - )) - return dict(BLOCK_SIZE_M=block_m, - BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, - GROUP_SIZE_M=group_m, - num_warps=num_warps, - num_stages=num_stages) + values = torch.ops.vllm.try_get_optimal_moe_config_list( + w1_shape, + w2_shape, + top_k, + dtype, + M, + is_marlin, + block_shape, + ) + + config = dict() + + keys = ["BLOCK_SIZE_M", "BLOCK_SIZE_N", + "BLOCK_SIZE_K", "GROUP_SIZE_M", + "num_warps", "num_stages"] + + assert len(keys) == len(values) + + config = dict() + for k, v in zip(keys, values): + if v != 0: + config[k] = v + + return config def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, @@ -1225,31 +1233,6 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif True: - fn = modular_triton_fused_moe(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape) - - return fn( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, From 695203df924b9f337c8de7976f1d2b5180cdf3d9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 17 Jun 2025 13:42:36 +0000 Subject: [PATCH 13/18] lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 6 ++++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 9 +++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 1ffa1c1a8f22..039a65225675 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -6,7 +6,6 @@ import torch -import vllm.model_executor.layers.quantization.deepgemm import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( @@ -15,6 +14,9 @@ MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.deepgemm import ( # noqa: E501 + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm as + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm) from vllm.utils import round_up logger = init_logger(__name__) @@ -49,7 +51,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, M = hidden_states.size(0) _, K, N = w2.size() if not _valid_deep_gemm_shape(M, N, K): - logger.debug("DeepGemm disabled: unalinged problem size.") + logger.debug("DeepGemm disabled: unaligned problem size.") return False if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2b8704862846..f739c1894c07 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -878,11 +878,12 @@ def try_get_optimal_moe_config( block_shape, ) - config = dict() + config: dict[str, int] = dict() - keys = ["BLOCK_SIZE_M", "BLOCK_SIZE_N", - "BLOCK_SIZE_K", "GROUP_SIZE_M", - "num_warps", "num_stages"] + keys = [ + "BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K", "GROUP_SIZE_M", + "num_warps", "num_stages" + ] assert len(keys) == len(values) From 287a204703a47327c8421717e12f76d0cff9c184 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 18 Jun 2025 19:59:53 +0000 Subject: [PATCH 14/18] opify at a higher level Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 8 +- tests/kernels/moe/test_pplx_moe.py | 5 +- vllm/envs.py | 1 + .../layers/fused_moe/cutlass_moe.py | 70 ++++----- .../layers/fused_moe/deep_gemm_moe.py | 5 +- .../fused_moe/deepep_ll_prepare_finalize.py | 2 +- .../layers/fused_moe/fused_moe.py | 142 ++++++------------ vllm/model_executor/layers/fused_moe/layer.py | 19 +-- .../layers/fused_moe/modular_kernel.py | 5 + .../layers/fused_moe/pplx_prepare_finalize.py | 2 +- 10 files changed, 107 insertions(+), 152 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index df825ea65e9a..6b5b99e84108 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -77,6 +77,8 @@ def run_moe_test( if use_compile: moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(score, 0) test_output = moe_fn(a, w1, @@ -204,9 +206,9 @@ def m_fused_moe( padding=padding, ) - use_compile = (m >= chunk_size and n >= 1024 and k >= 1024 - and current_platform.is_cuda_alike()) - use_cudagraph = use_compile + use_compile = False + use_cudagraph = (n >= 1024 and k >= 1024 + and current_platform.is_cuda_alike()) with set_current_vllm_config(vllm_config): baseline_output = runner(torch_moe, iterative_moe) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 945fd391fba5..929067266ece 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -386,7 +386,7 @@ def pplx_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - use_compile: bool = True, + use_compile: bool = False, use_cudagraphs: bool = True, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( @@ -451,6 +451,9 @@ def pplx_moe( _fused_experts = torch.compile(fused_experts, backend='inductor', fullgraph=True) + torch._dynamo.mark_dynamic(a_chunk, 0) + torch._dynamo.mark_dynamic(chunk_topk_weight, 0) + torch._dynamo.mark_dynamic(chunk_topk_ids, 0) else: _fused_experts = fused_experts diff --git a/vllm/envs.py b/vllm/envs.py index 921052821ee3..41a1cfb5e872 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -942,6 +942,7 @@ def factorize(name: str): "VLLM_DP_RANK", "VLLM_DP_SIZE", "VLLM_USE_STANDALONE_COMPILE", + "VLLM_FUSED_MOE_CHUNK_SIZE", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 3f9ceac8b6e3..73d169a84808 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -41,24 +41,24 @@ def run_cutlass_moe_fp8( assert w1.dtype == torch.float8_e4m3fn assert w2.dtype == torch.float8_e4m3fn if expert_num_tokens is None: - assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1" + assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1" else: - assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1" - assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1.shape[1], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2.shape[1], "W2 scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert a1q_scale is None or a1q_scale.dim( - ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ - 0], "Input scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a2_scale is None or a2_scale.dim( - ) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[ - 0], "Intermediate scale shape mismatch" + assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1" + assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2" + assert w1_scale.dim() == 1 or w1_scale.size( + 1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.size( + 1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch" + assert w1.size(0) == w2.size(0), "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size( + 0) == 1 or a1q_scale.size( + 0) == a1q.shape[0], "Input scale shape mismatch" + assert w1.size(0) == w2.size(0), "Weights expert number mismatch" + assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch" + assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch" + assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size( + 0) == 1 or a2_scale.size( + 0) == a1q.shape[0], "Intermediate scale shape mismatch" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" if expert_map is not None: assert expert_num_tokens is None @@ -75,12 +75,12 @@ def run_cutlass_moe_fp8( # their tokens are already contiguous for each expert as a result of # the dispatch function. - M = a1q.shape[0] # non batched expert M - padded_M = a1q.shape[1] # batched expert M + M = a1q.size(0) # non batched expert M + padded_M = a1q.size(1) # batched expert M _, K, N = w2.shape device = a1q.device - assert w1.shape[2] == K + assert w1.size(2) == K assert global_num_experts != -1 assert a1q_scale is not None @@ -91,8 +91,8 @@ def run_cutlass_moe_fp8( else: local_topk_ids = topk_ids - topk = local_topk_ids.shape[1] - local_E = w1.shape[0] + topk = local_topk_ids.size(1) + local_E = w1.size(0) if use_batched_format: assert expert_num_tokens is not None @@ -111,10 +111,10 @@ def run_cutlass_moe_fp8( problem_sizes2, expert_num_tokens, local_E, padded_M, N, K) - w1_scale = w1_scale.reshape(w1_scale.shape[0], -1) - w2_scale = w2_scale.reshape(w2_scale.shape[0], -1) - a1q = a1q.reshape(-1, a1q.shape[2]) - a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous() + w1_scale = w1_scale.reshape(w1_scale.size(0), -1) + w2_scale = w2_scale.reshape(w2_scale.size(0), -1) + a1q = a1q.reshape(-1, a1q.size(2)) + a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() else: expert_offsets = torch.empty((global_num_experts + 1), @@ -151,19 +151,19 @@ def run_cutlass_moe_fp8( a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.shape[0], ), + ab_strides1 = torch.full((w1.size(0), ), K, device=device, dtype=torch.int64) - c_strides1 = torch.full((w1.shape[0], ), + c_strides1 = torch.full((w1.size(0), ), 2 * N, device=device, dtype=torch.int64) - ab_strides2 = torch.full((w1.shape[0], ), + ab_strides2 = torch.full((w1.size(0), ), N, device=device, dtype=torch.int64) - c_strides2 = torch.full((w1.shape[0], ), + c_strides2 = torch.full((w1.size(0), ), K, device=device, dtype=torch.int64) @@ -237,7 +237,7 @@ def workspace_shapes( workspace2: tuple[int, ...] = () output: tuple[int, ...] = () if self.use_batched_format: - padded_M = aq.shape[1] + padded_M = aq.size(1) workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) output = (self.max_experts_per_worker, padded_M, K) @@ -332,7 +332,7 @@ def cutlass_moe_fp8( """ per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - per_out_ch = w1_scale.numel() != w1_q.shape[0] + per_out_ch = w1_scale.numel() != w1_q.size(0) out_dtype = a.dtype @@ -425,11 +425,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert (topk_weights.shape[0] == m and topk_ids.shape[0] + assert (topk_weights.size(0) == m and topk_ids.size(0) == m), ("topk must be provided for each row of a") out_dtype = a.dtype - num_topk = topk_ids.shape[1] + num_topk = topk_ids.size(1) expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) @@ -463,7 +463,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, out_dtype, device) del rep_a_fp4, rep_a_blockscale # hidden size dimension is split to one halfpytho sized tensor. - intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2), + intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), device=device, dtype=out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 039a65225675..3f0f9c36e05f 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -14,10 +14,11 @@ MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, per_token_group_quant_fp8) -from vllm.model_executor.layers.quantization.deepgemm import ( # noqa: E501 +from vllm.utils import round_up + +from vllm.model_executor.layers.quantization.deepgemm import ( # isort:skip m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm as m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm) -from vllm.utils import round_up logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 3484a7a8a496..5a8accd80463 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -25,7 +25,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor, expert_x_fp32 = expert_x_fp8.to(torch.float32).view( num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) expert_x_scales = expert_x_scales.view(num_experts, -1, 1) - return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f739c1894c07..f22884b8a1a5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -488,10 +488,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None - assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) - == B_scale.shape[-2]) - assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) - == B_scale.shape[-1]) + assert (block_shape is None + or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) + assert (block_shape is None + or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None @@ -500,19 +500,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert A_scale is None assert B_scale is None - M = A.shape[0] + M = A.size(0) num_tokens = M * top_k - EM = sorted_token_ids.shape[0] - if A.shape[0] < config["BLOCK_SIZE_M"]: + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. # We assume that top_ids of each token is unique, so # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. - EM = min(sorted_token_ids.shape[0], - A.shape[0] * top_k * config['BLOCK_SIZE_M']) + EM = min(sorted_token_ids.size(0), + A.size(0) * top_k * config['BLOCK_SIZE_M']) grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( - B.shape[1], META['BLOCK_SIZE_N']), ) + B.size(1), META['BLOCK_SIZE_N']), ) if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: @@ -522,16 +522,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, group_size=block_shape[1], - num_experts=B.shape[0], + num_experts=B.size(0), bit=4 if use_int4_w4a16 else 8) config = config.copy() config.update( get_moe_wna16_block_config(config=config, use_moe_wna16_cuda=use_moe_wna16_cuda, num_valid_tokens=num_tokens, - size_k=A.shape[1], - size_n=B.shape[1], - num_experts=B.shape[1], + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), group_size=block_shape[1], real_top_k=top_k, block_size_m=config["BLOCK_SIZE_M"])) @@ -556,8 +556,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.shape[1], - A.shape[1], + B.size(1), + A.size(1), EM, num_tokens, A.stride(0), @@ -573,7 +573,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0, - block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, group_size=block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, @@ -599,8 +599,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.shape[1], - B.shape[2], + B.size(1), + B.size(2), EM, num_tokens, A.stride(0), @@ -810,15 +810,15 @@ def get_default_config( return config -def try_get_optimal_moe_config_list( - w1_shape: list[int], - w2_shape: list[int], +def try_get_optimal_moe_config( + w1_shape: tuple[int, ...], + w2_shape: tuple[int, ...], top_k: int, dtype: Optional[str], M: int, is_marlin: bool = False, block_shape: Optional[list[int]] = None, -) -> tuple[int, int, int, int, int, int]: +) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() if override_config: @@ -840,58 +840,6 @@ def try_get_optimal_moe_config_list( # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - - return ( - config.get('BLOCK_SIZE_M', 0), - config.get('BLOCK_SIZE_N', 0), - config.get('BLOCK_SIZE_K', 0), - config.get('GROUP_SIZE_M', 0), - config.get('num_warps', 0), - config.get('num_stages', 0), - ) - - -direct_register_custom_op( - op_name="try_get_optimal_moe_config_list", - op_func=try_get_optimal_moe_config_list, - mutates_args=[], - dispatch_key="", -) - - -def try_get_optimal_moe_config( - w1_shape: list[int], - w2_shape: list[int], - top_k: int, - dtype: Optional[str], - M: int, - is_marlin: bool = False, - block_shape: Optional[list[int]] = None, -) -> dict[str, int]: - values = torch.ops.vllm.try_get_optimal_moe_config_list( - w1_shape, - w2_shape, - top_k, - dtype, - M, - is_marlin, - block_shape, - ) - - config: dict[str, int] = dict() - - keys = [ - "BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K", "GROUP_SIZE_M", - "num_warps", "num_stages" - ] - - assert len(keys) == len(values) - - config = dict() - for k, v in zip(keys, values): - if v != 0: - config[k] = v - return config @@ -925,10 +873,10 @@ def fused_topk( renormalize: bool, indices_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( + assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") - M, _ = hidden_states.shape + M, _ = hidden_states.size() topk_weights = torch.empty(M, topk, @@ -967,7 +915,7 @@ def grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( + assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") if scoring_func == "softmax": @@ -977,7 +925,7 @@ def grouped_topk( else: raise ValueError(f"Unsupported scoring function: {scoring_func}") - num_token = scores.shape[0] + num_token = scores.size(0) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights @@ -994,7 +942,7 @@ def grouped_topk( group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] @@ -1214,7 +1162,7 @@ def fused_experts(hidden_states: torch.Tensor, allow_deep_gemm: bool = False) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. - N = w1.shape[1] + N = w1.size(1) if (allow_deep_gemm and use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2)): assert apply_router_weight_on_input is False @@ -1285,13 +1233,13 @@ def fused_experts_impl( ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: - assert hidden_states.shape[1] // 2 == w1.shape[ - 2], "Hidden size mismatch" + assert hidden_states.size(1) // 2 == w1.size(2), ( + "Hidden size mismatch") else: - assert hidden_states.shape[1] == w1.shape[2], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") + assert hidden_states.size(1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" @@ -1299,12 +1247,12 @@ def fused_experts_impl( torch.float32, torch.float16, torch.bfloat16 ] - num_tokens = hidden_states.shape[0] - E, N, _ = w1.shape - K = w2.shape[1] + num_tokens = hidden_states.size(0) + E, N, _ = w1.size() + K = w2.size(1) if global_num_experts == -1: global_num_experts = E - top_k_num = topk_ids.shape[1] + top_k_num = topk_ids.size(1) # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE @@ -1321,8 +1269,8 @@ def fused_experts_impl( get_config_func = functools.partial( try_get_optimal_moe_config, - w1.shape, - w2.shape, + w1.size(), + w2.size(), top_k_num, config_dtype, block_shape=block_shape, @@ -1362,7 +1310,7 @@ def fused_experts_impl( min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape + tokens_in_chunk, _ = curr_hidden_states.size() if tokens_in_chunk == 0: break @@ -1374,7 +1322,7 @@ def fused_experts_impl( # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.shape[1]] + topk_ids.size(1)] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) @@ -1450,7 +1398,7 @@ def fused_experts_impl( per_channel_quant=per_channel_quant, block_shape=block_shape) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states @@ -1663,8 +1611,8 @@ def apply( dtype=hidden_states.dtype) config = try_get_optimal_moe_config( - w1.shape, - w2.shape, + w1.size(), + w2.size(), top_k_num, config_dtype, num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1fd8f2175886..bd5eeaa849db 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -853,13 +853,11 @@ def __init__( self.global_num_experts = num_experts # For smuggling this layer into the fused moe custom op - self.use_direct_call = self.dp_size == 1 - if not self.use_direct_call: - compilation_config = vllm_config.compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError("Duplicate layer name: {}".format(prefix)) - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix # Determine expert maps if self.use_ep: @@ -1353,11 +1351,8 @@ def maybe_all_reduce_tensor_model_parallel( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call: - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ed3b6b8a1af4..adac6da2dc8e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -507,3 +507,8 @@ def forward( topk_ids, apply_router_weight_on_input) return output + + def compile(self, *args, **kwargs) -> None: + print(f"ARGS {args}") + print(f"KWARGS {kwargs}") + #super().compile(*args, **kwargs) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 5bc01dbf2025..2ff8ef99b2ec 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -69,7 +69,7 @@ def prepare( a1 = a1 * rank_topk_weights.to(a1.dtype) repeat_cols = 4 - repeat_rows = 1 if self.per_act_token else a1.shape[0] + repeat_rows = 1 if self.per_act_token else a1.size(0) a1q, a1q_scale = moe_kernel_quantize_input( a1, (None if self.per_act_token else a1_scale), self.quant_dtype, self.per_act_token, self.block_shape) From 1c9fd39606745f6a2efc2c9a9174ae5a3a80665d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 18 Jun 2025 20:16:07 +0000 Subject: [PATCH 15/18] de-opify deepgemm kernels Signed-off-by: Bill Nell --- tests/kernels/moe/test_cutlass_moe.py | 6 +++--- .../layers/fused_moe/deep_gemm_moe.py | 14 ++++++------- .../layers/quantization/deepgemm.py | 21 ------------------- 3 files changed, 9 insertions(+), 32 deletions(-) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 0d6654aad6d4..158100a09879 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -29,10 +29,10 @@ (224, 1024, 1536), (224, 3072, 1024), (224, 3072, 1536), + (32768, 1024, 1024), # These sizes trigger wrong answers. - # (7232, 2048, 5120), - # (32768, 1024, 1024), - # (40000, 2048, 5120), + #(7232, 2048, 5120), + #(40000, 2048, 5120), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 3f0f9c36e05f..c04aaacc9d19 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -16,10 +16,6 @@ _resize_cache, per_token_group_quant_fp8) from vllm.utils import round_up -from vllm.model_executor.layers.quantization.deepgemm import ( # isort:skip - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm as - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm) - logger = init_logger(__name__) has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None @@ -112,6 +108,8 @@ def apply( workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], ): + import deep_gemm as dg + a1q = hidden_states _, N, K = w1.size() @@ -146,8 +144,8 @@ def apply( (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - torch.ops.vllm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( - a1q, a1q_scale, w1, w1_scale, mm1_out, expert_ids) + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) self.activation(activation, act_out, mm1_out.view(-1, N)) @@ -157,8 +155,8 @@ def apply( column_major_scales=True, out_q=quant_out) - torch.ops.vllm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( - a2q, a2q_scale, w2, w2_scale, mm2_out, expert_ids) + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) torch.index_select(mm2_out, 0, inv_perm, out=output) diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index d5079e3ba405..1d40f4915a1b 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util import logging -from typing import Optional import torch @@ -76,19 +75,6 @@ def w8a8_block_fp8_matmul_deepgemm_fake( return C -def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( - a: torch.Tensor, - a_scale: Optional[torch.Tensor], - b: torch.Tensor, - b_scale: Optional[torch.Tensor], - output: torch.Tensor, - expert_ids: torch.Tensor, -) -> None: - import deep_gemm as dg - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a, a_scale), (b, b_scale), - output, expert_ids) - - direct_register_custom_op( op_name="w8a8_block_fp8_matmul_deepgemm", op_func=w8a8_block_fp8_matmul_deepgemm, @@ -96,10 +82,3 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm( fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, dispatch_key=current_platform.dispatch_key, ) - -direct_register_custom_op( - op_name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm", - op_func=m_grouped_gemm_fp8_fp8_bf16_nt_contiguous_deepgemm, - mutates_args=["output"], - dispatch_key=current_platform.dispatch_key, -) From 79a1962799324e46b00315809936b2719b14d8a7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 18 Jun 2025 20:18:59 +0000 Subject: [PATCH 16/18] remove cruft Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 4 ++++ tests/kernels/moe/test_pplx_moe.py | 3 +++ vllm/model_executor/layers/fused_moe/modular_kernel.py | 5 ----- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 6b5b99e84108..12b3bf51f792 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -206,7 +206,11 @@ def m_fused_moe( padding=padding, ) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. use_compile = False + use_cudagraph = (n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 929067266ece..c4ad3af6802d 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -447,6 +447,9 @@ def pplx_moe( w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. if use_compile: _fused_experts = torch.compile(fused_experts, backend='inductor', diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index adac6da2dc8e..ed3b6b8a1af4 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -507,8 +507,3 @@ def forward( topk_ids, apply_router_weight_on_input) return output - - def compile(self, *args, **kwargs) -> None: - print(f"ARGS {args}") - print(f"KWARGS {kwargs}") - #super().compile(*args, **kwargs) From 8b5492c3959b7e2e47703310929c5967bd5de50e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 18 Jun 2025 21:48:19 +0000 Subject: [PATCH 17/18] fix up deep gemm tests Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c2fe75f510d6..1ca0a80ab9a9 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -456,9 +456,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - use_compile = (chunk_size < M and N >= 1024 and K >= 1024 - and current_platform.is_cuda_alike()) - use_cudagraph = use_compile + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False + + use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 + and current_platform.is_cuda_alike()) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): @@ -476,6 +480,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, backend="inductor", fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(topk_weights, 0) + torch._dynamo.mark_dynamic(topk_ids, 0) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 From de1d0963140daa45d3102389ebbe8db277eecf09 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 24 Jun 2025 21:47:39 +0000 Subject: [PATCH 18/18] fix test_mixtral_moe test Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 85 +++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 39 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 12b3bf51f792..0c31168566e2 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -17,6 +17,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) @@ -365,45 +366,51 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, pytest.skip("AITER ROCm test skip for float32") # Instantiate our and huggingface's MoE blocks - config = MixtralConfig() - hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") - vllm_moe = MixtralMoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=dtype, - tp_size=1, - dp_size=1, - ).cuda() - - # Load the weights - vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data - for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, - hf_moe.experts[i].w3.weight.data) - vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) - vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data - - # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] - hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") - # vLLM uses 1D query [num_tokens, hidden_dim] - vllm_inputs = hf_inputs.flatten(0, 1) - - # Pad the weight if moe padding is enabled - if padding: - vllm_moe.experts.w13_weight = Parameter(F.pad( - vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], - requires_grad=False) - torch.cuda.empty_cache() - vllm_moe.experts.w2_weight = Parameter(F.pad( - vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], - requires_grad=False) - torch.cuda.empty_cache() - - # Run forward passes for both MoE blocks - hf_states, _ = hf_moe.forward(hf_inputs) - vllm_states = vllm_moe.forward(vllm_inputs) + vllm_config.compilation_config.static_forward_context = dict() + with (set_current_vllm_config(vllm_config), + set_forward_context(None, vllm_config)): + config = MixtralConfig() + hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") + vllm_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=dtype, + tp_size=1, + dp_size=1, + ).cuda() + + # Load the weights + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data + for i in range(config.num_local_experts): + weights = (hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data) + vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data + + # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] + hf_inputs = torch.randn( + (1, 64, config.hidden_size)).to(dtype).to("cuda") + # vLLM uses 1D query [num_tokens, hidden_dim] + vllm_inputs = hf_inputs.flatten(0, 1) + + # Pad the weight if moe padding is enabled + if padding: + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., + 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + vllm_moe.experts.w2_weight = Parameter(F.pad( + vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., + 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + + # Run forward passes for both MoE blocks + hf_states, _ = hf_moe.forward(hf_inputs) + vllm_states = vllm_moe.forward(vllm_inputs) mixtral_moe_tol = { torch.float32: 1e-3,