Skip to content

Commit 80ed942

Browse files
cthifacebook-github-bot
authored andcommitted
Fixing test_quantize_fp8_matmul for CUDA graph (pytorch#4425)
Summary: Pull Request resolved: pytorch#4425 X-link: facebookresearch/FBGEMM#1492 With cuda graph you will run into issues like P1856907021 sporadically on this test. The confusing thing is the RNG errors are thrown outside of the cuda graph (when we do `torch.randn`) - Im not sure if the tests are running in parallel with hypothesis/buck but this could potentially be the cause. The change is we always warm up before cuda graph capture, even for non-triton. This is a good practice anyways, since some initialization could occur beneath us in ATen. After adding back cuda graph, now the test runs reliably. Reviewed By: jwfromm Differential Revision: D77596554 fbshipit-source-id: 6a65ba530bbac5d1357ac24ca9638e28e33369c8
1 parent 8325430 commit 80ed942

File tree

1 file changed

+65
-104
lines changed

1 file changed

+65
-104
lines changed

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 65 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
278278
),
279279
QType=st.sampled_from([fp8_e4m3, fp8_e5m2]),
280280
Bias=st.sampled_from([True, False]),
281-
CudaGraph=st.sampled_from([False]),
281+
CudaGraph=st.sampled_from([True, False]),
282282
UseTriton=st.sampled_from([False] + ([True] if torch.version.cuda else [])),
283283
UseFastAccum=st.booleans(),
284284
InputMultiDim=st.booleans(),
@@ -337,78 +337,62 @@ def test_quantize_fp8_matmul(
337337
)
338338

