Skip to content

make mxfp8 dim1 cast kernel configurable #1427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 25, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 21, 2025

Stacked PRs:


make mxfp8 dim1 cast kernel configurable

Summary

  • We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (Add CUDA kernel for MXFP8 dim1 casting ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: integration of new mxfp8 casting cuda kernel ao#2564
  • This PR updates the mxfp8 user facing API to replace the boolean flag "--mx.use_triton_for_dim1_cast=[true|false] to mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]

Test plan

  • Triton: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"
  • Cuda: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"
  • Torch: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"

Limitations

  • TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode(). This is a known issue we were talking to Brian about, will continue following up on it.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 21, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 50e3ade to 372b083 Compare July 21, 2025 17:17
@danielvegamyhre danielvegamyhre mentioned this pull request Jul 21, 2025
danielvegamyhre added a commit that referenced this pull request Jul 21, 2025
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 372b083 to 24c3d3a Compare July 21, 2025 17:19
@@ -556,7 +556,7 @@ class Float8:

@dataclass
class MX:
use_fp8_dim1_cast_triton_kernel: bool = True
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the benefit of letting user choose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It makes it easy for torchao developers to benchmark torch.compile vs triton vs cuda implementations as we work on perf improvements, especially on improving torch.compile performance for casts.
  2. Users may be using a python only torchao installation that doesn't include the CUDA kernel. This is probably not common but still worth considering.

@danielvegamyhre danielvegamyhre marked this pull request as draft July 21, 2025 19:31
danielvegamyhre added a commit that referenced this pull request Jul 22, 2025
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 24c3d3a to b4b53cb Compare July 22, 2025 03:33
@danielvegamyhre danielvegamyhre marked this pull request as ready for review July 22, 2025 03:33
@danielvegamyhre
Copy link
Contributor Author

@vkuzo @tianyu-l this is ready for review

CI error is unrelated to this change:

Exception: Integration test failed, flavor : 2D eager, command : TORCH_TRACE="artifacts-to-be-uploaded/2d_eager/compile_trace" CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml NGPU=4 LOG_RANK=0,1,2,3 ./run_train.sh --job.dump_folder artifacts-to-be-uploaded/2d_eager --parallelism.tensor_parallel_degree 2

@vkuzo
Copy link
Contributor

vkuzo commented Jul 24, 2025

lgtm, I'll let someone from titan team accept

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode(). This is a known issue we were talking to Brian about, will continue following up on it.

Do you think we can error out in mx.py, since you do have JobConfig on tp info?

Please rebase before merge.

danielvegamyhre added a commit that referenced this pull request Jul 24, 2025
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from b4b53cb to a6466e7 Compare July 24, 2025 23:42
danielvegamyhre added a commit that referenced this pull request Jul 24, 2025
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from a6466e7 to f79f833 Compare July 24, 2025 23:44
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from f79f833 to 4806fdb Compare July 24, 2025 23:49
@danielvegamyhre
Copy link
Contributor Author

confirmed error is CUDA driver error in async TP, which is not related to this change. as an aside, it looks like an error at the C++ symm mem level, or perhaps the containerized env running the test had dependencies updated?

@danielvegamyhre
Copy link
Contributor Author

Do you think we can error out in mx.py, since you do have JobConfig on tp info?

Sure, done

@danielvegamyhre danielvegamyhre merged commit f3e2a75 into main Jul 25, 2025
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants