Skip to content

Commit 58a5c18

Browse files
committed
fix mergea
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 203dece commit 58a5c18

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

tests/kernels/moe/test_block_fp8.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
274274
if not _valid_deep_gemm_shape(M, N, K):
275275
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
276276

277+
chunk_size = 1024
278+
277279
torch.manual_seed(seed)
278280

279-
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
281+
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
282+
280283
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
281284
block_size = [block_m, block_m]
282285
dtype = torch.bfloat16
286+
283287
fp8_info = torch.finfo(torch.float8_e4m3fn)
284288
fp8_max, fp8_min = fp8_info.max, fp8_info.min
285289

@@ -315,6 +319,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
315319
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
316320
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
317321

322+
use_compile = (chunk_size < M and N >= 1024 and K >= 1024
323+
and current_platform.is_cuda_alike())
324+
use_cudagraph = use_compile
325+
318326
# Set the context to avoid lots of warning spam.
319327
with set_current_vllm_config(vllm_config):
320328
if M >= 128:
@@ -327,7 +335,26 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
327335
topk_weights, topk_ids, token_expert_indices = fused_topk(
328336
a, score.float(), topk, False)
329337

330-
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
338+
if use_compile:
339+
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
340+
backend="inductor",
341+
fullgraph=True)
342+
else:
343+
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
344+
345+
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
346+
topk_ids)
347+
348+
if use_cudagraph:
349+
out.fill_(0)
350+
stream = torch.cuda.Stream()
351+
graph = torch.cuda.CUDAGraph()
352+
with torch.cuda.graph(graph, stream=stream):
353+
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
354+
topk_ids)
355+
torch.cuda.synchronize()
356+
graph.replay()
357+
torch.cuda.synchronize()
331358

332359
#print(f"{out.sum()=}")
333360
#print(f"{ref_out.sum()=}")

0 commit comments

Comments
 (0)