339339
if Mode == "tensorwise":
340-
if CudaGraph:
341-
g = torch.cuda.CUDAGraph()
342-
with torch.cuda.graph(g):
343-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
344-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
345-
zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale)
346-
if bias is not None:
347-
zq += bias
348-
g.replay()
349-
else:
340+
341+
def f(
342+
x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor]
343+
) -> torch.Tensor:
350344
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
351345
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
352346
zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale)
353347
if bias is not None:
354348
zq += bias
355-
elif Mode == "tensorwise_broadcast":
356-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
357-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
358-
x_scale = x_scale.item()
359-
w_scale = w_scale.item()
349+
return zq
350+
360351
if CudaGraph:
352+
# Warm-up to avoid capture issues
353+
f(x, w, bias)
354+
361355
g = torch.cuda.CUDAGraph()
362356
with torch.cuda.graph(g):
363-
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(
364-
xq, wq, x_scale * w_scale, use_fast_accum=UseFastAccum
365-
)
366-
if bias is not None:
367-
zq += bias
357+
zq = f(x, w, bias)
368358
g.replay()
369359
else:
360+
zq = f(x, w, bias)
361+
elif Mode == "tensorwise_broadcast":
362+
363+
def f(
364+
xq: torch.Tensor,
365+
wq: torch.Tensor,
366+
scale: float,
367+
bias: Optional[torch.Tensor],
368+
) -> torch.Tensor:
370369
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(
371-
xq, wq, x_scale * w_scale, use_fast_accum=UseFastAccum
370+
xq, wq, scale, use_fast_accum=UseFastAccum
372371
)
373372
if bias is not None:
374373
zq += bias
375-
elif Mode == "rowwise":
374+
return zq
375+
376+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
377+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
378+
x_scale = x_scale.item()
379+
w_scale = w_scale.item()
380+
376381
if CudaGraph:
377-
# Warm up triton functions before cuda graph.
378-
xq, x_scale = quantize_fp8_row(x)
379-
wq, w_scale = quantize_fp8_row(w)
380-
if UseTriton and torch.version.cuda:
381-
zq = matmul_fp8_row(
382-
xq, wq, x_scale, w_scale, fp8_fast_accum=UseFastAccum
383-
)
382+
# Warm-up to avoid capture issues
383+
f(xq, wq, x_scale * w_scale, bias)
384+
384385
g = torch.cuda.CUDAGraph()
385386
with torch.cuda.graph(g):
386-
if torch.version.cuda:
387-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
388-
x, output_dtype=QType
389-
)
390-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
391-
else:
392-
xq, x_scale = quantize_fp8_row(x)
393-
wq, w_scale = quantize_fp8_row(w)
394-
if UseTriton and torch.version.cuda:
395-
zq = matmul_fp8_row(xq, wq, x_scale, w_scale)
396-
if bias is not None:
397-
zq += bias
398-
else:
399-
zq = torch.ops.fbgemm.f8f8bf16_rowwise(
400-
xq,
401-
wq,
402-
x_scale,
403-
w_scale,
404-
bias=bias if torch.version.cuda else None,
405-
use_fast_accum=UseFastAccum,
406-
)
407-
# Bias fusion not yet supported on AMD.
408-
if bias is not None and torch.version.hip:
409-
zq += bias
387+
zq = f(xq, wq, x_scale * w_scale, bias)
410388
g.replay()
411389
else:
390+
zq = f(xq, wq, x_scale * w_scale, bias)
391+
elif Mode == "rowwise":
392+
393+
def f(
394+
x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor]
395+
) -> torch.Tensor:
412396
if torch.version.cuda:
413397
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
414398
x, output_dtype=QType
@@ -418,9 +402,7 @@ def test_quantize_fp8_matmul(
418402
xq, x_scale = quantize_fp8_row(x)
419403
wq, w_scale = quantize_fp8_row(w)
420404
if UseTriton and torch.version.cuda:
421-
zq = matmul_fp8_row(
422-
xq, wq, x_scale, w_scale, fp8_fast_accum=UseFastAccum
423-
)
405+
zq = matmul_fp8_row(xq, wq, x_scale, w_scale)
424406
if bias is not None:
425407
zq += bias
426408
else:
@@ -435,14 +417,27 @@ def test_quantize_fp8_matmul(
435417
# Bias fusion not yet supported on AMD.
436418
if bias is not None and torch.version.hip:
437419
zq += bias
438-
elif Mode == "blockwise":
439-
block_m = block_n = block_k = 128
440-
output_device = torch.device(self.device)
420+
421+
return zq
422+
441423
if CudaGraph:
442-
# Need a warmup to compile the Triton kernel before cuda graph
424+
# Warm-up to avoid capture issues
425+
f(x, w, bias)
426+
427+
g = torch.cuda.CUDAGraph()
428+
with torch.cuda.graph(g):
429+
zq = f(x, w, bias)
430+
g.replay()
431+
else:
432+
zq = f(x, w, bias)
433+
elif Mode == "blockwise":
443434

435+
def f(
436+
x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor]
437+
) -> torch.Tensor:
438+
block_m = block_n = block_k = 128
444439
wq, w_scale = quantize_fp8_block(
445-
w, block_n, block_k, output_device=output_device
440+
w, block_n, block_k, output_device=torch.device(self.device)
446441
)
447442
xq, x_scale = quantize_fp8_block(x, block_m, block_k)
448443
if UseTriton:
@@ -463,52 +458,18 @@ def test_quantize_fp8_matmul(
463458
if bias is not None:
464459
zq += bias
465460

461+
return zq
462+
463+
if CudaGraph:
464+
# Warm-up to avoid capture issues
465+
f(x, w, bias)
466+
466467
g = torch.cuda.CUDAGraph()
467468
with torch.cuda.graph(g):
468-
wq, w_scale = quantize_fp8_block(
469-
w, block_n, block_k, output_device=output_device
470-
)
471-
xq, x_scale = quantize_fp8_block(x, block_m, block_k)
472-
if UseTriton:
473-
zq = matmul_fp8_block(
474-
xq,
475-
wq,
476-
x_scale,
477-
w_scale,
478-
block_m,
479-
block_n,
480-
block_k,
481-
fp8_fast_accum=UseFastAccum,
482-
)
483-
else:
484-
zq = torch.ops.fbgemm.f8f8bf16_blockwise(
485-
xq, wq, x_scale, w_scale, block_m, block_n, block_k
486-
)
487-
if bias is not None:
488-
zq += bias
469+
zq = f(x, w, bias)
489470
g.replay()
490471
else:
491-
wq, w_scale = quantize_fp8_block(
492-
w, block_n, block_k, output_device=output_device
493-
)
494-
xq, x_scale = quantize_fp8_block(x, block_m, block_k)
495-
if UseTriton:
496-
zq = matmul_fp8_block(
497-
xq,
498-
wq,
499-
x_scale,
500-
w_scale,
501-
block_m,
502-
block_n,
503-
block_k,
504-
fp8_fast_accum=UseFastAccum,
505-
)
506-
else:
507-
zq = torch.ops.fbgemm.f8f8bf16_blockwise(
508-
xq, wq, x_scale, w_scale, block_m, block_n, block_k
509-
)
510-
if bias is not None:
511-
zq += bias
472+
zq = f(x, w, bias)
512473
else:
513474
raise ValueError(f"Invalid mode {Mode}")
514475

0 commit comments

Comments
 (0)