-
Notifications
You must be signed in to change notification settings - Fork 298
ROCm mx-fp8 Gemm #2066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
ROCm mx-fp8 Gemm #2066
Changes from 12 commits
c21d24c
36dd5b7
c75df8e
9b7b602
5ee124e
df2c220
8df1d85
8ae4021
8505860
129a6d6
c807d70
75db95e
3ecc91e
ef52979
f88f1cf
5d2b55d
979893a
012f938
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,11 +32,15 @@ class MXGemmKernelChoice(Enum): | |
# note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 | ||
CUBLAS = "cublas" | ||
|
||
# available only on ROCm with HIPBLASLT support, reuqire gfx950 and ROCm 7.0 | ||
HIPBLASLT = "hipblaslt" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should change this to cc @vkuzo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you clarify your approach here? I believe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohh I just mean that if I want to quantize a model to mxfp8 I need to know if I am running on rocm or cuda. And the only place where one needs to know this is here. But in reality the "CUBLAS" enum really means "call into scaled_mm" and that would handle all the dispatch logic. It feels weird and anti-pattern to core pytorch to have device specific(cuda/ROCM) APIs when we dont need to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, IMO for what this PR is trying to do we should rename |
||
|
||
|
||
# Pre-made recipes for common configurations | ||
class MXLinearRecipeName(Enum): | ||
MXFP8_EMULATED = "mxfp8_emulated" | ||
MXFP8_CUBLAS = "mxfp8_cublas" | ||
MXFP8_HIPBLASLT = "mxfp8_hipblaslt" | ||
MXFP4_EMULATED = "mxfp4_emulated" | ||
MXFP4_CUTLASS = "mxfp4_cutlass" | ||
|
||
|
@@ -64,6 +68,15 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): | |
assert elem_dtype in valid_dtypes, ( | ||
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" | ||
) | ||
elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT: | ||
assert block_size == 32, ( | ||
f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" | ||
) | ||
valid_dtypes = [torch.float8_e4m3fn] | ||
assert elem_dtype in valid_dtypes, ( | ||
f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" | ||
) | ||
assert torch.version.hip is not None, "HIPBLASLT requires ROCm" | ||
|
||
|
||
@dataclass | ||
|
@@ -124,6 +137,8 @@ def from_recipe_name( | |
return MXLinearConfig() | ||
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS: | ||
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS) | ||
elif recipe_name is MXLinearRecipeName.MXFP8_HIPBLASLT: | ||
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT) | ||
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: | ||
return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2) | ||
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: | ||
|
Uh oh!
There was an error while loading. Please reload this page.