|
7 | 7 | import triton.language as tl
|
8 | 8 | from typing import Optional
|
9 | 9 |
|
| 10 | +import vllm._custom_ops as ops |
| 11 | +from vllm.config import VllmConfig, set_current_vllm_config |
| 12 | +from vllm.model_executor.layers.activation import SiluAndMul |
10 | 13 | from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
11 |
| - invoke_moe_batched_triton_kernel) |
| 14 | + invoke_moe_batched_triton_kernel, |
| 15 | + BatchedExperts, |
| 16 | + BatchedPrepareAndFinalize, |
| 17 | + BatchedTritonExperts) |
| 18 | +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, |
| 19 | + get_default_config) |
| 20 | +from vllm.model_executor.layers.fused_moe.modular_kernel import ( |
| 21 | + FusedMoEModularKernel) |
| 22 | +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( |
| 23 | + per_token_group_quant_fp8, w8a8_block_fp8_matmul) |
| 24 | +from vllm.platforms import current_platform |
| 25 | +from vllm.utils import round_up |
| 26 | + |
| 27 | + |
| 28 | +NUM_EXPERTS = [8, 64] |
| 29 | +TOP_KS = [1, 2, 6] |
| 30 | + |
| 31 | +vllm_config = VllmConfig() |
| 32 | +vllm_config.scheduler_config.max_num_seqs = 128 |
| 33 | +vllm_config.scheduler_config.max_model_len = 8192 |
12 | 34 |
|
13 | 35 |
|
14 | 36 | @dataclass
|
@@ -141,14 +163,13 @@ def ref_impl(
|
141 | 163 | B[e].transpose(0, 1),
|
142 | 164 | A_scale,
|
143 | 165 | B_scale,
|
144 |
| - [1,1])#block_shape) |
| 166 | + block_shape) |
145 | 167 | else:
|
146 |
| - import vllm._custom_ops as ops |
147 | 168 | tmp = ops.cutlass_scaled_mm(A[e, :, :],
|
148 | 169 | B[e].transpose(0, 1),
|
149 | 170 | A_scale,
|
150 | 171 | B_scale,
|
151 |
| - C.dtype) |
| 172 | + torch.bfloat16) |
152 | 173 | C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
153 | 174 | else:
|
154 | 175 | C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
@@ -194,8 +215,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
194 | 215 | #print(f"tensors.B {tensors.B.shape}")
|
195 | 216 |
|
196 | 217 | if use_fp8_w8a8:
|
197 |
| - #A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device) |
| 218 | + #A_scale = torch.ones((1, K), dtype=torch.float32, device=tensors.A.device) |
198 | 219 | #B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
| 220 | + #quant_block_shape = [N, K] |
199 | 221 | A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
|
200 | 222 | B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
|
201 | 223 | quant_block_shape = [1, 1]
|
@@ -251,3 +273,158 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
251 | 273 |
|
252 | 274 | torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
|
253 | 275 | torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
|
| 276 | + |
| 277 | + |
| 278 | +def batched_moe( |
| 279 | + a: torch.Tensor, |
| 280 | + w1: torch.Tensor, |
| 281 | + w2: torch.Tensor, |
| 282 | + topk_weight: torch.Tensor, |
| 283 | + topk_ids: torch.Tensor, |
| 284 | + w1_scale: Optional[torch.Tensor] = None, |
| 285 | + w2_scale: Optional[torch.Tensor] = None, |
| 286 | + use_fp8_w8a8: bool = False, |
| 287 | + block_shape: Optional[list[int]] = None, |
| 288 | +) -> torch.Tensor: |
| 289 | + max_num_tokens = round_up(a.shape[0], 64) # ? |
| 290 | + fused_experts = FusedMoEModularKernel( |
| 291 | + BatchedPrepareAndFinalize(max_num_tokens, world_size=1, dp_size=1, rank=0, use_fp8_w8a8=use_fp8_w8a8, |
| 292 | + block_shape=block_shape), |
| 293 | + BatchedTritonExperts(max_num_tokens=max_num_tokens, dp_size=1, world_size=1, |
| 294 | + use_fp8_w8a8=use_fp8_w8a8, |
| 295 | + block_shape=block_shape)) |
| 296 | + |
| 297 | + return fused_experts(a, |
| 298 | + w1, |
| 299 | + w2, |
| 300 | + topk_weight, |
| 301 | + topk_ids, |
| 302 | + w1_scale=w1_scale, |
| 303 | + w2_scale=w2_scale) |
| 304 | + |
| 305 | + |
| 306 | +# Note: same as torch_moe but with fused_topk factored out. |
| 307 | +def torch_moe2( |
| 308 | + a: torch.Tensor, |
| 309 | + w1: torch.Tensor, |
| 310 | + w2: torch.Tensor, |
| 311 | + topk_weight: torch.Tensor, |
| 312 | + topk_ids: torch.Tensor, |
| 313 | + w1_scale: Optional[torch.Tensor] = None, |
| 314 | + w2_scale: Optional[torch.Tensor] = None, |
| 315 | + use_fp8_w8a8: bool = False, |
| 316 | + block_shape: Optional[list[int]] = None, |
| 317 | +) -> torch.Tensor: |
| 318 | + M, K = a.shape |
| 319 | + topk = topk_ids.shape[1] |
| 320 | + |
| 321 | + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) |
| 322 | + |
| 323 | + if use_fp8_w8a8: |
| 324 | + a, a_scale = per_token_group_quant_fp8(a, block_shape[1]) |
| 325 | + #print(f"a_scale {a_scale.shape}") |
| 326 | + else: |
| 327 | + a_scale = None |
| 328 | + |
| 329 | + out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) |
| 330 | + num_experts = w1.shape[0] |
| 331 | + for i in range(num_experts): |
| 332 | + mask = (topk_ids == i).view(-1) |
| 333 | + if mask.sum(): |
| 334 | + if not use_fp8_w8a8: |
| 335 | + tmp1 = a[mask] @ w1[i].transpose(0, 1) |
| 336 | + tmp2 = SiluAndMul()(tmp1) |
| 337 | + out[mask] = tmp2 @ w2[i].transpose(0, 1) |
| 338 | + else: |
| 339 | + #tmp1 = ops.cutlass_scaled_mm(a[mask], |
| 340 | + # w1[i].transpose(0, 1), |
| 341 | + # a_scale[mask], |
| 342 | + # w1_scale[i], |
| 343 | + # torch.bfloat16) |
| 344 | + tmp1 = native_w8a8_block_matmul(a[mask], |
| 345 | + w1[i], |
| 346 | + a_scale[mask], |
| 347 | + w1_scale[i], |
| 348 | + block_shape, |
| 349 | + torch.bfloat16) |
| 350 | + tmp2 = SiluAndMul()(tmp1) |
| 351 | + tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1]) |
| 352 | + |
| 353 | + # out[mask] = ops.cutlass_scaled_mm(tmp2, |
| 354 | + # w2[i].transpose(0, 1), |
| 355 | + # b_scale, |
| 356 | + # w2_scale[i], |
| 357 | + # torch.bfloat16) |
| 358 | + out[mask] = native_w8a8_block_matmul(tmp2, |
| 359 | + w2[i], |
| 360 | + b_scale, |
| 361 | + w2_scale[i], |
| 362 | + block_shape, |
| 363 | + torch.bfloat16) |
| 364 | + |
| 365 | + return (out.view(M, -1, w2.shape[1]) * |
| 366 | + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) |
| 367 | + |
| 368 | + |
| 369 | +@pytest.mark.parametrize("m", [1, 33, 64, 222]) |
| 370 | +@pytest.mark.parametrize("n", [128, 1024, 2048]) |
| 371 | +@pytest.mark.parametrize("k", [128, 512, 1024]) |
| 372 | +@pytest.mark.parametrize("e", NUM_EXPERTS) |
| 373 | +@pytest.mark.parametrize("topk", TOP_KS) |
| 374 | +@pytest.mark.parametrize("dtype", [torch.torch.float8_e4m3fn, torch.bfloat16]) |
| 375 | +def test_fused_moe_batched_experts( |
| 376 | + m: int, |
| 377 | + n: int, |
| 378 | + k: int, |
| 379 | + e: int, |
| 380 | + topk: int, |
| 381 | + dtype: torch.dtype, |
| 382 | +): |
| 383 | + current_platform.seed_everything(7) |
| 384 | + block_shape = [128, 128] |
| 385 | + |
| 386 | + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 |
| 387 | + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10 |
| 388 | + w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10 |
| 389 | + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) |
| 390 | + |
| 391 | + use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn |
| 392 | + |
| 393 | + if use_fp8_w8a8: |
| 394 | + block_n, block_k = block_shape[0], block_shape[1] |
| 395 | + n_tiles_w1 = (2 * n + block_n - 1) // block_n |
| 396 | + n_tiles_w2 = (k + block_n - 1) // block_n |
| 397 | + k_tiles_w1 = (k + block_k - 1) // block_k |
| 398 | + k_tiles_w2 = (n + block_k - 1) // block_k |
| 399 | + |
| 400 | + finfo = torch.finfo(dtype) |
| 401 | + fp8_min = finfo.min |
| 402 | + fp8_max = finfo.max |
| 403 | + |
| 404 | + w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype) |
| 405 | + w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype) |
| 406 | + |
| 407 | + factor_for_scale = 1e-2 |
| 408 | + w1_s = torch.rand( |
| 409 | + (e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale |
| 410 | + w2_s = torch.rand( |
| 411 | + (e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale |
| 412 | + else: |
| 413 | + w1_s = None |
| 414 | + w2_s = None |
| 415 | + |
| 416 | + with set_current_vllm_config(vllm_config): |
| 417 | + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) |
| 418 | + baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape) |
| 419 | + batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape) |
| 420 | + # batched_output = batched_moe(a, |
| 421 | + # w1.to(torch.bfloat16), |
| 422 | + # w2.to(torch.bfloat16), |
| 423 | + # topk_weight, topk_ids, |
| 424 | + # w1_s, w2_s, False, |
| 425 | + # block_shape) |
| 426 | + |
| 427 | + torch.testing.assert_close(baseline_output, |
| 428 | + batched_output, |
| 429 | + atol=2e-2, |
| 430 | + rtol=0) |
0 commit comments