-
Notifications
You must be signed in to change notification settings - Fork 440
make mxfp8 dim1 cast kernel configurable #1427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
50e3ade
to
372b083
Compare
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
372b083
to
24c3d3a
Compare
torchtitan/config_manager.py
Outdated
@@ -556,7 +556,7 @@ class Float8: | |||
|
|||
@dataclass | |||
class MX: | |||
use_fp8_dim1_cast_triton_kernel: bool = True | |||
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the benefit of letting user choose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- It makes it easy for torchao developers to benchmark torch.compile vs triton vs cuda implementations as we work on perf improvements, especially on improving torch.compile performance for casts.
- Users may be using a python only torchao installation that doesn't include the CUDA kernel. This is probably not common but still worth considering.
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
24c3d3a
to
b4b53cb
Compare
@vkuzo @tianyu-l this is ready for review CI error is unrelated to this change:
|
lgtm, I'll let someone from titan team accept |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode(). This is a known issue we were talking to Brian about, will continue following up on it.
Do you think we can error out in mx.py
, since you do have JobConfig
on tp info?
Please rebase before merge.
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
b4b53cb
to
a6466e7
Compare
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
a6466e7
to
f79f833
Compare
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
f79f833
to
4806fdb
Compare
confirmed error is CUDA driver error in async TP, which is not related to this change. as an aside, it looks like an error at the C++ symm mem level, or perhaps the containerized env running the test had dependencies updated? |
Sure, done |
Stacked PRs:
make mxfp8 dim1 cast kernel configurable
Summary
"--mx.use_triton_for_dim1_cast=[true|false]
tomxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]
Test plan
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"
Limitations
RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
. This is a known issue we were talking to Brian about, will continue following up on it.