diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index acabe6c1ddb0..1d4e730f99ae 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -113,6 +113,7 @@ def run_cutlass_moe( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + per_act_token: bool, num_repeats: int, ): for _ in range(num_repeats): @@ -124,7 +125,8 @@ def run_cutlass_moe( topk_ids, w1_scale, w2_scale, - a1_scale=a_scale, + per_act_token, + a1_scale=None, ) def run_cutlass_from_graph( @@ -148,7 +150,8 @@ def run_cutlass_from_graph( topk_ids, w1_scale, w2_scale, - a1_scale=a_scale, + per_act_token, + a1_scale=None, ) def run_triton_from_graph( @@ -227,6 +230,7 @@ def replay_graph(graph, num_repeats): "w2_q": w2_q, "w1_scale": w1_scale, "w2_scale": w2_scale, + "per_act_token": per_act_token, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -287,12 +291,13 @@ def replay_graph(graph, num_repeats): w2_scale, topk_weights, topk_ids, + per_act_token, num_warmup, ) results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py new file mode 100644 index 000000000000..7797e4f0c9c0 --- /dev/null +++ b/tests/kernels/moe/parallel_utils.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +DeepEP test utilities +""" +import dataclasses +import importlib +import os +import traceback +from typing import Callable, Optional + +import torch +from torch.distributed import ProcessGroup +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +from vllm.utils import get_open_port + +has_deep_ep = importlib.util.find_spec("deep_ep") is not None +if has_deep_ep: + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) + +## Parallel Processes Utils + +P = ParamSpec("P") + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +## DeepEP specific utils + + +@dataclasses.dataclass +class DeepEPHTArgs: + num_local_experts: int + + +@dataclasses.dataclass +class DeepEPLLArgs: + max_tokens_per_rank: int + hidden_size: int + num_experts: int + use_fp8_dispatch: bool + + +def make_deepep_ht_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + ht_args: DeepEPHTArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + + import deep_ep + + # high throughput a2a + num_nvl_bytes = 1024 * 1024 * 1024 # 1GB + num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1 + buffer = deep_ep.Buffer(group=pg, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=low_latency_mode, + num_qps_per_rank=num_qps_per_rank) + return DeepEPHTPrepareAndFinalize(buffer=buffer, + world_size=pgi.world_size, + rank=pgi.rank, + dp_size=dp_size, + rank_expert_offset=pgi.rank * + ht_args.num_local_experts) + + +def make_deepep_ll_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ll_args: DeepEPLLArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + + import deep_ep + + # low-latency a2a + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size, + pgi.world_size, deepep_ll_args.num_experts) + + buffer = deep_ep.Buffer(group=pg, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=deepep_ll_args.num_experts // + pgi.world_size) + + return DeepEPLLPrepareAndFinalize( + buffer=buffer, + world_size=pgi.world_size, + dp_size=dp_size, + max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, + use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, + ) + + +def make_deepep_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ht_args: Optional[DeepEPHTArgs], + deepep_ll_args: Optional[DeepEPLLArgs], + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + if deepep_ht_args is not None: + assert deepep_ll_args is None + return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype, + block_shape) + + assert deepep_ll_args is not None + return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype, + block_shape) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index b0e0feab4689..779fa1df086d 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -2,18 +2,59 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from typing import Optional import pytest import torch import triton.language as tl +from tests.kernels.moe.utils import (batched_moe, + make_quantized_test_activations, + make_test_weights, triton_moe) +from tests.kernels.quant_utils import native_batched_masked_quant_matmul +from tests.kernels.utils import torch_experts +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( invoke_moe_batched_triton_kernel) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.platforms import current_platform + +MNK_FACTORS = [ + (1, 128, 128), + (1, 128, 2048), + (1, 512, 512), + (1, 1024, 128), + (1, 1024, 2048), + (32, 128, 128), + (32, 512, 512), + (32, 1024, 2048), + (45, 128, 128), + (45, 128, 2048), + (45, 512, 512), + (45, 1024, 128), + (45, 1024, 2048), + (64, 128, 128), + (64, 512, 512), + (64, 1024, 2048), + (222, 128, 128), + (222, 128, 2048), + (222, 512, 512), + (222, 1024, 128), + (222, 1024, 2048), +] +NUM_EXPERTS = [8, 64] +TOP_KS = [1, 2, 6] + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 @dataclass class BatchedMMConfig: - dtype: torch.dtype + in_dtype: torch.dtype + quant_dtype: Optional[torch.dtype] + out_dtype: torch.dtype num_experts: int max_tokens_per_expert: int K: int @@ -32,79 +73,127 @@ def make_tensors(config: BatchedMMConfig): A = torch.randn( (config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", - dtype=config.dtype) / 10 + dtype=config.in_dtype) / 10 B = torch.randn((config.num_experts, config.N, config.K), device="cuda", - dtype=config.dtype) + dtype=config.in_dtype) C = torch.zeros( (config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", - dtype=config.dtype) + dtype=config.out_dtype) + num_expert_tokens = torch.randint(low=0, high=config.max_tokens_per_expert, size=(config.num_experts, ), device="cuda", dtype=torch.int32) - return BatchedMMTensors(A, B, C, num_expert_tokens) - -def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - num_expert_tokens: torch.Tensor) -> torch.Tensor: - - num_expert_tokens_cpu = num_expert_tokens.clone() - num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") - num_experts = num_expert_tokens.size(0) - - for e in range(num_experts): - num_tokens = num_expert_tokens_cpu[e] - C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) - - return C + return BatchedMMTensors(A, B, C, num_expert_tokens) -@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("num_experts", [8, 16, 32]) @pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512]) @pytest.mark.parametrize("K", [128, 256, 1024]) @pytest.mark.parametrize("N", [128, 256, 512, 1024]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("block_shape", [None]) +@pytest.mark.parametrize("per_act_token_quant", [False]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, - N: int, dtype: torch.dtype): + N: int, dtype: torch.dtype, + block_shape: Optional[list[int]], + per_act_token_quant: bool): + current_platform.seed_everything(7) - config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) - tensors = BatchedMMTensors.make_tensors(config) + use_fp8_w8a8 = dtype == torch.float8_e4m3fn - test_output = tensors.C - ref_output = test_output.clone() + if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8: + pytest.skip("Don't test blocking for non-quantized types.") + + if per_act_token_quant and block_shape is not None: + pytest.skip("Skip illegal quantization test.") + + if dtype.itemsize == 1: + act_dtype = torch.bfloat16 + quant_dtype = dtype + else: + act_dtype = dtype + quant_dtype = None + + num_expert_tokens = torch.randint(low=0, + high=max_tokens_per_expert, + size=(num_experts, ), + device="cuda", + dtype=torch.int32) + + A, A_q, A_scale = make_quantized_test_activations( + num_experts, + max_tokens_per_expert, + K, + in_dtype=act_dtype, + quant_dtype=quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant) + + B, B_q, B_scale, _, _, _ = make_test_weights( + num_experts, + N // 2, + K, + in_dtype=act_dtype, + quant_dtype=quant_dtype, + block_shape=block_shape, + ) + + out_shape = (num_experts, max_tokens_per_expert, N) + test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda") + ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda") + q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda") compute_tl_dtype = { torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32 }[test_output.dtype] + + assert A_q.dtype == B_q.dtype + invoke_moe_batched_triton_kernel( - tensors.A, - tensors.B, + A_q, + B_q, test_output, - tensors.num_expert_tokens, + num_expert_tokens, compute_tl_dtype, # Quantization data - None, - None, + A_scale, + B_scale, None, # Quantization schemes - False, + use_fp8_w8a8, False, False, config={ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 16 - }) + "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32 + }, + block_shape=block_shape, + ) - ref_output = ref_impl(tensors.A, tensors.B, ref_output, - tensors.num_expert_tokens) + ref_output = native_batched_masked_quant_matmul( + A, + B, + ref_output, + num_expert_tokens, + None, + None, + None, + ) + + q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output, + num_expert_tokens, + A_scale, B_scale, + block_shape) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -112,4 +201,98 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, torch.float32: (1e-2, 1e-2), }[test_output.dtype] - torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol) + torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("per_act_token_quant", [False]) +@pytest.mark.parametrize("block_shape", [None]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + per_act_token_quant: bool, + block_shape: Optional[list[int]], +): + current_platform.seed_everything(7) + + use_fp8_w8a8 = dtype == torch.float8_e4m3fn + + if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): + pytest.skip("Skip quantization test for non-quantized type") + + if per_act_token_quant and block_shape is not None or topk > e: + pytest.skip("Skip illegal quantization test.") + + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + + if dtype.itemsize == 1: + act_dtype = torch.bfloat16 + quant_dtype = dtype + else: + act_dtype = dtype + quant_dtype = None + + _, w1, w1_s, _, w2, w2_s = make_test_weights(e, + n, + k, + block_shape=block_shape, + in_dtype=act_dtype, + quant_dtype=quant_dtype) + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + batched_output = batched_moe( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + baseline_output = torch_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape) + + triton_output = triton_moe( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + + torch.testing.assert_close(triton_output, + baseline_output, + atol=2e-2, + rtol=2e-2) + + torch.testing.assert_close(triton_output, + batched_output, + atol=2e-2, + rtol=2e-2) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py new file mode 100644 index 000000000000..c187542205a5 --- /dev/null +++ b/tests/kernels/moe/test_block_fp8.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.kernels.moe.utils import make_test_weights +from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, + native_w8a8_block_matmul) +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm_shape, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, modular_triton_fused_moe) +from vllm.platforms import current_platform + +dg_available = False +try: + import deep_gemm + dg_available = True +except ImportError: + pass + +if current_platform.get_device_capability() < (9, 0): + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", + allow_module_level=True) + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +# Test configurations +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] +# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 +# and its hidden size is 7168. +MNK_FACTORS = [ + (1, 128, 128), + (1, 512, 512), + (1, 128, 7168), + (1, 1024, 7168), + (1, 4608, 128), + (1, 4608, 512), + (1, 4608, 7168), + (83, 128, 128), + (83, 512, 512), + (83, 1024, 7168), + (83, 4608, 512), + (83, 4608, 7168), + (128, 128, 128), + (128, 512, 512), + (128, 1024, 7168), + (128, 4608, 512), + (128, 4608, 7168), + (2048, 128, 128), + (2048, 1024, 7168), + (2048, 4608, 512), + (2048, 4608, 7168), + (8192, 128, 128), + (8192, 512, 512), + (8192, 128, 7168), + (8192, 1024, 7168), + (8192, 4608, 512), + (8192, 4608, 7168), +] + +MNK_FACTORS_DG = [ + (128, 128, 128), + (128, 512, 512), + (128, 128, 7168), + (128, 1024, 7168), + (128, 4608, 128), + (128, 4608, 512), + (128, 4608, 7168), + (192, 128, 128), + (192, 512, 512), + (192, 1024, 7168), + (192, 4608, 512), + (192, 4608, 7168), + (1335, 128, 128), + (1335, 1024, 7168), + (1335, 4608, 512), + (1335, 4608, 7168), + (2048, 128, 128), + (2048, 512, 512), + (2048, 128, 7168), + (2048, 1024, 7168), + (2048, 4608, 128), + (2048, 4608, 512), + (2048, 4608, 7168), +] + +BLOCK_SIZE = [[128, 128]] +E = [2, 8, 16] # [128, 256] +TOP_KS = [1, 2, 6] +SEEDS = [0] + + +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, + block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + topk = topk_ids.size(1) + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k) + out[mask] = native_w8a8_block_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +# Skip all tests if CUDA is not available +pytest.importorskip("torch.cuda") + + +@pytest.fixture(autouse=True) +def setup_cuda(): + torch.set_default_device("cuda") + + +@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS) +@pytest.mark.parametrize("E", E) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, + monkeypatch): + if topk > E: + pytest.skip(f"Skipping test; topk={topk} > E={E}") + + torch.manual_seed(seed) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048") + + a = torch.randn((M, K), dtype=dtype) / 10 + score = torch.randn((M, E), dtype=dtype) + + _, w1, w1_s, _, w2, w2_s = make_test_weights(E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size) + + m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_act_token_quant=False, + block_shape=block_size) + + topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) + + # Set the context to avoid lots of warning spam. + with set_current_vllm_config(vllm_config): + ref_out = torch_w8a8_block_fp8_moe( + a, + w1, + w2, + w1_s, + w2_s, + topk_weights, + topk_ids, + block_size, + ) + + out = fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + + m_out = m_fused_moe( + a, + w1, + w2, + topk_weights, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + ) + + # 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0] + tol = 0.035 if M < 40000 else 0.039 + torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol) + torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol) + + +@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG) +@pytest.mark.parametrize("E", E) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") +@torch.inference_mode() +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, + monkeypatch): + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") + + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") + + chunk_size = 1024 + + torch.manual_seed(seed) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_size = [block_m, block_m] + dtype = torch.bfloat16 + + a = torch.randn((M, K), dtype=dtype) / 10 + score = torch.randn((M, E), dtype=dtype) + + _, w1, w1_s, _, w2, w2_s = make_test_weights(E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size) + + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False + + use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 + and current_platform.is_cuda_alike()) + + topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) + + # Set the context to avoid lots of warning spam. + with set_current_vllm_config(vllm_config): + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids, block_size) + + if use_compile: + deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, + backend="inductor", + fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(topk_weights, 0) + torch._dynamo.mark_dynamic(topk_ids, 0) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) + + if use_cudagraph: + out.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035) diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py new file mode 100644 index 000000000000..8e680c722935 --- /dev/null +++ b/tests/kernels/moe/test_block_int8.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.kernels.moe.utils import make_test_weights +from tests.kernels.quant_utils import (native_per_token_group_quant_int8, + native_w8a8_block_matmul) +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (7, 0): + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", + allow_module_level=True) + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +DTYPES = [torch.half, torch.bfloat16] + +MNK_FACTORS = [ + (1, 128, 128), + (1, 512, 512), + (1, 128, 7168), + (1, 1024, 7168), + (1, 4096, 128), + (1, 4096, 512), + (1, 4096, 7168), + (33, 128, 128), + (33, 512, 512), + (33, 128, 7168), + (33, 1024, 7168), + (33, 4096, 128), + (33, 4096, 512), + (33, 4096, 7168), + (128, 128, 128), + (128, 512, 512), + (128, 1024, 7168), + (128, 4096, 512), + (128, 4096, 7168), + (222, 128, 128), + (222, 512, 512), + (222, 1024, 7168), + (222, 4096, 512), + (222, 4096, 7168), + (2048, 128, 128), + (2048, 1024, 7168), + (2048, 4096, 512), + (2048, 4096, 7168), +] + +E = [8, 24] +TOP_KS = [2, 6] +# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] +BLOCK_SIZE = [[128, 128]] +SEEDS = [0] + + +# For test +def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """This function performs fused moe with block-wise quantization using + native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_int8(a, block_k) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_int8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.fixture(autouse=True, scope="module") +def setup_cuda(): + """Sets the default CUDA device for all tests in this module.""" + torch.set_default_device("cuda") + + +@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS) +@pytest.mark.parametrize("E", E) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + """Tests the fused_moe kernel with W8A8 INT8 block quantization against a + native torch reference.""" + torch.manual_seed(seed) + + a = torch.randn((M, K), dtype=dtype) / 10 + score = torch.randn((M, E), dtype=dtype) + + _, w1, w1_s, _, w2, w2_s = make_test_weights(E, + N, + K, + dtype, + torch.int8, + per_act_token_quant=False, + block_shape=block_size) + + # Set the context to avoid lots of warning spam. + with set_current_vllm_config(vllm_config): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_int8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + # Check results + torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 158100a09879..929db9177537 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -97,11 +97,9 @@ def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, n_b_scales = 2 * n if per_out_channel else 1 k_b_scales = k if per_out_channel else 1 # Get the right scale for tests. - _, a_scale = ops.scaled_fp8_quant( - moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a, - a_scale, - use_per_token_if_dynamic=per_act_token) + a_q, a_scale = ops.scaled_fp8_quant( + moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) @@ -187,6 +185,7 @@ def slice_experts(): def run_8_bit(moe_tensors: MOETensors8Bit, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + per_act_token: bool, num_local_experts: Optional[int] = None) -> torch.Tensor: assert not any([ t is None for t in [ @@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'topk_ids': topk_ids, 'w1_scale': moe_tensors.w1_scale, 'w2_scale': moe_tensors.w2_scale, - 'a1_scale': moe_tensors.a_scale + 'per_act_token': per_act_token, + 'a1_scale': None #moe_tensors.a_scale } num_experts = moe_tensors.w1.size(0) @@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph( triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids) - cutlass_output = run_8_bit(mt, topk_weights, topk_ids) + cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token) + # Note 5.5 only needed for larger problem sizes, 5 works ok for + # the rest. torch.testing.assert_close(triton_output, cutlass_output, - atol=5e-2, + atol=5.5e-2, rtol=1e-2) @@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph( stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run_8_bit(mt, topk_weights, topk_ids) + cutlass_output = run_8_bit(mt, topk_weights, topk_ids, + per_act_token) torch.cuda.synchronize() graph.replay() @@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP( cutlass_output = run_8_bit(mt, topk_weights, topk_ids, + per_act_token, num_local_experts=e // ep_size) torch.testing.assert_close(triton_output, diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 008406c3f159..9b861d4ebc23 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -Test DeepEP + DeepGEMM integration +Test DeepEP + DeepGEMM integration DeepGEMM are gemm kernels specialized for the fp8 block-quantized case. """ @@ -17,12 +17,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from .utils import ProcessGroupInfo, parallel_launch +from .parallel_utils import ProcessGroupInfo, parallel_launch +from .utils import make_test_weights if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 @@ -30,10 +29,9 @@ from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) - from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a + from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a if has_deep_gemm(): - import deep_gemm from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) @@ -60,25 +58,6 @@ def next_power_of_2(x): return 2**math.ceil(math.log2(x)) -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (deep_gemm.ceil_div(m, 128) * 128, - deep_gemm.ceil_div(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - def make_block_quant_fp8_weights( e: int, n: int, @@ -86,43 +65,11 @@ def make_block_quant_fp8_weights( block_size: list[int], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Return weights w1, w2, w1q, w2q, w1_scale, w2_scale + Return weights w1q, w2q, w1_scale, w2_scale """ - dtype = torch.bfloat16 - - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 - w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) - - w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10 - w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = ((2 * n) + block_n - 1) // block_n - k_tiles_w1 = (k + block_k - 1) // block_k - n_tiles_w2 = (k + block_n - 1) // block_n - k_tiles_w2 = (n + block_k - 1) // block_k - - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - - w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), - device="cuda", - dtype=torch.float32) - w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), - device="cuda", - dtype=torch.float32) - - assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128) - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - - for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - - return w1, w2, w1_s, w2_s + w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights( + e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) + return w1q, w2q, w1_scale, w2_scale @dataclasses.dataclass @@ -132,6 +79,7 @@ class TestConfig: k: int n: int num_experts: int + per_act_token_quant: bool block_size: list[int] # configs for testing low-latency kernels low_latency: bool @@ -150,8 +98,7 @@ class TestTensors: def make(config: TestConfig, rank) -> "TestTensors": dtype = torch.bfloat16 - topk, m, k, block_size = (config.topk, config.m, config.k, - config.block_size) + topk, m, k = (config.topk, config.m, config.k) fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -159,9 +106,7 @@ def make(config: TestConfig, rank) -> "TestTensors": rank_tokens = torch.randn( (m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max) - - block_k = block_size[1] - _, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k) + rank_token_scales = None topk_ids = torch.randint( low=0, @@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, q_dtype=q_dtype, block_shape=test_config.block_size) - fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank, - world_size=pgi.world_size, - dp_size=dp_size, - block_shape=test_config.block_size) + fused_experts = BatchedDeepGemmExperts( + max_num_tokens=max_tokens_per_rank, + world_size=pgi.world_size, + dp_size=dp_size, + block_shape=test_config.block_size, + per_act_token_quant=test_config.per_act_token_quant) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk @@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, """ Tests for High-Throughput DeepEP + DeepGemm integration. """ + import deep_gemm m, n, k = mnk current_platform.seed_everything(7) @@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, k=k, n=n, num_experts=num_experts, + per_act_token_quant=False, block_size=block_size, low_latency=False, use_fp8_dispatch=None) @@ -474,10 +423,14 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm -def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int, - int], num_experts: int, topk: int, - use_fp8_dispatch: bool, block_size: list[int], - world_dp_size: tuple[int, int]): +def test_ll_deepep_deepgemm_moe( + mnk: tuple[int, int, int], + num_experts: int, + topk: int, + use_fp8_dispatch: bool, + block_size: list[int], + world_dp_size: tuple[int, int], +): """ Tests for Low-Latency DeepEP + DeepGemm integration. """ @@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int, k=k, n=n, num_experts=num_experts, + per_act_token_quant=False, block_size=block_size, low_latency=True, use_fp8_dispatch=use_fp8_dispatch, diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 94947c809e3a..d7df5bf77035 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -23,7 +23,7 @@ from vllm.platforms import current_platform from vllm.utils import has_deep_ep -from .utils import ProcessGroupInfo, parallel_launch +from .parallel_utils import ProcessGroupInfo, parallel_launch if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 @@ -31,7 +31,7 @@ from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) - from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a + from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a requires_deep_ep = pytest.mark.skipif( not has_deep_ep(), @@ -102,10 +102,6 @@ def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors": rank_tokens = torch.randn( (config.m, config.k), device="cuda", dtype=token_dtype) / 10 rank_token_scales = None - if config.dtype == torch.float8_e4m3fn: - # low_latency_mode kernels dont support per-token quant. - _, rank_token_scales = ops.scaled_fp8_quant( - rank_tokens, use_per_token_if_dynamic=not low_latency_mode) topk = torch.randint(low=0, high=config.num_experts, @@ -121,11 +117,18 @@ def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors": config=config) -def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - low_latency_mode: bool, hidden_size: int, dp_size: int, - num_experts: int, num_local_experts: int, - q_dtype: Optional[torch.dtype], - use_fp8_dispatch: bool) -> FusedMoEModularKernel: +def make_modular_kernel( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + low_latency_mode: bool, + hidden_size: int, + dp_size: int, + num_experts: int, + num_local_experts: int, + q_dtype: Optional[torch.dtype], + use_fp8_dispatch: bool, + per_act_token_quant: bool, +) -> FusedMoEModularKernel: is_quantized = q_dtype is not None @@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, deepep_ll_args = ll_args) if low_latency_mode: + assert not per_act_token_quant, "not supported in ll mode" fused_experts = BatchedTritonExperts( max_num_tokens=MAX_TOKENS_PER_RANK, world_size=pgi.world_size, @@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, use_fp8_w8a8=is_quantized, use_int8_w8a8=False, use_int8_w8a16=False, - use_int4_w4a16=False) + use_int4_w4a16=False, + per_act_token_quant=False, + ) else: - fused_experts = TritonExperts(use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False) + fused_experts = TritonExperts( + use_fp8_w8a8=is_quantized, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_act_token_quant=per_act_token_quant, + ) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, - low_latency_mode: bool, dp_size: int, - test_tensors: TestTensors, w1: torch.Tensor, - w2: torch.Tensor, w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], num_experts: int, - use_fp8_dispatch: bool) -> torch.Tensor: +def deep_ep_moe_impl( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + low_latency_mode: bool, + dp_size: int, + test_tensors: TestTensors, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + num_experts: int, + use_fp8_dispatch: bool, + per_act_token_quant: bool, +) -> torch.Tensor: num_local_experts = w1.size(0) @@ -199,11 +215,9 @@ def build_expert_map(): q_dtype = torch.float8_e4m3fn # Make modular kernel - mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode, - hidden_size, dp_size, - num_experts, - num_local_experts, q_dtype, - use_fp8_dispatch) + mk: FusedMoEModularKernel = make_modular_kernel( + pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, + num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant) out_hidden_states = torch.empty_like(test_tensors.rank_tokens) total_num_tokens = test_tensors.rank_tokens.size(0) @@ -257,9 +271,15 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): return out_hidden_states -def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor, - w2: torch.Tensor, w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool): +def torch_moe_impl( + test_tensors: TestTensors, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + using_fp8_dispatch: bool, + per_act_token_quant: bool, +): a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk, test_tensors.topk_weights) @@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor, # The DeepEP implementation is requested to dispatch using FP8. # For numerical stability for testing, emulate the fp8 dispatch by # blockwise quant and de-quant. + assert not per_act_token_quant a = test_tensors.rank_tokens aq, aq_scale = per_token_group_quant_fp8(a, 128) a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view( @@ -310,6 +331,7 @@ def _deep_ep_moe( w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], use_fp8_dispatch: bool, + per_act_token_quant: bool, ): if not low_latency_mode: @@ -331,7 +353,8 @@ def _deep_ep_moe( with set_current_vllm_config(VllmConfig()): # Reference torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale, - w2_scale, use_fp8_dispatch) + w2_scale, use_fp8_dispatch, + per_act_token_quant) # Splice experts for this rank. num_local_experts = config.num_experts // pgi.world_size @@ -356,6 +379,7 @@ def _deep_ep_moe( w2_scale_ep, config.num_experts, use_fp8_dispatch, + per_act_token_quant, ) torch.testing.assert_close( @@ -384,10 +408,16 @@ def _deep_ep_moe( @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@pytest.mark.parametrize("per_act_token_quant", [False, True]) @requires_deep_ep -def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], - num_experts: int, topk: int, world_dp_size: tuple[int, - int]): +def test_deep_ep_moe( + dtype: torch.dtype, + mnk: tuple[int, int, int], + num_experts: int, + topk: int, + world_dp_size: tuple[int, int], + per_act_token_quant: bool, +): low_latency_mode = False use_fp8_dispatch = False m, n, k = mnk @@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, - config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) + config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch, + per_act_token_quant) MNKs = [ @@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, - config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) + config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch, + False) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 0c31168566e2..96e3f29b3d79 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -17,6 +17,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( @@ -142,6 +143,10 @@ def test_fused_moe( # Setup test data # + # + # Setup test data + # + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -169,7 +174,7 @@ def test_fused_moe( use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, - per_channel_quant=False, + per_act_token_quant=False, block_shape=None) def m_fused_moe( @@ -365,6 +370,13 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") + monkeypatch.setenv('RANK', "0") + monkeypatch.setenv('LOCAL_RANK', "0") + monkeypatch.setenv('WORLD_SIZE', "1") + monkeypatch.setenv('MASTER_ADDR', 'localhost') + monkeypatch.setenv('MASTER_PORT', '12345') + init_distributed_environment() + # Instantiate our and huggingface's MoE blocks vllm_config.compilation_config.static_forward_context = dict() with (set_current_vllm_config(vllm_config), diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 76b560e1bb41..3f5412e75821 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -14,7 +14,7 @@ from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + pytest.skip("Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True) MNK_FACTORS = [ diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index ee2bdc838b0d..184c2dd2f904 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -15,7 +15,7 @@ FusedMoEModularKernel) from vllm.platforms import current_platform -from .utils import ProcessGroupInfo, parallel_launch +from .parallel_utils import ProcessGroupInfo, parallel_launch try: from pplx_kernels import AllToAll @@ -93,7 +93,7 @@ def pplx_cutlass_moe( num_experts=num_experts, experts_per_token=topk, rank=rank, - world_size=pgi.world_size, + world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1 @@ -118,8 +118,6 @@ def pplx_cutlass_moe( pgi.world_size, rank, dp_size, - quant_dtype=torch.float8_e4m3fn, - per_act_token=per_act_token, ) experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 1da14eddff31..186e00800a17 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -18,18 +18,20 @@ except ImportError: has_pplx = False +from tests.kernels.moe.utils import make_test_weights, naive_batched_moe from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe import override_config +from vllm.model_executor.layers.fused_moe import fused_topk, override_config +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, - get_default_config) + BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform +from vllm.utils import round_up -from .utils import ProcessGroupInfo, parallel_launch +from .parallel_utils import ProcessGroupInfo, parallel_launch requires_pplx = pytest.mark.skipif( not has_pplx, @@ -144,25 +146,6 @@ def torch_batched_moe( return torch_finalize(out, topk_weight, topk_ids) -def batched_moe( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, -) -> torch.Tensor: - num_experts = w1.shape[0] - - fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens=a.shape[0], - world_size=1, - dp_size=1, - rank=0), - BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) - - return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) - - @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @@ -188,7 +171,7 @@ def test_fused_moe_batched_experts( topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) - batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) + batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids) torch.testing.assert_close(baseline_output, torch_output, @@ -226,7 +209,6 @@ def pplx_prepare_finalize( topk = topk_ids.shape[1] num_tokens, hidden_dim = a.shape - block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size @@ -241,9 +223,7 @@ def pplx_prepare_finalize( dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else - ((hidden_dim + block_size - 1) // block_size * - torch.float32.itemsize)), + hidden_dim_scale_bytes=0, ) if group_name is None: @@ -260,7 +240,6 @@ def pplx_prepare_finalize( world_size, rank, dp_size, - a.dtype, ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) @@ -276,6 +255,7 @@ def pplx_prepare_finalize( num_experts, None, False, + FusedMoEQuantConfig(), ) b_a = b_a * 1.5 @@ -350,6 +330,7 @@ def _pplx_prepare_finalize( # TODO (bnell): this test point does not work for odd M due to how the test is # written, not due to limitations of the pplx kernels. The pplx_moe # test below is able to deal with odd M. +# TODO (bnell) add fp8 tests @pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @@ -386,18 +367,31 @@ def pplx_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + qtype: Optional[torch.dtype] = None, + per_act_token_quant=False, + block_shape: Optional[list[int]] = None, use_compile: bool = False, use_cudagraphs: bool = True, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) device = torch.device("cuda", rank) hidden_dim = a.shape[1] num_experts = w1.shape[0] - block_size = 128 topk = topk_ids.shape[1] - max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64) + + hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes( + max_num_tokens, + hidden_dim, + a.dtype, + qtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) args = dict( max_num_tokens=max_num_tokens, @@ -407,10 +401,8 @@ def pplx_moe( world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else - ((hidden_dim + block_size - 1) // block_size * - torch.float32.itemsize)), + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=scale_bytes, ) if group_name is None: @@ -429,9 +421,11 @@ def pplx_moe( dp_size, ) - experts = BatchedTritonExperts(max_num_tokens=a.shape[0], + experts = BatchedTritonExperts(max_num_tokens=max_num_tokens, world_size=world_size, - dp_size=dp_size) + dp_size=dp_size, + use_fp8_w8a8=qtype == torch.float8_e4m3fn, + block_shape=block_shape) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -447,6 +441,13 @@ def pplx_moe( w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + if w1_scale is not None: + w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device) + w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device) + else: + w1_scale_chunk = None + w2_scale_chunk = None + # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. @@ -465,6 +466,8 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, global_num_experts=num_experts) if use_cudagraphs: @@ -477,6 +480,8 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, global_num_experts=num_experts) torch.cuda.synchronize() @@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): rank=rank, ) - experts = BatchedExperts(max_num_tokens=a.shape[0], - world_size=1, - dp_size=1) + experts = NaiveBatchedExperts(max_num_tokens=a.shape[0], + world_size=1, + dp_size=1) fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -539,7 +544,12 @@ def _pplx_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, - use_internode: bool, + w1_s: Optional[torch.Tensor] = None, + w2_s: Optional[torch.Tensor] = None, + qtype: Optional[torch.dtype] = None, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, + use_internode: bool = False, ): if use_internode: uid = nvshmem_get_unique_id( @@ -557,11 +567,28 @@ def _pplx_moe( moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) + device = torch.device("cuda", pgi.rank) + a = a.to(device) + w1 = w1.to(device) + w2 = w2.to(device) + w1_s = w1_s.to(device) if w1_s is not None else None + w2_s = w2_s.to(device) if w2_s is not None else None + with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_experts(a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + quant_dtype=qtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape) pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, - a, w1, w2, topk_weight, topk_ids) + a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, + qtype, per_act_token_quant, block_shape) # TODO (bnell): fix + re-enable #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, # topk_ids) @@ -581,6 +608,8 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("per_act_token_quant", [False, True]) +@pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_moe( @@ -589,15 +618,33 @@ def test_pplx_moe( topk: int, dtype: torch.dtype, world_dp_size: tuple[int, int], + per_act_token_quant: bool, + block_shape: Optional[list[int]], use_internode: bool, ): current_platform.seed_everything(7) m, n, k = mnk world_size, dp_size = world_dp_size - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + + if dtype == torch.float8_e4m3fn: + use_fp8_w8a8 = True + quant_dtype = dtype + else: + use_fp8_w8a8 = False + quant_dtype = None + + if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None: + pytest.skip("Skip quantization test for non-quantized type") + + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + + _, w1, w1_s, _, w2, w2_s = make_test_weights(e, + n, + k, + quant_dtype=quant_dtype, + block_shape=block_shape) parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, + w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, use_internode) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index e317ccbdb4a7..5b1048797447 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -1,194 +1,249 @@ # SPDX-License-Identifier: Apache-2.0 -""" -DeepEP test utilities -""" -import dataclasses -import importlib -import os -import traceback -from typing import Callable, Optional +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional import torch -from torch.distributed import ProcessGroup -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] -from typing_extensions import Concatenate, ParamSpec - -from vllm.utils import get_open_port - -has_deep_ep = importlib.util.find_spec("deep_ep") is not None -if has_deep_ep: - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) - -## Parallel Processes Utils - -P = ParamSpec("P") - - -@dataclasses.dataclass -class ProcessGroupInfo: - world_size: int - world_local_size: int - rank: int - node_rank: int - local_rank: int - device: torch.device - - -def _worker_parallel_launch( - local_rank: int, - world_size: int, - world_local_size: int, - node_rank: int, - init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - rank = node_rank * world_local_size + local_rank - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - torch.distributed.init_process_group( - backend="cpu:gloo,cuda:nccl", - init_method=init_method, - rank=rank, - world_size=world_size, - device_id=device, - ) - barrier = torch.tensor([rank], device=device) - torch.distributed.all_reduce(barrier) - - try: - worker( - ProcessGroupInfo( - world_size=world_size, - world_local_size=world_local_size, - rank=rank, - node_rank=node_rank, - local_rank=local_rank, - device=device, - ), - *args, - **kwargs, - ) - except Exception as ex: - print(ex) - traceback.print_exc() - raise - finally: - torch.distributed.destroy_process_group() - - -def parallel_launch( - world_size: int, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - assert not kwargs - spawn( - _worker_parallel_launch, - args=( - world_size, - world_size, - 0, - f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", - worker, - ) + args, - nprocs=world_size, - join=True, - ) - -## DeepEP specific utils - - -@dataclasses.dataclass -class DeepEPHTArgs: - num_local_experts: int - - -@dataclasses.dataclass -class DeepEPLLArgs: - max_tokens_per_rank: int - hidden_size: int - num_experts: int - use_fp8_dispatch: bool - - -def make_deepep_ht_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - ht_args: DeepEPHTArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): - - import deep_ep - - # high throughput a2a - num_nvl_bytes = 1024 * 1024 * 1024 # 1GB - num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1 - buffer = deep_ep.Buffer(group=pg, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=num_qps_per_rank) - return DeepEPHTPrepareAndFinalize(buffer=buffer, - world_size=pgi.world_size, - rank=pgi.rank, - dp_size=dp_size, - rank_expert_offset=pgi.rank * - ht_args.num_local_experts, - quant_dtype=q_dtype, - block_shape=block_shape) - - -def make_deepep_ll_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - deepep_ll_args: DeepEPLLArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): - - import deep_ep - - # low-latency a2a - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size, - pgi.world_size, deepep_ll_args.num_experts) - - buffer = deep_ep.Buffer(group=pg, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=True, - num_qps_per_rank=deepep_ll_args.num_experts // - pgi.world_size) - - return DeepEPLLPrepareAndFinalize( - buffer=buffer, - world_size=pgi.world_size, - dp_size=dp_size, - max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, - quant_dtype=q_dtype, - block_shape=block_shape, - use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, +import vllm._custom_ops as ops +from tests.kernels.quant_utils import (per_block_cast_to_fp8, + per_block_cast_to_int8) +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) +from vllm.utils import round_up + + +def triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + quant_dtype: Optional[torch.dtype] = None, + per_act_token_quant=False, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + return fused_experts(a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_channel_quant=per_act_token_quant, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, + block_shape=block_shape) + + +def batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + quant_dtype: Optional[torch.dtype] = None, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + max_num_tokens = round_up(a.shape[0], 64) + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(max_num_tokens, + world_size=1, + dp_size=1, + rank=0), + BatchedTritonExperts( + max_num_tokens=max_num_tokens, + world_size=1, + dp_size=1, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ), ) + return fused_experts(a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) + + +def naive_batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + quant_dtype: Optional[torch.dtype] = None, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + max_num_tokens = round_up(a.shape[0], 64) + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(max_num_tokens, + world_size=1, + dp_size=1, + rank=0), + NaiveBatchedExperts( + max_num_tokens=max_num_tokens, + dp_size=1, + world_size=1, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ), + ) -def make_deepep_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - deepep_ht_args: Optional[DeepEPHTArgs], - deepep_ll_args: Optional[DeepEPLLArgs], - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): - if deepep_ht_args is not None: - assert deepep_ll_args is None - return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype, - block_shape) - - assert deepep_ll_args is not None - return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype, - block_shape) + return fused_experts(a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) + + +def chunk_scales(scales: Optional[torch.Tensor], start: int, + end: int) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + return scales + else: + return scales[start:end] + return None + + +def make_quantized_test_activations( + E: int, + m: int, + k: int, + in_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10 + a_q = a + a_scale = None + + if quant_dtype is not None: + assert (quant_dtype == torch.float8_e4m3fn + or quant_dtype == torch.int8), "only fp8/int8 supported" + a_q = torch.zeros_like(a, dtype=quant_dtype) + a_scale_l = [None] * E + for e in range(E): + a_q[e], a_scale_l[e] = moe_kernel_quantize_input( + a[e], None, quant_dtype, per_act_token_quant, block_shape) + a_scale = torch.stack(a_scale_l) + + if not per_act_token_quant and block_shape is None: + a_scale = a_scale.view(E, 1, 1) + + return a, a_q, a_scale + + +def moe_quantize_weights( + w: torch.Tensor, + w_s: Optional[torch.Tensor], + quant_dtype: Optional[torch.dtype], + per_token_quant: bool, + block_shape: Optional[list[int]], +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert (quant_dtype == torch.float8_e4m3fn + or quant_dtype == torch.int8), "only fp8/int8 supported" + + if block_shape is not None: + assert not per_token_quant + if quant_dtype == torch.int8: + w, w_s = per_block_cast_to_int8(w, block_shape) + else: + w, w_s = per_block_cast_to_fp8(w, block_shape) + else: + if quant_dtype == torch.int8: + w, w_s = ops.scaled_int8_quant( + w, w_s, use_per_token_if_dynamic=per_token_quant) + else: + w, w_s = ops.scaled_fp8_quant( + w, w_s, use_per_token_if_dynamic=per_token_quant) + + return w, w_s + + +def make_test_weight( + e: int, + rows: int, + cols: int, + in_dtype: torch.dtype = torch.bfloat16, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 + + if quant_dtype is not None: + w_l = [None] * e + w_s_l = [None] * e + for idx in range(e): + w_l[idx], w_s_l[idx] = moe_quantize_weights( + w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) + + w = torch.stack(w_l) + w_s = torch.stack(w_s_l) + if w_s.ndim == 2: + assert w_s.shape[-1] == 1 + w_s = w_s.view(-1, 1, 1) + + if block_shape is not None: + block_n, block_k = block_shape + n_tiles = (rows + block_n - 1) // block_n + k_tiles = (cols + block_k - 1) // block_k + assert w_s.shape == (e, n_tiles, k_tiles) + else: + w = w_16 + w_s = None + + return w_16, w, w_s + + +def make_test_weights( + e: int, + n: int, + k: int, + in_dtype: torch.dtype = torch.bfloat16, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: + return ( + *make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, + per_act_token_quant), + *make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, + per_act_token_quant), + ) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 0840cc7b54fc..d0dc85f25755 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -5,7 +5,10 @@ import torch +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + group_broadcast) from vllm.platforms import current_platform +from vllm.utils import round_up # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for rocm. @@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ return ref_out, ref_scale.view((1, )) -def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, - As: torch.Tensor, Bs: torch.Tensor, block_size, - output_dtype): +def native_w8a8_block_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, + compute_type: torch.dtype = torch.float32, +) -> torch.Tensor: """This function performs matrix multiplication with block-wise quantization using native torch. It is agnostic to the input data type and can be used for both int8 and @@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, `Bs` (float32). The output is returned in the specified `output_dtype`. """ - A = A.to(torch.float32) - B = B.to(torch.float32) + A = A.to(compute_type) + B = B.to(compute_type) assert A.shape[-1] == B.shape[-1] assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 assert len(block_size) == 2 @@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, As = As.reshape(M, As.shape[-1]) n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] + assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}" + assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}" C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + C = torch.zeros(C_shape, dtype=compute_type, device=A.device) A_tiles = [ A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) @@ -152,3 +161,152 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, C = C.reshape(origin_C_shape).to(output_dtype) return C + + +def native_per_token_group_quant_fp8(x, + group_size, + eps=1e-10, + dtype=torch.float8_e4m3fn): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch.""" + assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must " + "be divisible by `group_size`") + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +def native_per_token_group_quant_int8(x, + group_size, + eps=1e-10, + dtype=torch.int8): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch. + + It converts the tensor values into int8 values and returns the + quantized tensor along with the scaling factor used for quantization. + """ + assert (x.shape[-1] % group_size == 0 + ), "the last dimension of `x` must be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + iinfo = torch.iinfo(dtype) + int8_min = iinfo.min + int8_max = iinfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + # Use float32 for scale calculation for stability + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / int8_max + x_q = (x_.to(torch.float32) / x_s).round().clamp( + min=int8_min, max=int8_max).to(dtype) # Round before clamping + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +DEFAULT_BLOCK_SHAPE = [128, 128] + + +def per_block_cast_to_fp8( + x: torch.Tensor, + block_shape: list[int] = DEFAULT_BLOCK_SHAPE, +) -> tuple[torch.Tensor, torch.Tensor]: + block_m, block_n = block_shape + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def per_block_cast_to_int8( + x: torch.Tensor, + block_shape: list[int] = DEFAULT_BLOCK_SHAPE, +) -> tuple[torch.Tensor, torch.Tensor]: + block_m, block_n = block_shape + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def dequant( + t: torch.Tensor, + scale: Optional[torch.Tensor], + block_shape: Optional[list[int]], + per_act_token_quant: bool, + out_dtype: Optional[torch.dtype] = torch.float32, +) -> torch.Tensor: + if scale is not None: + f32 = torch.float32 + if per_act_token_quant or block_shape is None: + return (t.to(f32) * scale).to(out_dtype) + else: + return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype) + else: + return t.to(out_dtype) + + +def native_batched_masked_quant_matmul( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + num_expert_tokens: torch.Tensor, + A_scale: Optional[torch.Tensor] = None, + B_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, +) -> torch.Tensor: + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e] + if A.dtype.itemsize == 1 and block_shape is not None: + assert A_scale is not None and B_scale is not None + tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e], + block_shape, C.dtype) + C[e, :num_tokens, :] = tmp[:num_tokens, :] + elif A.dtype.itemsize == 1 and block_shape is None: + assert A_scale is not None and B_scale is not None + A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant) + B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant) + C[e, :num_tokens, :] = ( + A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype) + else: + assert A_scale is None + assert B_scale is None + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + + return C diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 1ca0a80ab9a9..42d5526dc21f 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -7,16 +7,10 @@ import pytest import torch -from tests.kernels.quant_utils import native_w8a8_block_matmul -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe) -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) +from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, + native_w8a8_block_matmul, + per_block_cast_to_fp8) +from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -46,78 +40,10 @@ K = [256, 3884, 4096, 13824, 16384] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128] -M_moe_dg = [128, 192, 1335, 2048] -N_moe = [128, 256, 1024, 4608] # [13824] -K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16, 24] # [128, 256] -TOP_KS = [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] - -def native_per_token_group_quant_fp8(x, - group_size, - eps=1e-10, - dtype=torch.float8_e4m3fn): - """Function to perform per-token-group quantization on an input tensor - `x` using native torch.""" - assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " - "be divisible by `group_size`") - assert x.is_contiguous(), "`x` is not contiguous" - - finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) - x_s = amax / fp8_max - x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) - - return x_q, x_s - - -def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_fp8(a, block_k) - a_q = a_q.to(torch.float32) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - # Skip all tests if CUDA is not available pytest.importorskip("torch.cuda") @@ -177,111 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -@pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) -@torch.inference_mode() -def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - if topk > E: - pytest.skip(f"Skipping test; topk={topk} > E={E}") - - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = (torch.rand( - (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w1_bf16 - - w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w2_bf16 - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = torch.rand( - (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale - w2_s = torch.rand( - (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale - - score = torch.randn((M, E), dtype=dtype) - - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=block_size) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - - topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - m_out = m_fused_moe(a, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=E, - w1_scale=w1_s, - w2_scale=w2_s) - - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 - - rel_diff = (torch.mean( - torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 - - -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (deep_gemm.ceil_div(m, 128) * 128, - deep_gemm.ceil_div(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) @@ -324,187 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 - - -def fp8_perm(m, idx): - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) - else: - return m[idx, ...] - - -def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): - M, K = a.shape - - sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, block_m, num_groups, None, pad_sorted_ids=True) - - num_tokens = topk * M - - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:M * topk] - - a = fp8_perm(a, sorted_token_ids // topk) - if a_s is not None: - a_s = a_s[sorted_token_ids // topk] - - return a, a_s, m_indices, inv_perm - - -def _moe_unpermute(out, inv_perm, topk, K, topk_weight): - M = topk_weight.shape[0] - out = out[inv_perm, ...] - tmp_out = out.view(-1, topk, K) - return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - -def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, - block_shape): - """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" - num_groups = w1.shape[0] - M, K = a.shape - N = w2.shape[-1] - - topk_weight, topk_ids, token_expert_indices = fused_topk( - a, score.float(), topk, False) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - - _, block_k = block_shape[0], block_shape[1] - - a_q, a_s = per_token_group_quant_fp8(a, block_m) - - a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, - num_groups, topk, block_m) - - inter_out = torch.zeros((a_q.shape[0], N * 2), - dtype=torch.bfloat16, - device=a.device) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), - inter_out, m_indices) - - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - - out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - - final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) - - return final_out - - -@pytest.mark.parametrize( - "M,N,K,E,topk,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS)) -@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, - monkeypatch): - if topk > E: - pytest.skip(f"Skipping test: topk={topk} > E={E}") - - if not _valid_deep_gemm_shape(M, N, K): - pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") - - chunk_size = 1024 - - torch.manual_seed(seed) - - monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - block_size = [block_m, block_m] - dtype = torch.bfloat16 - - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * - fp8_max).clamp(min=fp8_min, max=fp8_max) - - w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * - fp8_max).clamp(min=fp8_min, max=fp8_max) - - score = torch.randn((M, E), dtype=dtype) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = ((2 * N) + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w2 = (N + block_k - 1) // block_k - - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - - w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() - - assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - - for i in range(E): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - - # Note: for now use_compile will error out if the problem size is - # large enough to trigger chunking. I'm leaving the flag and - # setup code in case we are able to revisit this later. - use_compile = False - - use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 - and current_platform.is_cuda_alike()) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - if M >= 128: - ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - else: - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - - topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score.float(), topk, False) - - if use_compile: - deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, - backend="inductor", - fullgraph=True) - torch._dynamo.mark_dynamic(a, 0) - torch._dynamo.mark_dynamic(topk_weights, 0) - torch._dynamo.mark_dynamic(topk_ids, 0) - else: - deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) - - if use_cudagraph: - out.fill_(0) - stream = torch.cuda.Stream() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) - torch.cuda.synchronize() - graph.replay() - torch.cuda.synchronize() - - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - - assert rel_diff < 0.03 diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index fa2c9f890d6f..fac82cf9c8b5 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -8,9 +8,7 @@ import torch from tests.kernels.quant_utils import native_w8a8_block_matmul -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( w8a8_block_int8_matmul) from vllm.platforms import current_platform @@ -23,82 +21,10 @@ vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 - -# For test -def native_per_token_group_quant_int8(x, - group_size, - eps=1e-10, - dtype=torch.int8): - """Function to perform per-token-group quantization on an input tensor - `x` using native torch. - - It converts the tensor values into int8 values and returns the - quantized tensor along with the scaling factor used for quantization. - """ - assert (x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" - - iinfo = torch.iinfo(dtype) - int8_min = iinfo.min - int8_max = iinfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - # Use float32 for scale calculation for stability - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) - x_s = amax / int8_max - x_q = (x_.to(torch.float32) / x_s).round().clamp( - min=int8_min, max=int8_max).to(dtype) # Round before clamping - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) - - return x_q, x_s - - -# For test -def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """This function performs fused moe with block-wise quantization using - native torch.""" - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_int8(a, block_k) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_int8( - act_out, block_k) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - DTYPES = [torch.half, torch.bfloat16] M = [1, 33, 64, 222] N = [128, 1024] K = [256, 4096] -E = [8, 24] -TOP_KS = [2, 6] # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] BLOCK_SIZE = [[128, 128]] SEEDS = [0] @@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 - - -@pytest.mark.parametrize( - "M, N, K, E, topk, block_size, dtype, seed", - itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) -@torch.inference_mode() -def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - """Tests the fused_moe kernel with W8A8 INT8 block quantization against a - native torch reference.""" - torch.manual_seed(seed) - # Use a smaller factor for scale initialization to prevent large - # values/overflow especially when output dtype might be float16 - factor_for_scale = 1e-2 - int8_info = torch.iinfo(torch.int8) - int8_max, int8_min = int8_info.max, int8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_fp32 = (torch.rand( - (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max - w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) - - w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max - w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = (torch.rand( - (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale) - w2_s = (torch.rand( - (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale) - - score = torch.randn((M, E), dtype=dtype) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_int8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - - # Check results - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.06 diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index dcda8e479b29..84cf87d71d88 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,8 +13,11 @@ import torch from torch._prims_common import TensorLikeType +from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) from vllm.platforms.interface import _Backend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -1054,32 +1057,77 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_experts(a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: +def torch_experts( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + quant_dtype: Optional[torch.dtype] = None, + per_act_token_quant=False, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: assert (global_num_experts == -1 or (global_num_experts == w1.shape[0] and expert_map is None) or (expert_map is not None and global_num_experts == expert_map.shape[0])) + + M, K = a.shape topk = topk_ids.shape[1] - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - topk_weight = topk_weight.view(-1) + + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype, + per_act_token_quant, block_shape) + + num_experts = w1.shape[0] + topk_ids = topk_ids.view(-1) if expert_map is not None: topk_ids = expert_map[topk_ids] - for i in range(w1.shape[0]): + + for i in range(num_experts): mask = topk_ids == i if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + if quant_dtype is None: + tmp1 = a[mask] @ w1[i].transpose(0, 1) + tmp2 = SiluAndMul()(tmp1) + out[mask] = tmp2 @ w2[i].transpose(0, 1) + elif block_shape is not None: + assert (a_scale is not None and w1_scale is not None + and w2_scale is not None) + tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], + w1_scale[i], block_shape, + out.dtype) + tmp2 = SiluAndMul()(tmp1) + tmp2, b_scale = moe_kernel_quantize_input( + tmp2, None, quant_dtype, per_act_token_quant, block_shape) + + out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, + w2_scale[i], block_shape, + out.dtype) + else: + assert (a_scale is not None and w1_scale is not None + and w2_scale is not None) + f32 = torch.float32 + scales = a_scale if a_scale.numel() == 1 else a_scale[mask] + tmp1 = a[mask].to(f32) * scales + w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) + tmp1 = tmp1 @ w1_dq + tmp2 = SiluAndMul()(tmp1) + w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) + out[mask] = (tmp2 @ w2_dq).to(out.dtype) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) def torch_moe(a: torch.Tensor, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 36a0395ccdc9..6b1b3f787c23 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1274,7 +1274,7 @@ def scaled_fp8_quant( scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - assert scale.numel() == 1 + assert scale.numel() == 1, f"{scale.shape}" torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 2bdc96e297c1..3d40879b4ccb 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -4,8 +4,12 @@ from contextlib import contextmanager from typing import Any, Optional +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize) from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]: __all__ = [ "FusedMoE", + "FusedMoEConfig", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", + "FusedMoEPermuteExpertsUnpermute", + "FusedMoEActivationFormat", + "FusedMoEPrepareAndFinalize", "override_config", "get_config", ] @@ -36,11 +44,21 @@ def get_config() -> Optional[dict[str, Any]]: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) + from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4, cutlass_moe_fp8) + CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts) + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import ( TritonExperts, fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) __all__ += [ "fused_moe", @@ -50,5 +68,11 @@ def get_config() -> Optional[dict[str, Any]]: "grouped_topk", "cutlass_moe_fp8", "cutlass_moe_fp4", + "CutlassExpertsFp8", "TritonExperts", + "BatchedTritonExperts", + "DeepGemmExperts", + "BatchedDeepGemmExperts", + "TritonOrDeepGemmExperts", + "BatchedTritonOrDeepGemmExperts", ] diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index b54ac80535a4..6b08f32dff18 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -5,6 +5,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton @@ -179,28 +180,44 @@ def silu_mul_fp8_quant_deep_gemm( class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # The Deep Gemm kernels only support block size of 128 - DEEPGEMM_BLOCK_SHAPE = 128 - - def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, - block_shape: list[int]): + DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] + + def __init__(self, + max_num_tokens: int, + world_size: int, + dp_size: int, + block_shape: list[int], + per_act_token_quant=False): """ max_num_tokens: Maximum number of tokens from a DP Rank world_size: Number of EP ranks dp_size: Number of data-parallel ranks block_shape: Block quantization block shape """ - super().__init__() + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size - self.block_shape = block_shape - assert (len(self.block_shape) == 2 and all( - [v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape])) + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) def supports_chunking(self) -> bool: return False + def supports_expert_map(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -248,6 +265,7 @@ def apply( ): import deep_gemm as dg assert hidden_states.ndim == 3 + assert self.block_shape is not None a1q = hidden_states _, N, K = w1.size() diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 822cda8205bf..3682a536cb5c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -6,6 +6,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) @@ -20,43 +21,45 @@ def __init__(self, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - per_channel_quant: bool = False, block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, allow_deep_gemm: bool = False): - super().__init__() assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + )) self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int8_w8a16 = use_int8_w8a16 - self.use_int4_w4a16 = use_int4_w4a16 - self.per_channel_quant = per_channel_quant - self.block_shape = block_shape self.allow_deep_gemm = allow_deep_gemm # BatchedTritonKernel doesn't support block quantization # at the moment. self.batched_triton_experts = BatchedTritonExperts( max_num_tokens=self.max_num_tokens, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape, world_size=self.world_size, - dp_size=self.dp_size) if self.block_shape is None else None + dp_size=self.dp_size, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=self.per_act_token_quant, + block_shape=self.block_shape, + ) if self.block_shape is None else None + + is_fp8_128_block_quantized = ( + use_fp8_w8a8 and self.block_shape + == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) - is_fp8_128_block_quantized = (self.use_fp8_w8a8 - and self.block_shape is not None - and len(self.block_shape) == 2 and all( - [b == 128 - for b in self.block_shape])) self.batched_deep_gemm_experts = BatchedDeepGemmExperts( max_num_tokens=self.max_num_tokens, world_size=self.world_size, @@ -67,12 +70,31 @@ def __init__(self, assert (self.batched_deep_gemm_experts is not None or self.batched_triton_experts is not None) + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + if self.batched_triton_experts is not None: + assert (self.batched_deep_gemm_experts is None + or self.batched_deep_gemm_experts.activation_formats + == self.batched_triton_experts.activation_formats) + return self.batched_triton_experts.activation_formats + else: + assert self.batched_deep_gemm_experts is not None + return self.batched_deep_gemm_experts.activation_formats + def supports_chunking(self) -> bool: bdge = self.batched_deep_gemm_experts bte = self.batched_triton_experts return ((bdge is None or bdge.supports_chunking()) and (bte is None or bte.supports_chunking())) + def supports_expert_map(self) -> bool: + bdge = self.batched_deep_gemm_experts + bte = self.batched_triton_experts + return ((bdge is None or bdge.supports_expert_map()) + and (bte is None or bte.supports_expert_map())) + def workspace_shapes( self, a: torch.Tensor, @@ -87,7 +109,8 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: + if self.allow_deep_gemm: + assert self.batched_deep_gemm_experts is not None return self.batched_deep_gemm_experts.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py new file mode 100644 index 000000000000..9a678406b8f3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) + +import vllm.envs as envs +from vllm.config import ParallelConfig +from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +logger = init_logger(__name__) + + +def _get_quant_config_quantization_args( + quant_config: Optional[QuantizationConfig], + prop_name: str, +) -> Optional[QuantizationArgs]: + if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') + and "Linear" in quant_config.target_scheme_map and + "input_activations" in quant_config.target_scheme_map["Linear"]): + return quant_config.target_scheme_map["Linear"].get(prop_name) + else: + return None + + +def get_quant_config_input_quant( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + return _get_quant_config_quantization_args(quant_config, + "input_activations") + + +def get_quant_config_weight_quant( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + return _get_quant_config_quantization_args(quant_config, "weights") + + +# TODO (bnell): use scalar_type instead of bools? +def get_config_quant_dtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> Optional[torch.dtype]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + return None + + +@dataclass +class FusedMoEQuantConfig: + # The post quantization activation type. + quant_dtype: Optional[torch.dtype] = None + per_act_token_quant: bool = False + per_out_ch_quant: bool = False + block_shape: Optional[list[int]] = None + + # TODO: add col major flag? + # add detailed quant info for input, intermediates, weights, etc? + + @staticmethod + def make( + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: Optional[list[int]] = None, + ) -> "FusedMoEQuantConfig": + assert sum([ + int(flag) for flag in [ + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + ] + ]) <= 1, "Quantization flags are mutually exclusive." + + quant_dtype = get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + ) + return FusedMoEQuantConfig( + quant_dtype, + per_act_token_quant, + per_out_ch_quant, + block_shape, + ) + + +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + world_size: int + + use_ep: bool # whether to use EP or not + + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep + + @property + def use_pplx_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "pplx") + + @property + def use_deepep_ht_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + + @property + def use_deepep_ll_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + + @staticmethod + def make(tp_size_: int, dp_size_: int, world_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, + dp_size_, ep_size_ and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + world_size_ (int): the world size of the current All2All manager. + vllm_parallel_config (ParallelConfig): vllm's parallel config + object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ + is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + world_size=world_size_, + use_ep=False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor + # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + world_size=world_size_, + use_ep=True) + + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class FusedMoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + # The activation type. + in_dtype: torch.dtype + + quant_config: Optional[FusedMoEQuantConfig] = None + + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE + + def __post_init__(self): + if self.dp_size > 1: + logger.debug("Using FusedMoEConfig::max_num_tokens=%d", + self.max_num_tokens) + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + if self.quant_config is not None: + return self.quant_config.quant_dtype + else: + return None + + @property + def block_shape(self) -> Optional[list[int]]: + if self.quant_config is not None: + return self.quant_config.block_shape + else: + return None + + @property + def per_act_token_quant(self) -> bool: + if self.quant_config is not None: + return self.quant_config.per_act_token_quant + else: + return False + + @property + def per_out_ch_quant(self) -> bool: + if self.quant_config is not None: + return self.quant_config.per_out_ch_quant + else: + return False + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def world_size(self): + return self.moe_parallel_config.world_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels + + @staticmethod + def make( + num_experts: int, + experts_per_token: int, + hidden_dim: int, + num_local_experts: int, + moe_parallel_config: FusedMoEParallelConfig, + in_dtype: torch.dtype, + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config: Optional[Union[FusedMoEQuantConfig, + QuantizationConfig]] = None + ) -> "FusedMoEConfig": + + _quant_config: Optional[FusedMoEQuantConfig] = None + + if quant_config is not None and isinstance(quant_config, + QuantizationConfig): + if hasattr(quant_config, 'weight_block_size'): + block_shape = quant_config.weight_block_size + else: + block_shape = None + per_act_token_quant = False + per_out_ch_quant = False + quant_dtype: Optional[torch.dtype] = None + + input_quant = get_quant_config_input_quant(quant_config) + weight_quant = get_quant_config_weight_quant(quant_config) + + if input_quant is not None: + per_act_token_quant = (input_quant.strategy + == QuantizationStrategy.TOKEN + if input_quant is not None else False) + + if input_quant.num_bits == 8: + if input_quant.type == QuantizationType.FLOAT: + quant_dtype = torch.float8_e4m3fn + elif input_quant.type == QuantizationType.INT: + quant_dtype = torch.int8 + + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + if quant_dtype is None and isinstance(quant_config, Fp8Config): + quant_dtype = torch.float8_e4m3fn + + if weight_quant is not None: + per_out_ch_quant = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL) + + if quant_dtype is not None: + _quant_config = FusedMoEQuantConfig( + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + ) + else: + _quant_config = FusedMoEQuantConfig() + logger.warning_once("MoE DP setup unable to determine " + "quantization scheme or unsupported " + "quantization type. This model will " + "not run with DP enabled.") + else: + _quant_config = quant_config + + return FusedMoEConfig( + num_experts=num_experts, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + num_local_experts=num_local_experts, + moe_parallel_config=moe_parallel_config, + in_dtype=in_dtype, + quant_config=_quant_config, + max_num_tokens=max_num_tokens, + ) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 73d169a84808..0ef4e4f767e3 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,6 +7,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache @@ -202,26 +203,47 @@ def run_cutlass_moe_fp8( # TODO (bnell): split class batched vs. non-batched? +# maybe remove need for passing aq to workspace_shapes class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, max_experts_per_worker: int, - out_dtype: torch.dtype, - per_act_token: bool, - per_out_ch: bool, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]] = None, use_batched_format: bool = False, ): - super().__init__() + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + )) + assert max_experts_per_worker > 0 self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype - self.per_act_token = per_act_token - self.per_out_ch = per_out_ch self.use_batched_format = use_batched_format + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + if self.use_batched_format: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + else: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + def supports_chunking(self) -> bool: return not self.use_batched_format + def supports_expert_map(self) -> bool: + return not self.use_batched_format + def workspace_shapes( self, a: torch.Tensor, @@ -245,7 +267,8 @@ def workspace_shapes( workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) output = (M * topk, K) - return (workspace1, workspace2, output, self.out_dtype) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) def apply( self, @@ -270,13 +293,14 @@ def apply( assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" activation_callable = lambda i, o: self.activation(activation, i, o) - run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids, - activation_callable, global_num_experts, - expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, - expert_num_tokens, self.out_dtype, - self.per_act_token, self.per_out_ch, - self.use_batched_format) + in_dtype = hidden_states.dtype + run_cutlass_moe_fp8( + output, hidden_states, w1, w2, topk_ids, activation_callable, + global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, workspace13, workspace2, expert_num_tokens, + self.out_dtype if self.out_dtype is not None else in_dtype, + self.per_act_token_quant, self.per_out_ch_quant, + self.use_batched_format) def cutlass_moe_fp8( @@ -287,6 +311,7 @@ def cutlass_moe_fp8( topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + per_act_token: bool, activation: str = "silu", a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, @@ -330,22 +355,18 @@ def cutlass_moe_fp8( Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) per_out_ch = w1_scale.numel() != w1_q.size(0) - out_dtype = a.dtype + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( + 0) fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - quant_dtype=torch.float8_e4m3fn, - per_channel_quant=per_act_token, - ), + MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - max_experts_per_worker=global_num_experts, - out_dtype=out_dtype, - per_act_token=per_act_token, - per_out_ch=per_out_ch, + max_experts_per_worker=num_experts, + out_dtype=a.dtype, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, use_batched_format=False, ), ) @@ -358,7 +379,7 @@ def cutlass_moe_fp8( topk_ids, False, activation, - global_num_experts if global_num_experts != -1 else w1_q.size(0), + num_experts, expert_map, w1_scale, w2_scale, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 818f6d345ba6..8ad57c237fed 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -7,13 +7,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, per_token_group_quant_fp8) from vllm.utils import has_deep_gemm, round_up logger = init_logger(__name__) @@ -65,16 +65,31 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): - super().__init__() - self.block_shape = deep_gemm_block_shape() + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=deep_gemm_block_shape(), + )) + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: return True + def supports_expert_map(self) -> bool: + return True + def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + assert self.block_shape is not None # We use global_num_experts due to how moe_align_block_size handles # expert_maps. num_experts = global_num_experts @@ -107,6 +122,7 @@ def apply( expert_num_tokens: Optional[torch.Tensor], ): import deep_gemm as dg + assert self.block_shape is not None a1q = hidden_states _, N, K = w1.size() @@ -213,8 +229,7 @@ def deep_gemm_moe_fp8( - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn, - block_shape=deep_gemm_block_shape()), + MoEPrepareAndFinalizeNoEP(), DeepGemmExperts(), ) return fn( diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 8c21d8aa53a6..d8ddec9554f0 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -6,6 +6,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) @@ -15,22 +16,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Prepare/Finalize using DeepEP High-Throughput kernels. """ - def __init__(self, - buffer: deep_ep.Buffer, - world_size: int, - rank: int, - dp_size: int, - rank_expert_offset: int, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): + def __init__(self, buffer: deep_ep.Buffer, world_size: int, rank: int, + dp_size: int, rank_expert_offset: int): super().__init__() self.buffer = buffer self.world_size = world_size self.rank = rank self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset - self.quant_dtype = quant_dtype - self.block_shape = block_shape # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. @@ -39,6 +32,10 @@ def __init__(self, # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + def max_num_tokens_per_rank(self) -> Optional[int]: return None @@ -55,13 +52,6 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]: return None return deep_ep.Buffer.get_combine_config(self.dp_size) - def _do_quant(self, tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], per_act_token: bool): - tokens, token_scales = moe_kernel_quantize_input( - tokens, token_scales, self.quant_dtype, per_act_token, - self.block_shape) - return tokens, token_scales - def _do_dispatch(self, tokens: torch.Tensor, token_scales: Optional[torch.Tensor], rank_topk_ids: torch.Tensor, @@ -130,43 +120,51 @@ def prepare( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - rank_topk_weights: torch.Tensor, - rank_topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: if apply_router_weight_on_input: - topk = rank_topk_ids.size(1) + topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * rank_topk_weights.to(a1.dtype) + a1 = a1 * topk_weights.to(a1.dtype) # Check if there is a block_shape / or if we can infer the quantization # schemes from the scales. per_token_quant = None - if all([x is None for x in [self.block_shape, a1_scale, a2_scale] - ]) and self.quant_dtype is not None: + if all([ + x is None + for x in [quant_config.block_shape, a1_scale, a2_scale] + ]) and quant_config.quant_dtype is not None: # Quantization required despite none of the inputs suggesting # quantization. Fallback to per_token_dynamic quant. per_token_quant = True else: - per_token_quant = ((self.block_shape is not None) or - (a1_scale is not None and a1_scale.numel() != 1) - or (a2_scale is not None - and a2_scale.numel() != 1)) + per_token_quant = False if per_token_quant: - a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True) + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=True, + block_shape=quant_config.block_shape, + ) + if a1q_scale is not None and a1q_scale.numel() == 1: + a1q_scale = a1q_scale.view(1, 1) (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, expert_topk_weights) = self._do_dispatch( tokens=a1q, token_scales=a1q_scale, - rank_topk_ids=rank_topk_ids, - rank_topk_weights=rank_topk_weights, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, num_experts=num_experts) else: # DeepEP kernels only support dispatching per-token-quant @@ -175,15 +173,18 @@ def prepare( expert_topk_weights) = self._do_dispatch( tokens=a1, token_scales=None, - rank_topk_ids=rank_topk_ids, - rank_topk_weights=rank_topk_weights, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, num_experts=num_experts) # quantize now expert_x_scale = None if expert_x.numel() != 0: - expert_x, expert_x_scale = self._do_quant(expert_x, - a1_scale, - per_act_token=False) + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape) return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, expert_topk_weights) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 5a8accd80463..b315b4a97f04 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -5,11 +5,13 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + maybe_fix_scales, moe_kernel_quantize_input) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 +DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] def dequant_fp8(expert_x_fp8: torch.Tensor, @@ -35,30 +37,30 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # DeepEP low-latency kernels are compiled only for certain # specific hidden sizes. - SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168] + SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 7168] def __init__(self, buffer: deep_ep.Buffer, + max_tokens_per_rank: int, world_size: int, dp_size: int, - max_tokens_per_rank: int, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None, use_fp8_dispatch: bool = False): super().__init__() self.buffer = buffer + self.max_tokens_per_rank = max_tokens_per_rank self.world_size = world_size self.dp_size = dp_size - self.quant_dtype = quant_dtype - self.block_shape = block_shape - self.max_tokens_per_rank = max_tokens_per_rank self.use_fp8_dispatch = use_fp8_dispatch # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. self.handle = None + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_tokens_per_rank @@ -66,12 +68,17 @@ def topk_indices_dtype(self) -> Optional[torch.dtype]: return torch.int64 def _do_quant( - self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], - a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - a1_dtype: torch.dtype + self, + x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + a1_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - block_k = self.block_shape[1] if self.block_shape is not None else None + block_k = block_shape[1] if block_shape is not None else None if self.use_fp8_dispatch: if block_k == DEEPEP_QUANT_BLOCK_SIZE: # DeepEP kernels did the quantization for us. @@ -84,32 +91,20 @@ def _do_quant( assert isinstance(x, torch.Tensor) - # Check if there is a block_shape / or if we can infer the quantization - # schemes from the scales. - per_token_quant = None - if all([v is None for v in [self.block_shape, a1_scale, a2_scale] - ]) and self.quant_dtype is not None: - # Quantization required despite none of the inputs suggesting - # quantization. Fallback to per_token_dynamic quant. - per_token_quant = True - else: - per_token_quant = ((self.block_shape is not None) or - (a1_scale is not None and a1_scale.numel() != 1) - or (a2_scale is not None - and a2_scale.numel() != 1)) + assert not per_act_token_quant num_experts, max_tokens, hidden_dim = x.size() # TODO (varun): Optimization - Use a batched version of quant x = x.view((-1, hidden_dim)) - x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype, - per_token_quant, - self.block_shape) + x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, + per_act_token_quant, + block_shape) x = x.view((num_experts, -1, hidden_dim)) - if per_token_quant: + if quant_dtype is not None: assert x_scales is not None - x_scales = x_scales.view(num_experts, max_tokens, -1) + x_scales = maybe_fix_scales(x_scales, num_experts) return x, x_scales @@ -118,11 +113,12 @@ def prepare( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - rank_topk_weights: torch.Tensor, - rank_topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -142,24 +138,25 @@ def prepare( "low_latency kernels doesn't support dispatching per-token scales") if apply_router_weight_on_input: - topk = rank_topk_ids.size(1) + topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * rank_topk_weights.to(a1.dtype) + a1 = a1 * topk_weights.to(a1.dtype) # Dispatch expert_x, expert_num_tokens, self.handle, event, hook = \ self.buffer.low_latency_dispatch(a1, - rank_topk_ids, + topk_ids, self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, async_finish=False, return_recv_hook=False) - expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale, - a1.dtype) + expert_x, expert_x_scale = self._do_quant( + expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) return (expert_x, expert_x_scale, expert_num_tokens, None, None) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index a12cfafd42ab..37a109857ac3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -8,6 +8,7 @@ import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( @@ -317,8 +318,8 @@ def invoke_moe_batched_triton_kernel( expert_num_tokens: torch.Tensor, # [E] compute_type: tl.dtype, # Quantization data - A_scale: torch.Tensor, - B_scale: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], B_zp: torch.Tensor, # Quantization schemes use_fp8_w8a8: bool, @@ -387,14 +388,23 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): that the PPLX dispatch/combine kernels use. """ - def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, - rank: int): + def __init__( + self, + max_num_tokens: int, + world_size: int, + dp_size: int, + rank: int, + ): super().__init__() self.world_size = world_size self.dp_size = dp_size self.rank = rank self.max_num_tokens = max_num_tokens + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens @@ -411,6 +421,7 @@ def prepare( num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: assert a1.dim() == 2 @@ -435,22 +446,35 @@ def prepare( num_local_experts = num_experts // self.world_size + if quant_config.quant_dtype is None: + b_type = a1.dtype + else: + b_type = quant_config.quant_dtype + b_a1 = torch.zeros( (num_local_experts, self.max_num_tokens, hidden_dim), - dtype=a1.dtype, + dtype=b_type, device=a1.device) + b_a1_scale = None + + assert quant_config.quant_dtype is None, "quantization NYI" + first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() rows = torch.count_nonzero(topks.flatten()) - b_a1[expert_id - - first_expert, :rows, :] = a1[:topks.numel()][topks] - tokens_per_expert[expert_id - first_expert] = rows + if rows == 0: + continue + idx = expert_id - first_expert + b_a1[idx, :rows, :] = a1[:topks.numel()][topks] + tokens_per_expert[idx] = rows - return b_a1, a1_scale, tokens_per_expert, None, None + assert b_a1_scale is None or b_a1_scale.ndim == 3 + + return b_a1, b_a1_scale, tokens_per_expert, None, None def finalize( self, @@ -480,7 +504,7 @@ def finalize( output[topks] = output[topks] + rhs -class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): +class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): """ A reference MoE expert class that operates on expert batched format, i.e. E x max_num_tokens x K. This is the format that the pplx @@ -497,11 +521,17 @@ def __init__( use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, block_shape: Optional[list[int]] = None, - block_m: Optional[int] = None, + per_act_token_quant: bool = False, ): - super().__init__() - assert block_shape is None - assert block_m is None + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) assert not use_fp8_w8a8, "NYI" assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" @@ -510,9 +540,19 @@ def __init__( self.world_size = world_size self.dp_size = dp_size + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + def supports_chunking(self) -> bool: return False + def supports_expert_map(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -554,20 +594,12 @@ def apply( assert hidden_states.dim() == 3 assert expert_num_tokens is not None - max_num_tokens = self.max_num_tokens - num_dp = self.world_size // self.dp_size num_local_experts = w1.size(0) assert num_local_experts == w1.size(0), ( f"{num_local_experts} == {w1.size(0)}") N = w1.size(1) // 2 - # Not cudagraph friendly - assert (torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing() - or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), ( - f"{expert_num_tokens} <= {max_num_tokens * num_dp}") - for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor if (torch.compiler.is_compiling() @@ -575,6 +607,10 @@ def apply( num = hidden_states.shape[1] else: num = int(expert_num_tokens[expert].item()) + + if num == 0: + continue + tmp = _resize_cache(workspace2, (num, N)) input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) self.activation(activation, tmp, input) @@ -590,34 +626,53 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: Optional[int] = None, + max_num_tokens: int, + world_size: int, + dp_size: int, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, - world_size: int = 1, - dp_size: int = 1, ): - super().__init__() + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int8_w8a8 = use_int8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a16 = use_int8_w8a16 - self.block_shape = block_shape - self.per_channel_quant = per_channel_quant self.max_num_tokens = max_num_tokens self.world_size = world_size self.dp_size = dp_size - - assert not use_int8_w8a8, "NYI" - assert not use_int4_w4a16, "NYI" - assert self.block_shape is None, "NYI" + assert world_size > 0 + assert dp_size > 0 + assert dp_size <= world_size + assert max_num_tokens > 0 + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) def supports_chunking(self) -> bool: return False + def supports_expert_map(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -630,10 +685,9 @@ def workspace_shapes( local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - num_dp = self.world_size // self.dp_size + num_dp = self.world_size num_experts = local_num_experts - max_num_tokens = a.size( - 0) if self.max_num_tokens is None else self.max_num_tokens + max_num_tokens = self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) output = (num_experts, max_num_tokens * num_dp, K) @@ -708,7 +762,6 @@ def apply( raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") - #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, @@ -734,6 +787,8 @@ def apply( config=config, block_shape=self.block_shape) + intermediate_cache2.fill_(0) + # TODO: would be nice to use expert_num_tokens here to reduce # garbage compute self.activation(activation, intermediate_cache2.view(-1, N // 2), @@ -745,8 +800,8 @@ def apply( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None, - per_channel_quant=self.per_channel_quant, + quant_dtype=self.quant_dtype, + per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) qintermediate_cache2 = qintermediate_cache2.view( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f22884b8a1a5..75712b8e3a4d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -12,6 +12,10 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger +# yapf: disable +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, get_config_quant_dtype) +# yapf: enable from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( @@ -980,20 +984,6 @@ def get_config_dtype_str( return None -# TODO (bnell): use scalar_type instead of bools? -def get_config_qtype( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, -) -> Optional[torch.dtype]: - if use_fp8_w8a8: - return torch.float8_e4m3fn - elif use_int8_w8a8: - return torch.int8 - return None - - def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1262,10 +1252,10 @@ def fused_experts_impl( use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype) - qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16) + qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -1332,8 +1322,8 @@ def fused_experts_impl( qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, - qtype=qtype, - per_channel_quant=per_channel_quant, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( @@ -1373,8 +1363,8 @@ def fused_experts_impl( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - qtype=qtype, - per_channel_quant=per_channel_quant, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qintermediate_cache2, @@ -1521,30 +1511,41 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, - block_m: Optional[int] = None, ): - super().__init__() + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a8 = use_int8_w8a8 self.use_int8_w8a16 = use_int8_w8a16 - self.block_shape = block_shape - self.block_m = block_m - self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16) - self.per_channel_quant = per_channel_quant + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: return True + def supports_expert_map(self) -> bool: + return True + def workspace_shapes( self, a: torch.Tensor, @@ -1660,7 +1661,7 @@ def apply( use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, + per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape) self.activation(activation, intermediate_cache2, @@ -1669,8 +1670,8 @@ def apply( a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, - self.block_shape) + intermediate_cache2, a2_scale, self.quant_dtype, + self.per_act_token_quant, self.block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, @@ -1690,7 +1691,7 @@ def apply( use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, + per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape) @@ -1699,27 +1700,17 @@ def modular_triton_fused_moe( use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, - per_channel_quant: bool, + per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> mk.FusedMoEModularKernel: - qtype = get_config_qtype( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - ) return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - quant_dtype=qtype, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - ), + MoEPrepareAndFinalizeNoEP(), TritonExperts( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, + per_act_token_quant=per_act_token_quant, block_shape=block_shape, ), ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 65a46ba5554b..6f9770262856 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,27 +3,30 @@ from abc import abstractmethod from collections.abc import Iterable -from dataclasses import dataclass from enum import Enum -from typing import Callable, Literal, Optional, Union, overload +from typing import Callable, Literal, Optional, overload import torch import torch.nn.functional as F -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import ParallelConfig, get_current_vllm_config +from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_world_group, tensor_model_parallel_all_reduce) from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +# yapf: disable +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig) +# yapf: enable +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.base_config import ( @@ -36,14 +39,12 @@ if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts - from .modular_kernel import (FusedMoEModularKernel, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize) if has_pplx(): - from .pplx_prepare_finalize import PplxPrepareAndFinalize + from .pplx_prepare_finalize import (PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes) if has_deep_ep(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE, + from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) else: fused_experts = None # type: ignore @@ -60,207 +61,8 @@ from .moe_pallas import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore -logger = init_logger(__name__) - - -@dataclass -class FusedMoEParallelConfig: - tp_size: int - dp_size: int - ep_size: int - tp_rank: int - dp_rank: int - ep_rank: int - - use_ep: bool # whether to use EP or not - - @property - def use_all2all_kernels(self): - return self.dp_size > 1 and self.use_ep - - @property - def use_pplx_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "pplx") - - @property - def use_deepep_ht_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") - - @property - def use_deepep_ll_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") - - @staticmethod - def make(tp_size_: int, dp_size_: int, - vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": - """ - Determine MoE parallel configuration. Based on the input tp_size_, - dp_size_, ep_size_ and vllm's parallel config, determine what - level's of parallelism to use in the fused moe layer. - - Args: - tp_size_ (int): tp_size passed into the FusedMoE constructor. - dp_size_ (int): dp_size passed into the FusedMoE constructor. - ep_size_ (int): ep_size passed into the FusedMoE constructor. - vllm_parallel_config (ParallelConfig): vllm's parallel config - object. - - Examples: - When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, - we simply return the sizes unaltered and the ranks set to 0. - - Expert Parallelism is considered only when either dp_size_ or tp_size_ - is non trivial. - - When TP = 2, DP = 1 and EP = False, the configuration on different - devices, - - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // - legend : {size, rank} - - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} - - Comment : Tensors are sharded across 2 devices. - - When TP = 1, DP = 2 and EP = False, the configuration on different - devices, - - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} - - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} - - Comment: There are 2 engine instances and the tensors are sharded - across 2 decvices. - - When TP = 2, DP = 2 and EP = False, the configuration on different - devices, - - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} - - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} - - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} - - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} - - Comment: There are 2 engine instances and the tensors are sharded - across 4 devices. - - When, TP = 2, DP = 1 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} - - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} - - Comment: The experts are split between the 2 devices. - - When, TP = 1, DP = 2 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} - - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} - - Comment: There are 2 engine instances and the experts are split - between the 2 devices. - - When TP = 2, DP = 2 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} - - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} - - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} - - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} - - Comment: There are 2 engine instances and the experts are split - between the 4 devices. - """ - - def flatten_tp_across_dp(dp_rank: int): - tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() - # There are actually dp_size_ * tp_size_ devices. Update tp_size - # and tp_rank so we shard across all devices. - tp_size = dp_size_ * tp_size_ - tp_rank = dp_rank * tp_size_ + tp_rank - return tp_size, tp_rank - - use_ep = (dp_size_ * tp_size_ > 1 - and vllm_parallel_config.enable_expert_parallel) - - dp_size = dp_size_ - dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 - tp_size, tp_rank = flatten_tp_across_dp(dp_rank) - - if not use_ep: - return FusedMoEParallelConfig(tp_size=tp_size, - tp_rank=tp_rank, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=1, - ep_rank=0, - use_ep=False) - # DP + EP / TP + EP / DP + TP + EP - assert use_ep - # In EP, each device owns a set of experts fully. There is no tensor - # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. - ep_size = tp_size - ep_rank = tp_rank - return FusedMoEParallelConfig(tp_size=1, - tp_rank=0, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - use_ep=True) - - -# Adapted from pplx-kernels tests/all_to_all_utils.py -@dataclass -class MoEConfig: - num_experts: int - experts_per_token: int - hidden_dim: int - - num_local_experts: int - moe_parallel_config: FusedMoEParallelConfig - - in_dtype: torch.dtype # The activation type. - quant_dtype: torch.dtype = None - - # TODO: add more quantization params, blocked, per-token, etc. - block_size: int = 128 - - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE - - def __post_init__(self): - if self.dp_size > 1: - logger.debug("Using MOEConfig::max_num_tokens=%d", - self.max_num_tokens) - - @property - def tp_size(self): - return self.moe_parallel_config.tp_size - - @property - def dp_size(self): - return self.moe_parallel_config.dp_size - - @property - def ep_size(self): - return self.moe_parallel_config.ep_size - - @property - def tp_rank(self): - return self.moe_parallel_config.tp_rank - - @property - def dp_rank(self): - return self.moe_parallel_config.dp_rank - - @property - def ep_rank(self): - return self.moe_parallel_config.ep_rank - - @property - def use_ep(self): - return self.moe_parallel_config.use_ep - @property - def use_pplx_kernels(self): - return self.moe_parallel_config.use_pplx_kernels - - @property - def use_deepep_ht_kernels(self): - return self.moe_parallel_config.use_deepep_ht_kernels - - @property - def use_deepep_ll_kernels(self): - return self.moe_parallel_config.use_deepep_ll_kernels +logger = init_logger(__name__) class FusedMoeWeightScaleSupported(Enum): @@ -270,21 +72,9 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" -def get_quant_config_input_activations( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') - and "Linear" in quant_config.target_scheme_map and - "input_activations" in quant_config.target_scheme_map["Linear"]): - return quant_config.target_scheme_map["Linear"].get( - "input_activations") - else: - return None - - class FusedMoEMethodBase(QuantizeMethodBase): - moe: MoEConfig + moe: FusedMoEConfig @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -292,23 +82,25 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def init_prepare_finalize(self, moe: MoEConfig, + def init_prepare_finalize(self, moe: FusedMoEConfig, quant_config: Optional[QuantizationConfig]): all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None self.moe = moe - quant_dtype = None - act_quant_block_size = None - from vllm.model_executor.layers.quantization.fp8 import Fp8Config - if isinstance(quant_config, Fp8Config): - act_quant_block_size = quant_config.weight_block_size - quant_dtype = torch.float8_e4m3fn - - prepare_finalize: Optional[Union[PplxPrepareAndFinalize, - DeepEPHTPrepareAndFinalize, - DeepEPLLPrepareAndFinalize]] = None + + prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + if moe.use_pplx_kernels: + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( + moe.max_num_tokens, + moe.hidden_dim, + moe.in_dtype, + moe.quant_dtype, + per_act_token_quant=moe.per_act_token_quant, + block_shape=moe.block_shape, + ) + all_to_all_args = dict( max_num_tokens=moe.max_num_tokens, num_experts=moe.num_experts, @@ -318,14 +110,8 @@ def init_prepare_finalize(self, moe: MoEConfig, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=( - 0 if moe.quant_dtype.itemsize != 1 else - ((moe.hidden_dim + moe.block_size - 1) // moe.block_size * - torch.float32.itemsize)), + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=hidden_scale_bytes, ) # Intranode pplx a2a takes a group name while internode does not. @@ -335,9 +121,6 @@ def init_prepare_finalize(self, moe: MoEConfig, handle = all2all_manager.get_handle(all_to_all_args) - input_activations = get_quant_config_input_activations( - quant_config) - prepare_finalize = PplxPrepareAndFinalize( handle, max_num_tokens=moe.max_num_tokens, @@ -345,10 +128,6 @@ def init_prepare_finalize(self, moe: MoEConfig, rank=all2all_manager.rank, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, - quant_dtype=moe.quant_dtype, - per_act_token=(input_activations.strategy - == QuantizationStrategy.TOKEN - if input_activations is not None else False), ) elif moe.use_deepep_ht_kernels: assert moe.dp_size == all2all_manager.dp_world_size @@ -362,8 +141,6 @@ def init_prepare_finalize(self, moe: MoEConfig, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * moe.num_local_experts, - quant_dtype=quant_dtype, - block_shape=act_quant_block_size, ) elif moe.use_deepep_ll_kernels: @@ -380,25 +157,25 @@ def init_prepare_finalize(self, moe: MoEConfig, # Note : We may want to use FP8 dispatch even otherwise just to # reduce datamovement - assert act_quant_block_size is not None - use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() - and act_quant_block_size[1] - == DEEPEP_QUANT_BLOCK_SIZE) + use_fp8_dispatch = (moe.quant_config is not None + and moe.quant_config.quant_dtype + == current_platform.fp8_dtype() + and moe.quant_config.block_shape + == DEEPEP_QUANT_BLOCK_SHAPE) # Note (varun): Whether to use FP8 dispatch or not needs some # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( handle, + max_tokens_per_rank=moe.max_num_tokens, world_size=all2all_manager.world_size, dp_size=all2all_manager.dp_world_size, - max_tokens_per_rank=moe.max_num_tokens, - quant_dtype=quant_dtype, - block_shape=act_quant_block_size, use_fp8_dispatch=use_fp8_dispatch, ) self.topk_indices_dtype = None if prepare_finalize is not None: + logger.debug("%s", prepare_finalize.__class__.__name__) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, moe) self.fused_experts = FusedMoEModularKernel( @@ -407,13 +184,15 @@ def init_prepare_finalize(self, moe: MoEConfig, ) def select_gemm_impl( - self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute: + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation raise NotImplementedError( - "Subclass must select appropriate gemm implementation" - " based on the prepare_finalize") + f"{self.__class__.__name__} must select appropriate gemm " + "implementation based on the prepare_finalize") @abstractmethod def apply( @@ -445,7 +224,7 @@ def apply( class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def __init__(self, moe: MoEConfig): + def __init__(self, moe: FusedMoEConfig): super().__init__() self.fused_experts = fused_experts # type: ignore self.topk_indices_dtype = None @@ -458,44 +237,30 @@ def __init__(self, moe: MoEConfig): else: self.rocm_aiter_fused_experts = None # type: ignore - def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: Optional[MoEConfig]): + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: assert self.fused_experts == fused_experts all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - - use_batched_experts = prepare_finalize.max_num_tokens_per_rank( - ) is not None - if use_batched_experts: + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) assert self.moe.dp_size == all2all_manager.dp_world_size - experts = BatchedTritonExperts( + return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, world_size=all2all_manager.world_size, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, ) else: logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) - return experts + return TritonExperts() def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -883,13 +648,18 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + tp_size_ = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + dp_size_ = (dp_size + if dp_size is not None else get_dp_group().world_size) + world_size_ = get_world_group().world_size + vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( - tp_size_=(tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()), - dp_size_=(dp_size if dp_size is not None else - get_dp_group().world_size), + tp_size_=tp_size_, + dp_size_=dp_size_, + world_size_=world_size_, vllm_parallel_config=vllm_config.parallel_config)) self.global_num_experts = num_experts + num_redundant_experts @@ -948,25 +718,22 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - # Only support float8 for now. - quant_dtype = params_dtype - if quant_config is not None: - input_activations = get_quant_config_input_activations( - quant_config) - if (input_activations is not None - and input_activations.num_bits == 8 - and input_activations.type == QuantizationType.FLOAT): - quant_dtype = torch.float8_e4m3fn - - moe = MoEConfig( + if vllm_config.model_config is not None: + model_dtype = vllm_config.model_config.dtype + else: + # TODO (bnell): This is a hack to get test_mixtral_moe to work + # since model_config is not set in the pytest test. + model_dtype = params_dtype + + moe = FusedMoEConfig.make( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, - in_dtype=params_dtype, - quant_dtype=quant_dtype, + in_dtype=model_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config=quant_config, ) self.moe_config = moe self.quant_config = quant_config @@ -1017,16 +784,15 @@ def __init__( self.batched_router_logits: Optional[torch.Tensor] = None if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels): - act_dtype = vllm_config.model_config.dtype self.batched_hidden_states = torch.zeros( - (envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size), - dtype=act_dtype, + (moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, device=torch.cuda.current_device()) # Note here we use `num_experts` which is logical expert count self.batched_router_logits = torch.zeros( - (envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts), - dtype=act_dtype, + (moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, device=torch.cuda.current_device()) @property @@ -1588,7 +1354,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): assert (self.batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (self.batched_router_logits.size(0) # type: ignore >= chunk_size) staged_hidden_states = self.batched_hidden_states[: chunk_size, :] # type: ignore diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index d25d70d3eff1..2ffb4d328eca 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from enum import Enum from math import prod -from typing import Optional +from typing import Optional, final import torch import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.utils import cdiv @@ -82,6 +84,18 @@ def _moe_problem_size( return E, M, N, K, topk +class FusedMoEActivationFormat(Enum): + """ + The standard activation format (num_tokens, hidden dim). + """ + Standard = "standard", + """ + The batched experts format (num experts, max tokens per expert, hidden dim) + """ + BatchedExperts = "batched_experts", + + +# TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ An abstract base class for the [Quantize-Prepare] and [Finalize] steps @@ -99,6 +113,7 @@ def prepare( num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -148,6 +163,15 @@ def finalize( """ raise NotImplementedError + @property + @abstractmethod + def activation_format(self) -> FusedMoEActivationFormat: + """ + A property indicating the output format of the activations for the + 'prepare' method. + """ + raise NotImplementedError + @abstractmethod def topk_indices_dtype(self) -> Optional[torch.dtype]: """ @@ -176,6 +200,41 @@ class FusedMoEPermuteExpertsUnpermute(ABC): above. """ + def __init__( + self, + quant_config: Optional[FusedMoEQuantConfig], + ): + if quant_config is not None: + self.quant_config = quant_config + else: + self.quant_config = FusedMoEQuantConfig() + + @property + @abstractmethod + def activation_formats( + self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + """ + A property which is a tuple of the input and output activation formats + for the 'apply' method. + """ + raise NotImplementedError + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + return self.quant_config.quant_dtype + + @property + def block_shape(self) -> Optional[list[int]]: + return self.quant_config.block_shape + + @property + def per_act_token_quant(self) -> bool: + return self.quant_config.per_act_token_quant + + @property + def per_out_ch_quant(self) -> bool: + return self.quant_config.per_out_ch_quant + # TODO (bnell): make this return a CHUNK_SIZE or None instead? @abstractmethod def supports_chunking(self) -> bool: @@ -185,6 +244,13 @@ def supports_chunking(self) -> bool: """ raise NotImplementedError + @abstractmethod + def supports_expert_map(self) -> bool: + """ + A flag indicating whether or not this class supports expert maps + """ + raise NotImplementedError + @abstractmethod def workspace_shapes( self, @@ -297,6 +363,7 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, return None +@final class FusedMoEModularKernel(torch.nn.Module): """ This class combines a FusedMoEPrepareAndFinalize instance and @@ -318,6 +385,12 @@ def __init__( super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts + assert prepare_finalize.activation_format == \ + fused_experts.activation_formats[0], ( + f"{prepare_finalize.__class__.__name__}." + f"{prepare_finalize.activation_format} == " + f"{fused_experts.__class__.__name__}." + f"{fused_experts.activation_formats[0]}") def forward( self, @@ -383,8 +456,16 @@ def forward( (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( - a1, a1_scale, a2_scale, topk_weights, topk_ids, - global_num_experts, expert_map, apply_router_weight_on_input) + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 2ff8ef99b2ec..45e813287d3f 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -6,33 +6,76 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.utils import cdiv, round_up + + +def pplx_hidden_dim_scale_bytes( + max_num_tokens: int, + hidden_dim: int, + in_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]], +): + # All pplx byte sizes must be 16-byte aligned. + align = 16 + + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to 4 * sizeof(float32) (x4 for alignment) + if quant_dtype is not None: + assert quant_dtype.itemsize == 1 + hidden_dim_bytes = hidden_dim * quant_dtype.itemsize + elem_size = torch.float32.itemsize + + if per_act_token_quant: + # per-token + assert block_shape is None + hidden_scale_bytes = elem_size + elif block_shape is not None: + # per-group + block_size = block_shape[1] + num_blocks = cdiv(hidden_dim, block_size) + hidden_scale_bytes = num_blocks * elem_size + else: + # per-tensor + hidden_scale_bytes = elem_size + else: + hidden_dim_bytes = hidden_dim * in_dtype.itemsize + hidden_scale_bytes = 0 + + return ( + round_up(hidden_dim_bytes, align), + round_up(hidden_scale_bytes, align), + ) # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): - def __init__(self, - a2a: pplx.AllToAll, - max_num_tokens: int, - world_size: int, - rank: int, - dp_size: int, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None, - per_act_token: bool = False): + def __init__( + self, + a2a: pplx.AllToAll, + max_num_tokens: int, + world_size: int, + rank: int, + dp_size: int, + ): super().__init__() assert max_num_tokens > 0 self.a2a = a2a - self.block_shape = block_shape self.max_num_tokens = max_num_tokens self.world_size = world_size self.rank = rank self.dp_size = dp_size - self.quant_dtype = quant_dtype - self.per_act_token = per_act_token + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens @@ -45,36 +88,43 @@ def prepare( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - rank_topk_weights: torch.Tensor, - rank_topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K - assert rank_topk_ids.size(0) == num_tokens + assert topk_ids.size(0) == num_tokens # assert expert_map is None, "NYI" # Is this always going to be a1.device? device = a1.device if apply_router_weight_on_input: - topk = rank_topk_ids.size(1) + topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * rank_topk_weights.to(a1.dtype) + a1 = a1 * topk_weights.to(a1.dtype) repeat_cols = 4 - repeat_rows = 1 if self.per_act_token else a1.size(0) + repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) a1q, a1q_scale = moe_kernel_quantize_input( - a1, (None if self.per_act_token else a1_scale), self.quant_dtype, - self.per_act_token, self.block_shape) + a1, (None if quant_config.per_act_token_quant else a1_scale), + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape) if a1q_scale is not None: + if a1q_scale.numel() == 1: + orig_a_scale_block_shape = 1 + else: + orig_a_scale_block_shape = a1q_scale.shape[-1] a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) # rem_experts need to be 0 for pplx to work properly. @@ -98,15 +148,12 @@ def prepare( expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: - float32_size = torch.float32.itemsize - block_size = (self.block_shape[0] if self.block_shape is not None - else 1) * float32_size + block_size = (quant_config.block_shape[1] + if quant_config.block_shape is not None else 1) expert_x_scale = torch.empty( - ( - num_local_experts, - expert_x.size(1), - (expert_x.size(2) + block_size - 1) // block_size, - ), + (num_local_experts, expert_x.size(1), + round_up( + (expert_x.size(2) + block_size - 1) // block_size, 4)), dtype=torch.float32, device=device, ) @@ -121,11 +168,11 @@ def prepare( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=rank_topk_ids, + indices=topk_ids, bound_m=bound_m, ) if expert_x_scale is not None: - expert_x_scale = expert_x_scale[:, :, 0:1] + expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] return expert_x, expert_x_scale, expert_num_tokens, None, None diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 9ed95e1de9fe..9e4be82f6c1f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.utils import ( @@ -13,16 +14,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): - def __init__( - self, - quant_dtype: Optional[torch.dtype] = None, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, - ): - super().__init__() - self.per_channel_quant = per_channel_quant - self.block_shape = block_shape - self.quant_dtype = quant_dtype + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def max_num_tokens_per_rank(self) -> Optional[int]: return None @@ -39,7 +33,8 @@ def prepare( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool = False, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -50,10 +45,9 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, - self.quant_dtype, - self.per_channel_quant, - self.block_shape) + a1q, a1q_scale = moe_kernel_quantize_input( + a1, a1_scale, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 4bbfea446e29..e660376ebe6b 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts @@ -12,34 +13,59 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, - block_m: Optional[int] = None, - allow_deep_gemm: bool = False): - super().__init__() - self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int4_w4a16=use_int4_w4a16, - use_int8_w8a16=use_int8_w8a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - block_m=block_m) - self.allow_deep_gemm = allow_deep_gemm - self.use_fp8_w8a8 = use_fp8_w8a8 + def __init__( + self, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, + allow_deep_gemm: bool = False, + ): + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + self.triton_expert = TritonExperts( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant + and use_fp8_w8a8) self.deep_gemm_expert = DeepGemmExperts( ) if self.allow_deep_gemm else None + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + assert (self.deep_gemm_expert is None + or self.triton_expert.activation_formats + == self.deep_gemm_expert.activation_formats) + return self.triton_expert.activation_formats + def supports_chunking(self) -> bool: dge = self.deep_gemm_expert te = self.triton_expert return ((dge is None or dge.supports_chunking()) and (te is None or te.supports_chunking())) + def supports_expert_map(self) -> bool: + dge = self.deep_gemm_expert + te = self.triton_expert + return ((dge is None or dge.supports_expert_map()) + and (te is None or te.supports_expert_map())) + def workspace_shapes( self, a: torch.Tensor, @@ -83,9 +109,7 @@ def apply( workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], ): - N = w1.size(1) - - use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + use_deep_gemm = (self.allow_deep_gemm and _valid_deep_gemm(hidden_states, w1, w2)) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 692482c2ea69..52346f797440 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -37,6 +37,7 @@ def _fp8_quantize( A, A_scale = ops.scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_act_token) else: + assert not per_act_token assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_fp8(A, block_k) @@ -64,6 +65,7 @@ def _int8_quantize( "int8 quantization only supports block or channel-wise" A, A_scale = per_token_quant_int8(A) else: + assert not per_act_token assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_int8(A, block_k) @@ -75,16 +77,15 @@ def _int8_quantize( def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], - qtype: Optional[torch.dtype], - per_channel_quant: bool, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if qtype == torch.float8_e4m3fn: - return _fp8_quantize(A, A_scale, per_channel_quant, block_shape) - elif qtype == torch.int8: - return _int8_quantize(A, A_scale, per_channel_quant, block_shape) + if quant_dtype == torch.float8_e4m3fn: + return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == torch.int8: + return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) else: - assert A_scale is None return A, A_scale @@ -96,3 +97,17 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] + + +# TODO(bnell): better name +def maybe_fix_scales(scales: Optional[torch.Tensor], + num_experts: int) -> Optional[torch.Tensor]: + if scales is not None and scales.ndim < 3: + if scales.numel() == 1: + scales = scales.view(1) + scales = torch.repeat_interleave(scales, num_experts, + dim=0).view(num_experts, 1, 1) + else: + scales = scales.view(num_experts, -1, scales.size(-1)) + + return scales diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 92b82f5a02ff..fa011266cf2f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -13,8 +13,10 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe import ( + CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, + FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, fused_experts) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) from vllm.model_executor.layers.quantization.utils import replace_parameter @@ -32,14 +34,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils import has_pplx - -if current_platform.is_cuda_alike(): - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize) - if has_pplx(): - from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) logger = init_logger(__name__) @@ -569,15 +563,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False) self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts - else: - from vllm.model_executor.layers.fused_moe import fused_experts - self.fused_experts_func = fused_experts - - if self.use_marlin: + elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale + self.fused_experts_func = None + else: + self.fused_experts_func = fused_experts def apply( self, @@ -653,6 +646,8 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) + assert self.fused_experts_func is not None + return self.fused_experts_func( hidden_states=x, w1=layer.w13_weight, @@ -826,28 +821,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) - def select_gemm_impl(self, prepare_finalize, moe): - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp8) + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: - assert moe is not None + use_batched_format = (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts) + + num_experts = (moe.num_local_experts + if use_batched_format else moe.num_experts) - max_experts_per_worker = ( - (moe.num_experts + prepare_finalize.world_size - 1) // - prepare_finalize.world_size) experts = CutlassExpertsFp8( - max_experts_per_worker, + num_experts, moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - use_batched_format=True, + use_batched_format=use_batched_format, ) - if has_pplx() and isinstance( - prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - # no expert_map support in this case - self.disable_expert_map = True + self.disable_expert_map = not experts.supports_expert_map() return experts def apply( @@ -888,7 +882,8 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=torch.uint32) + indices_type=self.topk_indices_dtype, + ) return self.fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ead345c794b8..0295f5e2a1c8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional import torch import torch.nn.functional as F @@ -13,8 +13,11 @@ from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe import ( + BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat, + FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, + TritonOrDeepGemmExperts) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -777,44 +780,46 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale - def select_gemm_impl(self, prepare_finalize, moe): - - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) - from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) - + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.") - experts: Optional[Union[BatchedTritonOrDeepGemmExperts, - TritonOrDeepGemmExperts]] = None - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - use_batched_experts = max_num_tokens_per_rank is not None - - if use_batched_experts: - experts = BatchedTritonOrDeepGemmExperts( + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + max_num_tokens_per_rank = ( + prepare_finalize.max_num_tokens_per_rank()) + assert max_num_tokens_per_rank is not None + logger.debug( + "BatchedTritonOrDeepGemmExperts(%s): " + "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", + self.__class__.__name__, max_num_tokens_per_rank, + self.quant_config.weight_block_size, False) + return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, - world_size=prepare_finalize.world_size, - dp_size=prepare_finalize.dp_size, + world_size=prepare_finalize. + world_size, # type: ignore [attr-defined] + dp_size=prepare_finalize. + dp_size, # type: ignore [attr-defined] use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, block_shape=self.quant_config.weight_block_size, + per_act_token_quant=False, allow_deep_gemm=self.allow_deep_gemm, ) else: - experts = TritonOrDeepGemmExperts( + logger.debug( + "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", + self.__class__.__name__, self.quant_config.weight_block_size, + False) + return TritonOrDeepGemmExperts( use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm, ) - assert experts is not None - return experts - def apply( self, layer: torch.nn.Module,