Skip to content

Commit 654cbea

Browse files
committed
fix merge. split up int8/fp8 moe tests
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 4e66291 commit 654cbea

File tree

13 files changed

+578
-486
lines changed

13 files changed

+578
-486
lines changed

tests/kernels/moe/test_block_fp8.py

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# Adapted from https://github.com/sgl-project/sglang/pull/2575
5+
import itertools
6+
7+
import pytest
8+
import torch
9+
10+
from tests.kernels.quant_utils import (native_w8a8_block_matmul,
11+
native_per_token_group_quant_fp8,
12+
per_block_cast_to_fp8)
13+
from vllm.config import VllmConfig, set_current_vllm_config
14+
from vllm.model_executor.layers.activation import SiluAndMul
15+
from vllm.model_executor.layers.fused_moe import fused_moe
16+
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
17+
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
18+
from vllm.model_executor.layers.fused_moe.fused_moe import (
19+
fused_topk, modular_triton_fused_moe)
20+
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
21+
moe_align_block_size)
22+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
23+
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
24+
from vllm.platforms import current_platform
25+
26+
dg_available = False
27+
try:
28+
import deep_gemm
29+
dg_available = True
30+
except ImportError:
31+
pass
32+
33+
if current_platform.get_device_capability() < (9, 0):
34+
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
35+
allow_module_level=True)
36+
37+
vllm_config = VllmConfig()
38+
vllm_config.scheduler_config.max_num_seqs = 128
39+
vllm_config.scheduler_config.max_model_len = 8192
40+
41+
# Test configurations
42+
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
43+
NUM_TOKENS = [7, 2050]
44+
D = [512, 4096, 5120, 13824]
45+
GROUP_SIZE = [64, 128, 512]
46+
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
47+
# and its hidden size is 7168.
48+
M = [1, 2, 83, 128, 2048, 1024 * 128]
49+
M_dg = [128, 192, 1335, 2048]
50+
N = [128, 256, 1024, 4608] # [13824]
51+
K = [256, 512, 7168] # [13824]
52+
BLOCK_SIZE = [[128, 128]]
53+
E = [2, 8, 16, 24] # [128, 256]
54+
TOP_KS = [1, 2, 6]
55+
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
56+
SEEDS = [0]
57+
58+
59+
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
60+
"""Fused moe with block-wise quantization using native torch."""
61+
B, D = a.shape
62+
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
63+
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
64+
score = torch.softmax(score, dim=-1, dtype=torch.float32)
65+
topk_weight, topk_ids = torch.topk(score, topk)
66+
topk_weight = topk_weight.view(-1)
67+
topk_ids = topk_ids.view(-1)
68+
69+
_, block_k = block_shape[0], block_shape[1]
70+
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
71+
a_q = a_q.to(torch.float32)
72+
for i in range(w1.shape[0]):
73+
mask = topk_ids == i
74+
if mask.sum():
75+
inter_out = native_w8a8_block_matmul(a_q[mask],
76+
w1[i],
77+
a_s[mask],
78+
w1_s[i],
79+
block_shape,
80+
output_dtype=a.dtype)
81+
act_out = SiluAndMul().forward_native(inter_out)
82+
act_out_q, act_out_s = native_per_token_group_quant_fp8(
83+
act_out, block_k)
84+
out[mask] = native_w8a8_block_matmul(act_out_q,
85+
w2[i],
86+
act_out_s,
87+
w2_s[i],
88+
block_shape,
89+
output_dtype=a.dtype)
90+
return (out.view(B, -1, w2.shape[1]) *
91+
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
92+
93+
94+
# Skip all tests if CUDA is not available
95+
pytest.importorskip("torch.cuda")
96+
97+
98+
@pytest.fixture(autouse=True)
99+
def setup_cuda():
100+
torch.set_default_device("cuda")
101+
102+
103+
@pytest.mark.parametrize(
104+
"M,N,K,E,topk,block_size,dtype,seed",
105+
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
106+
@torch.inference_mode()
107+
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
108+
if topk > E:
109+
pytest.skip(f"Skipping test; topk={topk} > E={E}")
110+
111+
torch.manual_seed(seed)
112+
factor_for_scale = 1e-2
113+
fp8_info = torch.finfo(torch.float8_e4m3fn)
114+
fp8_max, fp8_min = fp8_info.max, fp8_info.min
115+
116+
a = torch.randn((M, K), dtype=dtype) / 10
117+
118+
w1_bf16 = (torch.rand(
119+
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
120+
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
121+
del w1_bf16
122+
123+
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
124+
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
125+
del w2_bf16
126+
127+
block_n, block_k = block_size[0], block_size[1]
128+
n_tiles_w1 = (2 * N + block_n - 1) // block_n
129+
n_tiles_w2 = (K + block_n - 1) // block_n
130+
k_tiles_w1 = (K + block_k - 1) // block_k
131+
k_tiles_w2 = (N + block_k - 1) // block_k
132+
133+
w1_s = torch.rand(
134+
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
135+
w2_s = torch.rand(
136+
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
137+
138+
score = torch.randn((M, E), dtype=dtype)
139+
140+
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
141+
use_int8_w8a8=False,
142+
use_int8_w8a16=False,
143+
use_int4_w4a16=False,
144+
per_act_token_quant=False,
145+
block_shape=block_size)
146+
147+
# Set the context to avoid lots of warning spam.
148+
with set_current_vllm_config(vllm_config):
149+
out = fused_moe(
150+
a,
151+
w1,
152+
w2,
153+
score,
154+
topk,
155+
renormalize=False,
156+
use_fp8_w8a8=True,
157+
w1_scale=w1_s,
158+
w2_scale=w2_s,
159+
block_shape=block_size,
160+
)
161+
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
162+
block_size)
163+
164+
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
165+
m_out = m_fused_moe(a,
166+
w1,
167+
w2,
168+
topk_weights,
169+
topk_ids,
170+
global_num_experts=E,
171+
w1_scale=w1_s,
172+
w2_scale=w2_s)
173+
174+
#print(f"{out.sum()=}")
175+
#print(f"{ref_out.sum()=}")
176+
177+
rel_diff = (torch.mean(
178+
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
179+
torch.mean(torch.abs(ref_out.to(torch.float32))))
180+
assert rel_diff < 0.03
181+
182+
rel_diff = (torch.mean(
183+
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
184+
torch.mean(torch.abs(ref_out.to(torch.float32))))
185+
assert rel_diff < 0.03
186+
187+
188+
def fp8_perm(m, idx):
189+
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
190+
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
191+
else:
192+
return m[idx, ...]
193+
194+
195+
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
196+
M, K = a.shape
197+
198+
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
199+
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
200+
201+
num_tokens = topk * M
202+
203+
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
204+
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
205+
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
206+
207+
a = fp8_perm(a, sorted_token_ids // topk)
208+
if a_s is not None:
209+
a_s = a_s[sorted_token_ids // topk]
210+
211+
return a, a_s, m_indices, inv_perm
212+
213+
214+
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
215+
M = topk_weight.shape[0]
216+
out = out[inv_perm, ...]
217+
tmp_out = out.view(-1, topk, K)
218+
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
219+
220+
221+
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
222+
block_shape):
223+
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
224+
num_groups = w1.shape[0]
225+
M, K = a.shape
226+
N = w2.shape[-1]
227+
228+
topk_weight, topk_ids, token_expert_indices = fused_topk(
229+
a, score.float(), topk, False)
230+
231+
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
232+
233+
_, block_k = block_shape[0], block_shape[1]
234+
235+
a_q, a_s = per_token_group_quant_fp8(a, block_m)
236+
237+
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
238+
num_groups, topk, block_m)
239+
240+
inter_out = torch.zeros((a_q.shape[0], N * 2),
241+
dtype=torch.bfloat16,
242+
device=a.device)
243+
244+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
245+
inter_out, m_indices)
246+
247+
act_out = SiluAndMul().forward_native(inter_out)
248+
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
249+
250+
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
251+
252+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
253+
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
254+
255+
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
256+
257+
return final_out
258+
259+
260+
@pytest.mark.parametrize(
261+
"M,N,K,E,topk,seed",
262+
itertools.product(M_dg, N, K, E, TOP_KS, SEEDS))
263+
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
264+
@torch.inference_mode()
265+
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
266+
267+
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
268+
block_size = [block_m, block_m]
269+
dtype = torch.bfloat16
270+
271+
if topk > E:
272+
pytest.skip(f"Skipping test: topk={topk} > E={E}")
273+
274+
if not _valid_deep_gemm_shape(M, N, K):
275+
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
276+
277+
torch.manual_seed(seed)
278+
fp8_info = torch.finfo(torch.float8_e4m3fn)
279+
fp8_max, fp8_min = fp8_info.max, fp8_info.min
280+
281+
a = torch.randn((M, K), dtype=dtype) / 10
282+
283+
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
284+
fp8_max).clamp(min=fp8_min, max=fp8_max)
285+
286+
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
287+
fp8_max).clamp(min=fp8_min, max=fp8_max)
288+
289+
score = torch.randn((M, E), dtype=dtype)
290+
291+
block_n, block_k = block_size[0], block_size[1]
292+
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
293+
k_tiles_w1 = (K + block_k - 1) // block_k
294+
n_tiles_w2 = (K + block_n - 1) // block_n
295+
k_tiles_w2 = (N + block_k - 1) // block_k
296+
297+
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
298+
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
299+
300+
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
301+
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
302+
303+
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
304+
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
305+
306+
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
307+
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
308+
309+
for i in range(E):
310+
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
311+
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
312+
313+
# Set the context to avoid lots of warning spam.
314+
with set_current_vllm_config(vllm_config):
315+
if M >= 128:
316+
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
317+
score, topk, block_size)
318+
else:
319+
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
320+
topk, block_size)
321+
322+
topk_weights, topk_ids, token_expert_indices = fused_topk(
323+
a, score.float(), topk, False)
324+
325+
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
326+
327+
#print(f"{out.sum()=}")
328+
#print(f"{ref_out.sum()=}")
329+
330+
rel_diff = (torch.mean(
331+
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
332+
torch.mean(torch.abs(ref_out.to(torch.float32))))
333+
334+
assert rel_diff < 0.03

0 commit comments

Comments
 (0)