Skip to content

Add CUDA kernel for MXFP8 dim1 casting #2513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 9, 2025

Stacked PRs:


Add CUDA kernel for MXFP8 dim1 casting

Co-authored-by: Less Wright lessw@etrillium.com

Summary

  • Add a CUDA kernel to do mxfp8 dim1 casting, which benchmarks show is ~1.4x faster than the existing Triton kernel, and benchmarking Llama3 8b training with torchtitan show this translating to a 1.5% - 2.5% e2e training speedup. This prototype was developed/explored in a personal repo, and now this PR begins the migration into torchao.
  • We used this TE kernel as a starting point (big thanks to @lessw2020 for setting up this C++ extension with the parts we needed to iterate on! working mxfp8 quantization cpp extension danielvegamyhre/private-torchao#1)
  • Subsequent PRs will migrate the integration, torch.compile support etc. to make reviewing easier. This PR is only adds (1) the CUDA kernel, (2) the C++ extension to make it usable in Python, (3) numerical test, and (4) kernel benchmarking.

Numerics changes

I made the following changes to get matching numerics for the columnwise/dim1 scaling path:

Performance improvements

I also made the following changes to improve perf for the columnwise/dim1 scaling path:

torchao integration, torch.compile support, and other changes

Test build

  • cd ~/ao/torchao/prototype/mxfp8_cuda/
  • python setup.py install
  • Note: will include in torchao main setup.py after subsequent integration PRs are landed (see above)

Test numerics

  • pytest test/prototype/mx_formats/test_kernels.py -k cuda

Kernel microbenchmarks

  • CUDA mx dim1 floor scaling kernel is ~1.4x faster and achieves similarly higher peak memory bandwidth utilization.
  • RCEIL scaling is even faster (since it uses hardware native scaling instead of software scaling), but is not a 1-to-1 comparison since Triton is doing floor scaling.

CUDA mx dim1 with floor scaling:

(ao) [danvm@devgpu031.atn1 ~/ao/benchmarks/mx_formats (mxfp8-cuda)]$ CUDA_VISIBLE_DEVICES=3 python cast_bench.py --mode dim1_mx_cuda_floor
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250702+cu128
triton version: 3.3.1
mode: dim1_mx_cuda_floor
time_us 155.45600652694702
mem_bw_gbps 5234.245972084408

CUDA mx dim1 with rceil scaling:

(ao) [danvm@devgpu031.atn1 ~/ao/benchmarks/mx_formats (mxfp8-cuda)]$ CUDA_VISIBLE_DEVICES=3 python cast_bench.py --mode dim1_mx_cuda_rceil
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250702+cu128
triton version: 3.3.1
mode: dim1_mx_cuda_rceil
time_us 145.31199634075165
mem_bw_gbps 5599.640748805854

Triton mx dim1 (uses floor scaling):

(ao) [danvm@devgpu031.atn1 ~/ao/benchmarks/mx_formats (mxfp8-cuda)]$ CUDA_VISIBLE_DEVICES=3 python cast_bench.py --mode dim1_mx_triton
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250702+cu128
triton version: 3.3.1
mode: dim1_mx_triton
time_us 217.28000044822693
mem_bw_gbps 3744.914277988901

E2E training benchmarks

Llama3.1 8b on 4 B200s with FSDP=4, torch.compile, per op SAC

Speedup over bf16 baseline with FSDP=4  
fp8 tensorwise 1.185
mxpf8 triton 1.185
mxfp8 cuda 1.202

Note: I got a larger improvement of ~2.6% with FSDP=8 last night, but am going to rerun that benchmark a couple times later to confirm.

BF16: https://www.internalfb.com/phabricator/paste/view/P1864747221

(ao) [danvm@devgpu031.atn1 ~/torchtitan (main)]$ python parse.py --log-file=bf16-log.txt 

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 11401.0
Max Memory Usage: 50.22 GiB

FP8 tensorwise: https://www.internalfb.com/phabricator/paste/view/P1864745686

(ao) [danvm@devgpu031.atn1 ~/torchtitan (main)]$ python parse.py --log-file=fp8-tensorwise-log.txt 

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 13506.0
Max Memory Usage: 50.14 GiB

MXFP8 (triton): https://www.internalfb.com/phabricator/paste/view/P1864746140

(ao) [danvm@devgpu031.atn1 ~/torchtitan (main)]$ python parse.py --log-file=triton-log.txt 

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 13506.0
Max Memory Usage: 50.49 GiB

MXFP8 (cuda): https://www.internalfb.com/phabricator/paste/view/P1864746496

(ao) [danvm@devgpu031.atn1 ~/torchtitan (main)]$ python parse.py --log-file=cuda-log.txt 

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 13708.0
Max Memory Usage: 50.26 GiB

Additional e2e training benchmarks on 2nd machine to confirm results

Additional e2e training benchmarks ran on @lessw2020's machine, confirming the speedup is reproducible.

The peak memory reduction of 4-5% when no AC is used is interesting to note.

Llama 3.1 8b, FSDP=4, torch.compile, no AC

TPS Memory
cuda 12910 108.47
triton 12622 113.02
cuda % change vs triton 2.28% -4.03%

Llama 3.1 8b, FSDP=8, torch.compile, no AC

TPS Memory
cuda 12992 129.45
triton 12761 136.61
cuda % change vs triton 1.81% -5.24%

danielvegamyhre added a commit that referenced this pull request Jul 9, 2025
Co-authored-by: Less Wright <lessw@etrillium.com>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 828b1b0 to 5d1f777 Compare July 9, 2025 23:07
Copy link

pytorch-bot bot commented Jul 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2513

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 97c77a7 with merge base ddd4021 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 9, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft July 9, 2025 23:08
danielvegamyhre added a commit that referenced this pull request Jul 9, 2025
Co-authored-by: Less Wright <lessw@etrillium.com>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 5d1f777 to 5df753a Compare July 9, 2025 23:47
danielvegamyhre added a commit that referenced this pull request Jul 9, 2025
Co-authored-by: Less Wright <lessw@etrillium.com>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 5df753a to 6253dac Compare July 9, 2025 23:51
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 9, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review July 9, 2025 23:55
@danielvegamyhre
Copy link
Contributor Author

cc @vkuzo @drisspg for review

"dim0_mx",
"dim1_mx",
"dim1_mx_triton",
"dim1_mx_cuda_floor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove "floor" to match all the others, or add "floor" to all the others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added floor to others, I like the more explicit naming.

@@ -194,6 +208,42 @@ def run(
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx_cuda_floor":
bench_fn = partial(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: either refactor the other branches to use partial, or remove partial here, to keep the file code style consistent


from torchao.prototype.mx_formats.kernels import (
triton_to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx

try:
import mxfp8_cuda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would it take to just have this available in torchao instead of requiring a separate import?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Migrated to torchao/csrc/cuda/mx_kernels and updated setup.py to make it available as torchao.prototype.mxfp8_cuda. cc @drisspg as well

x_hp: torch.Tensor, block_size
x_hp: torch.Tensor,
block_size,
scaling_mode=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to floor?

from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx

scale_mode = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove after default to floor above?

"""
setup.py - Build configuration for MXFP8 PyTorch extension

This extension requires NVIDIA BlACKWELL architecture (SM100+) or newer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLACKWELL instead of BlACKWELL

* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights
*reserved.
*
* See LICENSE for license information.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be a LICENSE file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated license headers per conversation with Supriya, see update. There's no standalone license file though.

.contiguous()
)

y_d1_ref, s_d1_ref = triton_to_mxfp8_dim1_reference(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the refernce should always be the native PyTorch code, and custom kernels should each match against the reference. This way if there is a mismatch, it's very clear which kernels have a mismatch vs native PyTorch code.

torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)

# check quantized values
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also test the memory layout of all the tensors vs reference

danielvegamyhre added a commit that referenced this pull request Jul 10, 2025
Co-authored-by: Less Wright <lessw@etrillium.com>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 6253dac to 4823b37 Compare July 10, 2025 20:38
danielvegamyhre added a commit that referenced this pull request Jul 11, 2025
Co-authored-by: Less Wright <lessw@etrillium.com>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 4823b37 to a55e5ae Compare July 11, 2025 00:13
danielvegamyhre added a commit that referenced this pull request Jul 11, 2025
Co-authored-by: Less Wright <lessw@etrillium.com>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from a55e5ae to 0065bcd Compare July 11, 2025 00:25
Co-authored-by: Less Wright <lessw@etrillium.com>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 0065bcd to 97c77a7 Compare July 12, 2025 00:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants