You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference
Summary:
Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:
1. bring the 128x128 gemm triton kernel we have out of prototype and
wrap it with a custom op for `torch.compile` compatibility
2. enable the new granularity in various utility functions
3. wire the new granularity through the float8 inference configs
4. add a test which tests for e2e numerical correctness via SQNR
comparison vs high precision baseline
For now I added a fallback which only requires triton and is numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:
1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+
2. we should enable an fbgemm_gpu_genai path, if available in user env
3. we should map to a triton kernel for quantizing the weights, as
`torch.compile` is currently known slow for 128x128 block
quantization
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: db464e1
ghstack-comment-id: 3460951962
Pull-Request: #3257
0 commit comments