Skip to content

add more generic kernel for fp8 blockwise scaling #2592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 24, 2025

Stacked PRs:


add more generic kernel for fp8 blockwise scaling

  • Add generic FP8 blockwise quantization kernel from fbgemm_gpu
  • Add tests to verify numerics against torch implementation.
  • Add benchmarking script to bench the 3 blockwise quantization options (torch.compile, fbgemm_gpu kernels, and deepgemm kernels

Performance

TL;DR fbgemm is the best overall right now. deepgemm and torch.compile are both fast for 1x128 but slow for 128x128.

A_shape         block_shape      torch_us    fbgemm_us    deepgemm_us
--------------  -------------  ----------  -----------  -------------
(1024, 1024)    (1,128)            12.096       18.144         17.408
(1024, 1024)    (128,128)          22.432       12.288         17.408
(2048, 2048)    (1,128)            51.328       41.984         40.096
(2048, 2048)    (128,128)          84.192       15.264         40.096
(4096, 4096)    (1,128)           118.784      134.336        132.128
(4096, 4096)    (128,128)         241.68        25.728        132.096
(8192, 8192)    (1,128)           389.344      509.04         498.72
(8192, 8192)    (128,128)         874.56        62.464        498.656
(16384, 16384)  (1,128)          1456.16      1998.66        1964.96
(16384, 16384)  (128,128)        3377.31       183.296       1965.02
(32768, 32768)  (1,128)          5732.42      7960.21        7830.56
(32768, 32768)  (128,128)       13692.4        669.664       7831.14

Copy link

pytorch-bot bot commented Jul 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2592

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures

As of commit fa64d54 with merge base 0e00df3 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Jul 24, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from d0cd3be to 3b36022 Compare July 24, 2025 03:25
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 24, 2025
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 24, 2025
@danielvegamyhre danielvegamyhre requested review from vkuzo and drisspg July 24, 2025 04:01
@danielvegamyhre
Copy link
Contributor Author

cc @vkuzo @drisspg for review

error = torch.norm(C - C_q) / torch.norm(C)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.1, "Quantize gemm error is too high"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use sqnr everywhere match w/ existing numerics testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use SQNR


# original implementation from fbgemm_gpu:
# https://github.com/pytorch/FBGEMM/blob/b19401e913fcdff536dc097fa3013a0a9d66256e/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L3091
def triton_quantize_fp8_block(
Copy link
Contributor

@drisspg drisspg Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we have an optional runtime dependency on fbgemm can we just call their kernel directly?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is the desired end state. For now I have tried and have had repeated problems getting it to work so far (fbgemm-gpu-genai), e.g. undefined symbols. Tried on both H100 and B200 and got different undefined symbol errors

danielvegamyhre added a commit that referenced this pull request Jul 24, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 3b36022 to 9821453 Compare July 24, 2025 04:13
@drisspg
Copy link
Contributor

drisspg commented Jul 24, 2025

(32768, 32768)  (1,128)          5732.42      7960.21        7830.56
(32768, 32768)  (128,128)       13692.4        669.664       7831.14

this number is kinda weird to me, do you have memory bandwidth calcs? I dont immediately get why there is a 10x delta in group wise vs blockwise

stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from ee6ce03 to fa64d54 Compare July 25, 2025 03:11
@danielvegamyhre
Copy link
Contributor Author

this number is kinda weird to me, do you have memory bandwidth calcs? I dont immediately get why there is a 10x delta in group wise vs blockwise

Yeah I agree it's odd, will try adding some mem bw calcs, was thinking about checking with Josh / fbgemm team as well if perhapst here is a different kernel they use for activation quant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants