|
1 | 1 | #include <torch/all.h>
|
2 | 2 | #include "cuda_utils.h"
|
| 3 | +#include "cutlass_extensions/common.hpp" |
3 | 4 |
|
4 | 5 | template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
5 | 6 | void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
@@ -28,29 +29,46 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
28 | 29 | }
|
29 | 30 | }
|
30 | 31 | } else {
|
31 |
| - using GroupShape = std::array<int64_t, 2>; |
32 |
| - auto make_group_shape = [](torch::Tensor const& x, |
33 |
| - torch::Tensor const& s) -> GroupShape { |
34 |
| - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); |
35 |
| - return {cuda_utils::ceil_div(x.size(0), s.size(0)), |
36 |
| - cuda_utils::ceil_div(x.size(1), s.size(1))}; |
37 |
| - }; |
| 32 | + TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); |
| 33 | + TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); |
| 34 | + int32_t version_num = get_sm_version_num(); |
| 35 | + if (version_num >= 100) { |
| 36 | + TORCH_CHECK( |
| 37 | + a.size(0) == a_scales.size(0) && |
| 38 | + cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), |
| 39 | + "a_scale_group_shape must be [1, 128]."); |
| 40 | + TORCH_CHECK( |
| 41 | + cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && |
| 42 | + cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), |
| 43 | + "b_scale_group_shape must be [128, 128]."); |
| 44 | + } else { |
| 45 | + // TODO: Remove this after using cutlass sm90 blockwise scaling gemm |
| 46 | + // kernel, or introducing ceil_div to the load_init() of mainloop. |
| 47 | + using GroupShape = std::array<int64_t, 2>; |
| 48 | + auto make_group_shape = [](torch::Tensor const& x, |
| 49 | + torch::Tensor const& s) -> GroupShape { |
| 50 | + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); |
| 51 | + return {cuda_utils::ceil_div(x.size(0), s.size(0)), |
| 52 | + cuda_utils::ceil_div(x.size(1), s.size(1))}; |
| 53 | + }; |
| 54 | + |
| 55 | + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); |
| 56 | + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); |
38 | 57 |
|
39 |
| - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); |
40 |
| - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); |
| 58 | + // 1x128 per-token group scales for activations |
| 59 | + // 128x128 blockwise scales for weights |
| 60 | + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && |
| 61 | + b_scale_group_shape == GroupShape{128, 128} && |
| 62 | + a.dtype() == torch::kFloat8_e4m3fn && |
| 63 | + b.dtype() == torch::kFloat8_e4m3fn), |
| 64 | + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" |
| 65 | + "a_scale_group_shape must be [1, 128]. Got: [", |
| 66 | + a_scale_group_shape[0], ", ", a_scale_group_shape[1], |
| 67 | + "]\n" |
| 68 | + "b_scale_group_shape must be [128, 128]. Got: [", |
| 69 | + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); |
| 70 | + } |
41 | 71 |
|
42 |
| - // 1x128 per-token group scales for activations |
43 |
| - // 128x128 blockwise scales for weights |
44 |
| - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && |
45 |
| - b_scale_group_shape == GroupShape{128, 128} && |
46 |
| - a.dtype() == torch::kFloat8_e4m3fn && |
47 |
| - b.dtype() == torch::kFloat8_e4m3fn), |
48 |
| - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" |
49 |
| - "a_scale_group_shape must be [1, 128]. Got: [", |
50 |
| - a_scale_group_shape[0], ", ", a_scale_group_shape[1], |
51 |
| - "]\n" |
52 |
| - "b_scale_group_shape must be [128, 128]. Got: [", |
53 |
| - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); |
54 | 72 | TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
55 | 73 | blockwise_func(c, a, b, a_scales, b_scales);
|
56 | 74 | }
|
|
0 commit comments