Skip to content

Add SM80/89 blockwise scaling kernel, support FP8 block/groupwise on Ada, INT8 on Ampere #2328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

solrex
Copy link

@solrex solrex commented May 24, 2025

Inspired by #1932 and #2037, implement blockscaling kernel on platforms before SM90.

  • FP8 blockwise/groupwise scaling kernel for Ada(L20, L40S, 4090) (Requires accumulator type to be float)
  • INT8 blockwise/groupwise scaling kernel for Ampere(A100/800, A10, A30) (Requires accumulator type to be int)
  • CUTLASS 3.x API

* FP8 blockwise/groupwise kernel for Ada(L20,L40S,4090)
* INT8 blockwise/groupwise kernel for Ampere(A100/800)
@solrex solrex changed the title Add SM80/89 blockwise scaling kernel, support FP8 block/groupwise on Ada, INT8 block/groupwise on Ampere Add SM80/89 blockwise scaling kernel, support FP8 block/groupwise on Ada, INT8 on Ampere May 24, 2025
@solrex solrex force-pushed the sm80-blockscale branch from 2b2a88b to 5c58e77 Compare May 26, 2025 18:03
@hwu36
Copy link
Collaborator

hwu36 commented May 28, 2025

@jackkosaian

@solrex
Copy link
Author

solrex commented May 28, 2025

The following are the example benchmark results on L40S with CUDA 12.4 and CUTLASS main:

FP8:

$ ./examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute
Problem Size: 1024x1024x1024x1
  Tile shape (M, N, K): _64, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 64)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Running... 
  Result MSE: 2.79446e-06, MRE: 12.0697, greatest error: 0.0196838
  Disposition: Passed
  Avg runtime: 0.00905421 ms
  GFLOPS: 237181

$ ./examples/85_ada_ampere_gemm_with_blockwise_scaling/85b_ada_fp8_gemm_with_blockwise_scaling_cute
  Problem Size: 1024x1024x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Running... 
  Result MSE: 2.61817e-06, MRE: 11.7382, greatest error: 0.0210075
  Disposition: Passed
  Avg runtime: 0.0233175 ms
  GFLOPS: 92097.5

INT8: 

$ ./examples/85_ada_ampere_gemm_with_blockwise_scaling/85c_ampere_int8_gemm_with_groupwise_scaling_cute
  Problem Size: 1024x1024x1024x1
  Tile shape (M, N, K): _64, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 64)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Running... 
  Result MSE: 0, MRE: 81.7363, greatest error: 0
  Disposition: Passed
  Avg runtime: 0.00911155 ms
  GFLOPS: 235688

$ ./examples/85_ada_ampere_gemm_with_blockwise_scaling/85d_ampere_int8_gemm_with_blockwise_scaling_cute
  Problem Size: 1024x1024x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Running... 
  Result MSE: 0, MRE: 77.9124, greatest error: 0
  Disposition: Passed
  Avg runtime: 0.0239155 ms
  GFLOPS: 89794.6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants