-
Notifications
You must be signed in to change notification settings - Fork 296
Add CUDA kernel for MXFP8 dim1 casting #2513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
828b1b0
to
5d1f777
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2513
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 97c77a7 with merge base ddd4021 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
5d1f777
to
5df753a
Compare
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
5df753a
to
6253dac
Compare
"dim0_mx", | ||
"dim1_mx", | ||
"dim1_mx_triton", | ||
"dim1_mx_cuda_floor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove "floor" to match all the others, or add "floor" to all the others
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added floor to others, I like the more explicit naming.
benchmarks/mx_formats/cast_bench.py
Outdated
@@ -194,6 +208,42 @@ def run( | |||
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 | |||
bps = (bytes_r + bytes_w) / (time_us / 1e6) | |||
|
|||
elif mode == "dim1_mx_cuda_floor": | |||
bench_fn = partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: either refactor the other branches to use partial, or remove partial here, to keep the file code style consistent
benchmarks/mx_formats/cast_bench.py
Outdated
|
||
from torchao.prototype.mx_formats.kernels import ( | ||
triton_to_mxfp8_dim1, | ||
) | ||
from torchao.prototype.mx_formats.mx_tensor import to_mx | ||
|
||
try: | ||
import mxfp8_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what would it take to just have this available in torchao instead of requiring a separate import?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Migrated to torchao/csrc/cuda/mx_kernels and updated setup.py to make it available as torchao.prototype.mxfp8_cuda
. cc @drisspg as well
x_hp: torch.Tensor, block_size | ||
x_hp: torch.Tensor, | ||
block_size, | ||
scaling_mode=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to floor?
from torchao.prototype.mx_formats.mx_tensor import to_mx | ||
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx | ||
|
||
scale_mode = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove after default to floor above?
""" | ||
setup.py - Build configuration for MXFP8 PyTorch extension | ||
|
||
This extension requires NVIDIA BlACKWELL architecture (SM100+) or newer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BLACKWELL
instead of BlACKWELL
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights | ||
*reserved. | ||
* | ||
* See LICENSE for license information. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should there be a LICENSE
file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated license headers per conversation with Supriya, see update. There's no standalone license file though.
.contiguous() | ||
) | ||
|
||
y_d1_ref, s_d1_ref = triton_to_mxfp8_dim1_reference( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the refernce should always be the native PyTorch code, and custom kernels should each match against the reference. This way if there is a mismatch, it's very clear which kernels have a mismatch vs native PyTorch code.
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0) | ||
|
||
# check quantized values | ||
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should also test the memory layout of all the tensors vs reference
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
6253dac
to
4823b37
Compare
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
4823b37
to
a55e5ae
Compare
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
a55e5ae
to
0065bcd
Compare
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
0065bcd
to
97c77a7
Compare
Stacked PRs:
Add CUDA kernel for MXFP8 dim1 casting
Co-authored-by: Less Wright lessw@etrillium.com
Summary
Numerics changes
I made the following changes to get matching numerics for the columnwise/dim1 scaling path:
Performance improvements
I also made the following changes to improve perf for the columnwise/dim1 scaling path:
torchao integration, torch.compile support, and other changes
torch.compile
, in order to run e2e training benchmarks in torchtitan. These changes will be migrated in subsequent PRs, to make this easier to review. Custom op integration for CUDA mxfp8 casting kernel danielvegamyhre/private-torchao#17Test build
cd ~/ao/torchao/prototype/mxfp8_cuda/
python setup.py install
Test numerics
pytest test/prototype/mx_formats/test_kernels.py -k cuda
Kernel microbenchmarks
CUDA mx dim1 with floor scaling:
CUDA mx dim1 with rceil scaling:
Triton mx dim1 (uses floor scaling):
E2E training benchmarks
Llama3.1 8b on 4 B200s with FSDP=4, torch.compile, per op SAC
Note: I got a larger improvement of ~2.6% with FSDP=8 last night, but am going to rerun that benchmark a couple times later to confirm.
BF16: https://www.internalfb.com/phabricator/paste/view/P1864747221
FP8 tensorwise: https://www.internalfb.com/phabricator/paste/view/P1864745686
MXFP8 (triton): https://www.internalfb.com/phabricator/paste/view/P1864746140
MXFP8 (cuda): https://www.internalfb.com/phabricator/paste/view/P1864746496
Additional e2e training benchmarks on 2nd machine to confirm results
Additional e2e training benchmarks ran on @lessw2020's machine, confirming the speedup is reproducible.
The peak memory reduction of 4-5% when no AC is used is interesting to note.
Llama 3.1 8b, FSDP=4, torch.compile, no AC
Llama 3.1 8b, FSDP=8, torch.compile, no AC