@@ -274,12 +274,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
274
274
if not _valid_deep_gemm_shape (M , N , K ):
275
275
pytest .skip (f"Skipping test: invalid size m={ M } , n={ N } , k={ K } " )
276
276
277
+ chunk_size = 1024
278
+
277
279
torch .manual_seed (seed )
278
280
279
- monkeypatch .setenv ("VLLM_FUSED_MOE_CHUNK_SIZE" , "8192" )
281
+ monkeypatch .setenv ("VLLM_FUSED_MOE_CHUNK_SIZE" , str (chunk_size ))
282
+
280
283
block_m = deep_gemm .get_m_alignment_for_contiguous_layout ()
281
284
block_size = [block_m , block_m ]
282
285
dtype = torch .bfloat16
286
+
283
287
fp8_info = torch .finfo (torch .float8_e4m3fn )
284
288
fp8_max , fp8_min = fp8_info .max , fp8_info .min
285
289
@@ -315,6 +319,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
315
319
w1 [i ], w1_s [i ] = per_block_cast_to_fp8 (w1_bf16 [i ])
316
320
w2 [i ], w2_s [i ] = per_block_cast_to_fp8 (w2_bf16 [i ])
317
321
322
+ use_compile = (chunk_size < M and N >= 1024 and K >= 1024
323
+ and current_platform .is_cuda_alike ())
324
+ use_cudagraph = use_compile
325
+
318
326
# Set the context to avoid lots of warning spam.
319
327
with set_current_vllm_config (vllm_config ):
320
328
if M >= 128 :
@@ -327,7 +335,26 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
327
335
topk_weights , topk_ids , token_expert_indices = fused_topk (
328
336
a , score .float (), topk , False )
329
337
330
- out = deep_gemm_moe_fp8 (a , w1 , w2 , w1_s , w2_s , topk_weights , topk_ids )
338
+ if use_compile :
339
+ deep_gemm_moe_fp8_fn = torch .compile (deep_gemm_moe_fp8 ,
340
+ backend = "inductor" ,
341
+ fullgraph = True )
342
+ else :
343
+ deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
344
+
345
+ out = deep_gemm_moe_fp8_fn (a , w1 , w2 , w1_s , w2_s , topk_weights ,
346
+ topk_ids )
347
+
348
+ if use_cudagraph :
349
+ out .fill_ (0 )
350
+ stream = torch .cuda .Stream ()
351
+ graph = torch .cuda .CUDAGraph ()
352
+ with torch .cuda .graph (graph , stream = stream ):
353
+ out = deep_gemm_moe_fp8_fn (a , w1 , w2 , w1_s , w2_s , topk_weights ,
354
+ topk_ids )
355
+ torch .cuda .synchronize ()
356
+ graph .replay ()
357
+ torch .cuda .synchronize ()
331
358
332
359
#print(f"{out.sum()=}")
333
360
#print(f"{ref_out.sum()=}")
0 commit comments