-
Notifications
You must be signed in to change notification settings - Fork 296
[CPU] Add support for dynamic float8 act float8 weight on CPU #2505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2505
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 953ac13 with merge base 64c1ce3 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @chunyuan-w @mingfeima Could you please review this PR? Thanks. |
Should we move the conversion vec code to this file? https://github.com/pytorch/pytorch/blob/cd995bfb2aac8891465809be3ce29543bd524287/aten/src/ATen/cpu/vec/vec512/vec512_float8.h Similar to this PR: pytorch/pytorch#152417 |
Thanks for the comment. If we move it to PyTorch, a problem might be that we need to check if the function is available at compile time. We may do it step by step, and for now it might be better that we keep it here. |
|
||
// scales shape = [Nc, G, block_n] | ||
int64_t num_groups = weight_scales.size(1); | ||
int64_t group_size = K / num_groups; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we support the case where K % num_groups != 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support it. It is guarded by the quantization utility in Torchao, such as
ao/torchao/quantization/quant_primitives.py
Line 293 in aee0795
assert input_size[i] % block_size[i] == 0, ( |
I have also added a
TORCH_CHECK
here. Thanks.
Summary
This PR adds support for dynamic float8 act float8 weight quantization on X86 CPU.
It adds
Float8DynamicActFloat8WeightCPULayout
float8_linear_prepack_cpu
andfloat8_linear_cpu
The kernel computes FP8 GEMM with BF16 dtype.
Test plan