|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +# Adapted from https://github.com/sgl-project/sglang/pull/2575 |
| 5 | +import itertools |
| 6 | + |
| 7 | +import pytest |
| 8 | +import torch |
| 9 | + |
| 10 | +from tests.kernels.quant_utils import (native_w8a8_block_matmul, |
| 11 | + native_per_token_group_quant_fp8, |
| 12 | + per_block_cast_to_fp8) |
| 13 | +from vllm.config import VllmConfig, set_current_vllm_config |
| 14 | +from vllm.model_executor.layers.activation import SiluAndMul |
| 15 | +from vllm.model_executor.layers.fused_moe import fused_moe |
| 16 | +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( |
| 17 | + _valid_deep_gemm_shape, deep_gemm_moe_fp8) |
| 18 | +from vllm.model_executor.layers.fused_moe.fused_moe import ( |
| 19 | + fused_topk, modular_triton_fused_moe) |
| 20 | +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( |
| 21 | + moe_align_block_size) |
| 22 | +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( |
| 23 | + per_token_group_quant_fp8, w8a8_block_fp8_matmul) |
| 24 | +from vllm.platforms import current_platform |
| 25 | + |
| 26 | +dg_available = False |
| 27 | +try: |
| 28 | + import deep_gemm |
| 29 | + dg_available = True |
| 30 | +except ImportError: |
| 31 | + pass |
| 32 | + |
| 33 | +if current_platform.get_device_capability() < (9, 0): |
| 34 | + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", |
| 35 | + allow_module_level=True) |
| 36 | + |
| 37 | +vllm_config = VllmConfig() |
| 38 | +vllm_config.scheduler_config.max_num_seqs = 128 |
| 39 | +vllm_config.scheduler_config.max_model_len = 8192 |
| 40 | + |
| 41 | +# Test configurations |
| 42 | +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] |
| 43 | +NUM_TOKENS = [7, 2050] |
| 44 | +D = [512, 4096, 5120, 13824] |
| 45 | +GROUP_SIZE = [64, 128, 512] |
| 46 | +# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 |
| 47 | +# and its hidden size is 7168. |
| 48 | +M = [1, 2, 83, 128, 2048, 1024 * 128] |
| 49 | +M_dg = [128, 192, 1335, 2048] |
| 50 | +N = [128, 256, 1024, 4608] # [13824] |
| 51 | +K = [256, 512, 7168] # [13824] |
| 52 | +BLOCK_SIZE = [[128, 128]] |
| 53 | +E = [2, 8, 16, 24] # [128, 256] |
| 54 | +TOP_KS = [1, 2, 6] |
| 55 | +OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] |
| 56 | +SEEDS = [0] |
| 57 | + |
| 58 | + |
| 59 | +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): |
| 60 | + """Fused moe with block-wise quantization using native torch.""" |
| 61 | + B, D = a.shape |
| 62 | + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) |
| 63 | + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) |
| 64 | + score = torch.softmax(score, dim=-1, dtype=torch.float32) |
| 65 | + topk_weight, topk_ids = torch.topk(score, topk) |
| 66 | + topk_weight = topk_weight.view(-1) |
| 67 | + topk_ids = topk_ids.view(-1) |
| 68 | + |
| 69 | + _, block_k = block_shape[0], block_shape[1] |
| 70 | + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) |
| 71 | + a_q = a_q.to(torch.float32) |
| 72 | + for i in range(w1.shape[0]): |
| 73 | + mask = topk_ids == i |
| 74 | + if mask.sum(): |
| 75 | + inter_out = native_w8a8_block_matmul(a_q[mask], |
| 76 | + w1[i], |
| 77 | + a_s[mask], |
| 78 | + w1_s[i], |
| 79 | + block_shape, |
| 80 | + output_dtype=a.dtype) |
| 81 | + act_out = SiluAndMul().forward_native(inter_out) |
| 82 | + act_out_q, act_out_s = native_per_token_group_quant_fp8( |
| 83 | + act_out, block_k) |
| 84 | + out[mask] = native_w8a8_block_matmul(act_out_q, |
| 85 | + w2[i], |
| 86 | + act_out_s, |
| 87 | + w2_s[i], |
| 88 | + block_shape, |
| 89 | + output_dtype=a.dtype) |
| 90 | + return (out.view(B, -1, w2.shape[1]) * |
| 91 | + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) |
| 92 | + |
| 93 | + |
| 94 | +# Skip all tests if CUDA is not available |
| 95 | +pytest.importorskip("torch.cuda") |
| 96 | + |
| 97 | + |
| 98 | +@pytest.fixture(autouse=True) |
| 99 | +def setup_cuda(): |
| 100 | + torch.set_default_device("cuda") |
| 101 | + |
| 102 | + |
| 103 | +@pytest.mark.parametrize( |
| 104 | + "M,N,K,E,topk,block_size,dtype,seed", |
| 105 | + itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) |
| 106 | +@torch.inference_mode() |
| 107 | +def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): |
| 108 | + if topk > E: |
| 109 | + pytest.skip(f"Skipping test; topk={topk} > E={E}") |
| 110 | + |
| 111 | + torch.manual_seed(seed) |
| 112 | + factor_for_scale = 1e-2 |
| 113 | + fp8_info = torch.finfo(torch.float8_e4m3fn) |
| 114 | + fp8_max, fp8_min = fp8_info.max, fp8_info.min |
| 115 | + |
| 116 | + a = torch.randn((M, K), dtype=dtype) / 10 |
| 117 | + |
| 118 | + w1_bf16 = (torch.rand( |
| 119 | + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max |
| 120 | + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 121 | + del w1_bf16 |
| 122 | + |
| 123 | + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max |
| 124 | + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 125 | + del w2_bf16 |
| 126 | + |
| 127 | + block_n, block_k = block_size[0], block_size[1] |
| 128 | + n_tiles_w1 = (2 * N + block_n - 1) // block_n |
| 129 | + n_tiles_w2 = (K + block_n - 1) // block_n |
| 130 | + k_tiles_w1 = (K + block_k - 1) // block_k |
| 131 | + k_tiles_w2 = (N + block_k - 1) // block_k |
| 132 | + |
| 133 | + w1_s = torch.rand( |
| 134 | + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale |
| 135 | + w2_s = torch.rand( |
| 136 | + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale |
| 137 | + |
| 138 | + score = torch.randn((M, E), dtype=dtype) |
| 139 | + |
| 140 | + m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, |
| 141 | + use_int8_w8a8=False, |
| 142 | + use_int8_w8a16=False, |
| 143 | + use_int4_w4a16=False, |
| 144 | + per_act_token_quant=False, |
| 145 | + block_shape=block_size) |
| 146 | + |
| 147 | + # Set the context to avoid lots of warning spam. |
| 148 | + with set_current_vllm_config(vllm_config): |
| 149 | + out = fused_moe( |
| 150 | + a, |
| 151 | + w1, |
| 152 | + w2, |
| 153 | + score, |
| 154 | + topk, |
| 155 | + renormalize=False, |
| 156 | + use_fp8_w8a8=True, |
| 157 | + w1_scale=w1_s, |
| 158 | + w2_scale=w2_s, |
| 159 | + block_shape=block_size, |
| 160 | + ) |
| 161 | + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, |
| 162 | + block_size) |
| 163 | + |
| 164 | + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) |
| 165 | + m_out = m_fused_moe(a, |
| 166 | + w1, |
| 167 | + w2, |
| 168 | + topk_weights, |
| 169 | + topk_ids, |
| 170 | + global_num_experts=E, |
| 171 | + w1_scale=w1_s, |
| 172 | + w2_scale=w2_s) |
| 173 | + |
| 174 | + #print(f"{out.sum()=}") |
| 175 | + #print(f"{ref_out.sum()=}") |
| 176 | + |
| 177 | + rel_diff = (torch.mean( |
| 178 | + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / |
| 179 | + torch.mean(torch.abs(ref_out.to(torch.float32)))) |
| 180 | + assert rel_diff < 0.03 |
| 181 | + |
| 182 | + rel_diff = (torch.mean( |
| 183 | + torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) / |
| 184 | + torch.mean(torch.abs(ref_out.to(torch.float32)))) |
| 185 | + assert rel_diff < 0.03 |
| 186 | + |
| 187 | + |
| 188 | +def fp8_perm(m, idx): |
| 189 | + if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: |
| 190 | + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) |
| 191 | + else: |
| 192 | + return m[idx, ...] |
| 193 | + |
| 194 | + |
| 195 | +def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): |
| 196 | + M, K = a.shape |
| 197 | + |
| 198 | + sorted_token_ids, m_indices, num_pad = moe_align_block_size( |
| 199 | + topk_ids, block_m, num_groups, None, pad_sorted_ids=True) |
| 200 | + |
| 201 | + num_tokens = topk * M |
| 202 | + |
| 203 | + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) |
| 204 | + m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) |
| 205 | + inv_perm = torch.argsort(sorted_token_ids)[:M * topk] |
| 206 | + |
| 207 | + a = fp8_perm(a, sorted_token_ids // topk) |
| 208 | + if a_s is not None: |
| 209 | + a_s = a_s[sorted_token_ids // topk] |
| 210 | + |
| 211 | + return a, a_s, m_indices, inv_perm |
| 212 | + |
| 213 | + |
| 214 | +def _moe_unpermute(out, inv_perm, topk, K, topk_weight): |
| 215 | + M = topk_weight.shape[0] |
| 216 | + out = out[inv_perm, ...] |
| 217 | + tmp_out = out.view(-1, topk, K) |
| 218 | + return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) |
| 219 | + |
| 220 | + |
| 221 | +def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, |
| 222 | + block_shape): |
| 223 | + """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" |
| 224 | + num_groups = w1.shape[0] |
| 225 | + M, K = a.shape |
| 226 | + N = w2.shape[-1] |
| 227 | + |
| 228 | + topk_weight, topk_ids, token_expert_indices = fused_topk( |
| 229 | + a, score.float(), topk, False) |
| 230 | + |
| 231 | + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() |
| 232 | + |
| 233 | + _, block_k = block_shape[0], block_shape[1] |
| 234 | + |
| 235 | + a_q, a_s = per_token_group_quant_fp8(a, block_m) |
| 236 | + |
| 237 | + a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, |
| 238 | + num_groups, topk, block_m) |
| 239 | + |
| 240 | + inter_out = torch.zeros((a_q.shape[0], N * 2), |
| 241 | + dtype=torch.bfloat16, |
| 242 | + device=a.device) |
| 243 | + |
| 244 | + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), |
| 245 | + inter_out, m_indices) |
| 246 | + |
| 247 | + act_out = SiluAndMul().forward_native(inter_out) |
| 248 | + act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) |
| 249 | + |
| 250 | + out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) |
| 251 | + |
| 252 | + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( |
| 253 | + (act_out_q, act_out_s), (w2, w2_s), out, m_indices) |
| 254 | + |
| 255 | + final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) |
| 256 | + |
| 257 | + return final_out |
| 258 | + |
| 259 | + |
| 260 | +@pytest.mark.parametrize( |
| 261 | + "M,N,K,E,topk,seed", |
| 262 | + itertools.product(M_dg, N, K, E, TOP_KS, SEEDS)) |
| 263 | +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") |
| 264 | +@torch.inference_mode() |
| 265 | +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): |
| 266 | + |
| 267 | + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() |
| 268 | + block_size = [block_m, block_m] |
| 269 | + dtype = torch.bfloat16 |
| 270 | + |
| 271 | + if topk > E: |
| 272 | + pytest.skip(f"Skipping test: topk={topk} > E={E}") |
| 273 | + |
| 274 | + if not _valid_deep_gemm_shape(M, N, K): |
| 275 | + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") |
| 276 | + |
| 277 | + torch.manual_seed(seed) |
| 278 | + fp8_info = torch.finfo(torch.float8_e4m3fn) |
| 279 | + fp8_max, fp8_min = fp8_info.max, fp8_info.min |
| 280 | + |
| 281 | + a = torch.randn((M, K), dtype=dtype) / 10 |
| 282 | + |
| 283 | + w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * |
| 284 | + fp8_max).clamp(min=fp8_min, max=fp8_max) |
| 285 | + |
| 286 | + w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * |
| 287 | + fp8_max).clamp(min=fp8_min, max=fp8_max) |
| 288 | + |
| 289 | + score = torch.randn((M, E), dtype=dtype) |
| 290 | + |
| 291 | + block_n, block_k = block_size[0], block_size[1] |
| 292 | + n_tiles_w1 = ((2 * N) + block_n - 1) // block_n |
| 293 | + k_tiles_w1 = (K + block_k - 1) // block_k |
| 294 | + n_tiles_w2 = (K + block_n - 1) // block_n |
| 295 | + k_tiles_w2 = (N + block_k - 1) // block_k |
| 296 | + |
| 297 | + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) |
| 298 | + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) |
| 299 | + |
| 300 | + w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) |
| 301 | + w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) |
| 302 | + |
| 303 | + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() |
| 304 | + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() |
| 305 | + |
| 306 | + assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) |
| 307 | + assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] |
| 308 | + |
| 309 | + for i in range(E): |
| 310 | + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) |
| 311 | + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) |
| 312 | + |
| 313 | + # Set the context to avoid lots of warning spam. |
| 314 | + with set_current_vllm_config(vllm_config): |
| 315 | + if M >= 128: |
| 316 | + ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, |
| 317 | + score, topk, block_size) |
| 318 | + else: |
| 319 | + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, |
| 320 | + topk, block_size) |
| 321 | + |
| 322 | + topk_weights, topk_ids, token_expert_indices = fused_topk( |
| 323 | + a, score.float(), topk, False) |
| 324 | + |
| 325 | + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) |
| 326 | + |
| 327 | + #print(f"{out.sum()=}") |
| 328 | + #print(f"{ref_out.sum()=}") |
| 329 | + |
| 330 | + rel_diff = (torch.mean( |
| 331 | + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / |
| 332 | + torch.mean(torch.abs(ref_out.to(torch.float32)))) |
| 333 | + |
| 334 | + assert rel_diff < 0.03 |
0 commit comments