diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index ce420901e317..158100a09879 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -29,7 +29,10 @@ (224, 1024, 1536), (224, 3072, 1024), (224, 3072, 1536), - (1024 * 128, 1024, 1024), + (32768, 1024, 1024), + # These sizes trigger wrong answers. + #(7232, 2048, 5120), + #(40000, 2048, 5120), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( @@ -232,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph( topk: int, per_act_token: bool, per_out_ch: bool, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) @@ -274,8 +279,10 @@ def test_cutlass_moe_8_bit_cuda_graph( topk: int, per_act_token: bool, per_out_ch: bool, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): dtype = torch.half @@ -329,8 +336,10 @@ def test_cutlass_moe_8_bit_EP( per_act_token: bool, per_out_channel: bool, ep_size: int, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index bed374cf4d56..0c31168566e2 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -4,6 +4,9 @@ Run `pytest tests/kernels/test_moe.py`. """ +import functools +from typing import Callable, Optional, Union + import pytest import torch from torch.nn import Parameter @@ -14,6 +17,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) @@ -40,7 +44,76 @@ vllm_config.scheduler_config.max_model_len = 8192 -@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) +def run_moe_test( + baseline: Union[Callable, torch.Tensor], + moe_fn: Callable, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + padding: bool = False, + use_compile: bool = False, + use_cudagraph: bool = False, + atol: float = 2e-2, + rtol: float = 0, +) -> torch.Tensor: + if isinstance(baseline, torch.Tensor): + baseline_output = baseline + else: + baseline_output = baseline(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + + if use_compile: + moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(score, 0) + + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + + if use_cudagraph: + test_output.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(test_output, + baseline_output, + atol=atol, + rtol=rtol) + + return baseline_output + + +@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -48,6 +121,7 @@ @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) +@pytest.mark.parametrize("chunk_size", [8192]) def test_fused_moe( m: int, n: int, @@ -57,7 +131,17 @@ def test_fused_moe( ep_size: int, dtype: torch.dtype, padding: bool, + chunk_size: int, + monkeypatch, ): + current_platform.seed_everything(7) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + # + # Setup test data + # + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -77,58 +161,70 @@ def test_fused_moe( else: e_map = None - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None) - - with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) - - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() + # + # Setup test functions + # + + m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None) + + def m_fused_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + return m_fused_moe_fn(a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map) + + fused_moe_fn = functools.partial(fused_moe, renormalize=False) + + # + # Run tests + # + runner = functools.partial( + run_moe_test, + a=a, + w1=w1, + w2=w2, + score=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + padding=padding, + ) - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False - topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - m_triton_output = m_fused_moe(a, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=e, - expert_map=e_map) + use_cudagraph = (n >= 1024 and k >= 1024 + and current_platform.is_cuda_alike()) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - torch.testing.assert_close(m_triton_output, - torch_output, - atol=2e-2, - rtol=0) - torch.testing.assert_close(iterative_output, - torch_output, - atol=2e-2, - rtol=0) + with set_current_vllm_config(vllm_config): + baseline_output = runner(torch_moe, iterative_moe) + runner(baseline_output, + fused_moe_fn, + use_compile=use_compile, + use_cudagraph=use_cudagraph) + runner(baseline_output, + m_fused_moe, + use_compile=use_compile, + use_cudagraph=use_cudagraph) @pytest.mark.parametrize("m", [1, 32, 222]) @@ -238,7 +334,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch_output = torch_moe(a, + w1_ref, + w2_ref, + score, + topk, + expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -265,45 +366,51 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, pytest.skip("AITER ROCm test skip for float32") # Instantiate our and huggingface's MoE blocks - config = MixtralConfig() - hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") - vllm_moe = MixtralMoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=dtype, - tp_size=1, - dp_size=1, - ).cuda() - - # Load the weights - vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data - for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, - hf_moe.experts[i].w3.weight.data) - vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) - vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data - - # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] - hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") - # vLLM uses 1D query [num_tokens, hidden_dim] - vllm_inputs = hf_inputs.flatten(0, 1) + vllm_config.compilation_config.static_forward_context = dict() + with (set_current_vllm_config(vllm_config), + set_forward_context(None, vllm_config)): + config = MixtralConfig() + hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") + vllm_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=dtype, + tp_size=1, + dp_size=1, + ).cuda() + + # Load the weights + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data + for i in range(config.num_local_experts): + weights = (hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data) + vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data + + # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] + hf_inputs = torch.randn( + (1, 64, config.hidden_size)).to(dtype).to("cuda") + # vLLM uses 1D query [num_tokens, hidden_dim] + vllm_inputs = hf_inputs.flatten(0, 1) - # Pad the weight if moe padding is enabled - if padding: - vllm_moe.experts.w13_weight = Parameter(F.pad( - vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], - requires_grad=False) - torch.cuda.empty_cache() - vllm_moe.experts.w2_weight = Parameter(F.pad( - vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], - requires_grad=False) - torch.cuda.empty_cache() - - # Run forward passes for both MoE blocks - hf_states, _ = hf_moe.forward(hf_inputs) - vllm_states = vllm_moe.forward(vllm_inputs) + # Pad the weight if moe padding is enabled + if padding: + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., + 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + vllm_moe.experts.w2_weight = Parameter(F.pad( + vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., + 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + + # Run forward passes for both MoE blocks + hf_states, _ = hf_moe.forward(hf_inputs) + vllm_states = vllm_moe.forward(vllm_inputs) mixtral_moe_tol = { torch.float32: 1e-3, @@ -546,7 +653,12 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + torch_output = torch_moe(a, + w_ref1, + w_ref2, + score, + topk, + expert_map=e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 22482d9ca85a..76b560e1bb41 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, device=w2.device, block_size=quant_blocksize) - torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) torch.testing.assert_close(torch_output, cutlass_output, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index d90202dfcb3b..0caf14f040bb 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -6,9 +6,9 @@ import pytest import torch +from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -164,22 +164,6 @@ def pplx_cutlass_moe( vllm_config.scheduler_config.max_model_len = 8192 -def torch_moe2(a, w1, w2, topk_weight, topk_ids): - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, @@ -210,8 +194,8 @@ def _pplx_moe( group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): - torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, - topk_ids) + torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights, + topk_ids) pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, w2_scale, topk_weights, topk_ids, a1_scale, out_dtype, per_act_token, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 2d6a8f39cec5..c4ad3af6802d 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -18,8 +18,8 @@ except ImportError: has_pplx = False +from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) @@ -163,29 +163,6 @@ def batched_moe( return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) -# Note: same as torch_moe but with fused_topk factored out. -def torch_moe2( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, -) -> torch.Tensor: - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @@ -209,7 +186,7 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) @@ -409,7 +386,7 @@ def pplx_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - use_compile: bool = True, + use_compile: bool = False, use_cudagraphs: bool = True, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( @@ -470,10 +447,16 @@ def pplx_moe( w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. if use_compile: _fused_experts = torch.compile(fused_experts, backend='inductor', fullgraph=True) + torch._dynamo.mark_dynamic(a_chunk, 0) + torch._dynamo.mark_dynamic(chunk_topk_weight, 0) + torch._dynamo.mark_dynamic(chunk_topk_ids, 0) else: _fused_experts = fused_experts @@ -576,7 +559,7 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids) pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a, w1, w2, topk_weight, topk_ids) # TODO (bnell): fix + re-enable diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index eec59573792d..1ca0a80ab9a9 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -403,19 +403,24 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS)) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - block_size = [block_m, block_m] - dtype = torch.bfloat16 - +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, + monkeypatch): if topk > E: pytest.skip(f"Skipping test: topk={topk} > E={E}") if not _valid_deep_gemm_shape(M, N, K): pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") + chunk_size = 1024 + torch.manual_seed(seed) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_size = [block_m, block_m] + dtype = torch.bfloat16 + fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -451,6 +456,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False + + use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 + and current_platform.is_cuda_alike()) + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -463,7 +476,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + if use_compile: + deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, + backend="inductor", + fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(topk_weights, 0) + torch._dynamo.mark_dynamic(topk_ids, 0) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) + + if use_cudagraph: + out.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index d1db6a8eb1ba..dcda8e479b29 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1054,12 +1054,21 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_moe(a, w1, w2, score, topk, expert_map): +def torch_experts(a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + assert (global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or (expert_map is not None + and global_num_experts == expert_map.shape[0])) + topk = topk_ids.shape[1] B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) if expert_map is not None: @@ -1073,6 +1082,19 @@ def torch_moe(a, w1, w2, score, topk, expert_map): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch_moe(a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, + expert_map) + + def torch_moe_single(a, w, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) diff --git a/vllm/envs.py b/vllm/envs.py index 921052821ee3..41a1cfb5e872 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -942,6 +942,7 @@ def factorize(name: str): "VLLM_DP_RANK", "VLLM_DP_SIZE", "VLLM_USE_STANDALONE_COMPILE", + "VLLM_FUSED_MOE_CHUNK_SIZE", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 3f9ceac8b6e3..73d169a84808 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -41,24 +41,24 @@ def run_cutlass_moe_fp8( assert w1.dtype == torch.float8_e4m3fn assert w2.dtype == torch.float8_e4m3fn if expert_num_tokens is None: - assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1" + assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1" else: - assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1" - assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1.shape[1], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2.shape[1], "W2 scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert a1q_scale is None or a1q_scale.dim( - ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ - 0], "Input scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a2_scale is None or a2_scale.dim( - ) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[ - 0], "Intermediate scale shape mismatch" + assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1" + assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2" + assert w1_scale.dim() == 1 or w1_scale.size( + 1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.size( + 1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch" + assert w1.size(0) == w2.size(0), "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size( + 0) == 1 or a1q_scale.size( + 0) == a1q.shape[0], "Input scale shape mismatch" + assert w1.size(0) == w2.size(0), "Weights expert number mismatch" + assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch" + assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch" + assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size( + 0) == 1 or a2_scale.size( + 0) == a1q.shape[0], "Intermediate scale shape mismatch" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" if expert_map is not None: assert expert_num_tokens is None @@ -75,12 +75,12 @@ def run_cutlass_moe_fp8( # their tokens are already contiguous for each expert as a result of # the dispatch function. - M = a1q.shape[0] # non batched expert M - padded_M = a1q.shape[1] # batched expert M + M = a1q.size(0) # non batched expert M + padded_M = a1q.size(1) # batched expert M _, K, N = w2.shape device = a1q.device - assert w1.shape[2] == K + assert w1.size(2) == K assert global_num_experts != -1 assert a1q_scale is not None @@ -91,8 +91,8 @@ def run_cutlass_moe_fp8( else: local_topk_ids = topk_ids - topk = local_topk_ids.shape[1] - local_E = w1.shape[0] + topk = local_topk_ids.size(1) + local_E = w1.size(0) if use_batched_format: assert expert_num_tokens is not None @@ -111,10 +111,10 @@ def run_cutlass_moe_fp8( problem_sizes2, expert_num_tokens, local_E, padded_M, N, K) - w1_scale = w1_scale.reshape(w1_scale.shape[0], -1) - w2_scale = w2_scale.reshape(w2_scale.shape[0], -1) - a1q = a1q.reshape(-1, a1q.shape[2]) - a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous() + w1_scale = w1_scale.reshape(w1_scale.size(0), -1) + w2_scale = w2_scale.reshape(w2_scale.size(0), -1) + a1q = a1q.reshape(-1, a1q.size(2)) + a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() else: expert_offsets = torch.empty((global_num_experts + 1), @@ -151,19 +151,19 @@ def run_cutlass_moe_fp8( a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.shape[0], ), + ab_strides1 = torch.full((w1.size(0), ), K, device=device, dtype=torch.int64) - c_strides1 = torch.full((w1.shape[0], ), + c_strides1 = torch.full((w1.size(0), ), 2 * N, device=device, dtype=torch.int64) - ab_strides2 = torch.full((w1.shape[0], ), + ab_strides2 = torch.full((w1.size(0), ), N, device=device, dtype=torch.int64) - c_strides2 = torch.full((w1.shape[0], ), + c_strides2 = torch.full((w1.size(0), ), K, device=device, dtype=torch.int64) @@ -237,7 +237,7 @@ def workspace_shapes( workspace2: tuple[int, ...] = () output: tuple[int, ...] = () if self.use_batched_format: - padded_M = aq.shape[1] + padded_M = aq.size(1) workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) output = (self.max_experts_per_worker, padded_M, K) @@ -332,7 +332,7 @@ def cutlass_moe_fp8( """ per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - per_out_ch = w1_scale.numel() != w1_q.shape[0] + per_out_ch = w1_scale.numel() != w1_q.size(0) out_dtype = a.dtype @@ -425,11 +425,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert (topk_weights.shape[0] == m and topk_ids.shape[0] + assert (topk_weights.size(0) == m and topk_ids.size(0) == m), ("topk must be provided for each row of a") out_dtype = a.dtype - num_topk = topk_ids.shape[1] + num_topk = topk_ids.size(1) expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) @@ -463,7 +463,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, out_dtype, device) del rep_a_fp4, rep_a_blockscale # hidden size dimension is split to one halfpytho sized tensor. - intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2), + intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), device=device, dtype=out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b4473b907381..c04aaacc9d19 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -48,7 +48,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, M = hidden_states.size(0) _, K, N = w2.size() if not _valid_deep_gemm_shape(M, N, K): - logger.debug("DeepGemm disabled: unalinged problem size.") + logger.debug("DeepGemm disabled: unaligned problem size.") return False if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 3484a7a8a496..5a8accd80463 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -25,7 +25,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor, expert_x_fp32 = expert_x_fp8.to(torch.float32).view( num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) expert_x_scales = expert_x_scales.view(num_experts, -1, 1) - return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 437e80696ac6..f22884b8a1a5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -488,10 +488,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None - assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) - == B_scale.shape[-2]) - assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) - == B_scale.shape[-1]) + assert (block_shape is None + or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) + assert (block_shape is None + or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None @@ -500,19 +500,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert A_scale is None assert B_scale is None - M = A.shape[0] + M = A.size(0) num_tokens = M * top_k - EM = sorted_token_ids.shape[0] - if A.shape[0] < config["BLOCK_SIZE_M"]: + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. # We assume that top_ids of each token is unique, so # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. - EM = min(sorted_token_ids.shape[0], - A.shape[0] * top_k * config['BLOCK_SIZE_M']) + EM = min(sorted_token_ids.size(0), + A.size(0) * top_k * config['BLOCK_SIZE_M']) grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( - B.shape[1], META['BLOCK_SIZE_N']), ) + B.size(1), META['BLOCK_SIZE_N']), ) if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: @@ -522,16 +522,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, group_size=block_shape[1], - num_experts=B.shape[0], + num_experts=B.size(0), bit=4 if use_int4_w4a16 else 8) config = config.copy() config.update( get_moe_wna16_block_config(config=config, use_moe_wna16_cuda=use_moe_wna16_cuda, num_valid_tokens=num_tokens, - size_k=A.shape[1], - size_n=B.shape[1], - num_experts=B.shape[1], + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), group_size=block_shape[1], real_top_k=top_k, block_size_m=config["BLOCK_SIZE_M"])) @@ -556,8 +556,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.shape[1], - A.shape[1], + B.size(1), + A.size(1), EM, num_tokens, A.stride(0), @@ -573,7 +573,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0, - block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, group_size=block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, @@ -599,8 +599,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.shape[1], - B.shape[2], + B.size(1), + B.size(2), EM, num_tokens, A.stride(0), @@ -818,7 +818,7 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[list[int]] = None, -): +) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() if override_config: @@ -873,10 +873,10 @@ def fused_topk( renormalize: bool, indices_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( + assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") - M, _ = hidden_states.shape + M, _ = hidden_states.size() topk_weights = torch.empty(M, topk, @@ -915,7 +915,7 @@ def grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( + assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") if scoring_func == "softmax": @@ -925,7 +925,7 @@ def grouped_topk( else: raise ValueError(f"Unsupported scoring function: {scoring_func}") - num_token = scores.shape[0] + num_token = scores.size(0) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights @@ -942,7 +942,7 @@ def grouped_topk( group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] @@ -1162,7 +1162,7 @@ def fused_experts(hidden_states: torch.Tensor, allow_deep_gemm: bool = False) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. - N = w1.shape[1] + N = w1.size(1) if (allow_deep_gemm and use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2)): assert apply_router_weight_on_input is False @@ -1233,13 +1233,13 @@ def fused_experts_impl( ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: - assert hidden_states.shape[1] // 2 == w1.shape[ - 2], "Hidden size mismatch" + assert hidden_states.size(1) // 2 == w1.size(2), ( + "Hidden size mismatch") else: - assert hidden_states.shape[1] == w1.shape[2], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") + assert hidden_states.size(1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" @@ -1247,12 +1247,12 @@ def fused_experts_impl( torch.float32, torch.float16, torch.bfloat16 ] - num_tokens = hidden_states.shape[0] - E, N, _ = w1.shape - K = w2.shape[1] + num_tokens = hidden_states.size(0) + E, N, _ = w1.size() + K = w2.size(1) if global_num_experts == -1: global_num_experts = E - top_k_num = topk_ids.shape[1] + top_k_num = topk_ids.size(1) # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE @@ -1269,8 +1269,8 @@ def fused_experts_impl( get_config_func = functools.partial( try_get_optimal_moe_config, - w1.shape, - w2.shape, + w1.size(), + w2.size(), top_k_num, config_dtype, block_shape=block_shape, @@ -1310,7 +1310,7 @@ def fused_experts_impl( min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape + tokens_in_chunk, _ = curr_hidden_states.size() if tokens_in_chunk == 0: break @@ -1322,7 +1322,7 @@ def fused_experts_impl( # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.shape[1]] + topk_ids.size(1)] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) @@ -1398,7 +1398,7 @@ def fused_experts_impl( per_channel_quant=per_channel_quant, block_shape=block_shape) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states @@ -1611,8 +1611,8 @@ def apply( dtype=hidden_states.dtype) config = try_get_optimal_moe_config( - w1.shape, - w2.shape, + w1.size(), + w2.size(), top_k_num, config_dtype, num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1fd8f2175886..bd5eeaa849db 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -853,13 +853,11 @@ def __init__( self.global_num_experts = num_experts # For smuggling this layer into the fused moe custom op - self.use_direct_call = self.dp_size == 1 - if not self.use_direct_call: - compilation_config = vllm_config.compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError("Duplicate layer name: {}".format(prefix)) - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix # Determine expert maps if self.use_ep: @@ -1353,11 +1351,8 @@ def maybe_all_reduce_tensor_model_parallel( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call: - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 5bc01dbf2025..2ff8ef99b2ec 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -69,7 +69,7 @@ def prepare( a1 = a1 * rank_topk_weights.to(a1.dtype) repeat_cols = 4 - repeat_rows = 1 if self.per_act_token else a1.shape[0] + repeat_rows = 1 if self.per_act_token else a1.size(0) a1q, a1q_scale = moe_kernel_quantize_input( a1, (None if self.per_act_token else a1_scale), self.quant_dtype, self.per_act_token, self.block_shape)