Skip to content

Commit e23564c

Browse files
authored
use ceil_div in cutlass block scaling shape check (#17918)
1 parent 390ec88 commit e23564c

File tree

3 files changed

+62
-25
lines changed

3 files changed

+62
-25
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,16 @@ def bench_fp8(
115115
a_cont = a.contiguous()
116116
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
117117
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
118-
block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32)
119-
block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32)
118+
119+
def ceil_div(x: int, y: int) -> int:
120+
return (x + y - 1) // y
121+
122+
block_scale_a = torch.rand(
123+
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
124+
)
125+
block_scale_b = torch.rand(
126+
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
127+
)
120128
block_scale_a_M_major = block_scale_a.t().contiguous().t()
121129
block_scale_b_K_major = block_scale_b.t().contiguous().t()
122130
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/all.h>
22
#include "cuda_utils.h"
3+
#include "cutlass_extensions/common.hpp"
34

45
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
56
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
@@ -28,29 +29,46 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
2829
}
2930
}
3031
} else {
31-
using GroupShape = std::array<int64_t, 2>;
32-
auto make_group_shape = [](torch::Tensor const& x,
33-
torch::Tensor const& s) -> GroupShape {
34-
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
35-
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
36-
cuda_utils::ceil_div(x.size(1), s.size(1))};
37-
};
32+
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
33+
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
34+
int32_t version_num = get_sm_version_num();
35+
if (version_num >= 100) {
36+
TORCH_CHECK(
37+
a.size(0) == a_scales.size(0) &&
38+
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
39+
"a_scale_group_shape must be [1, 128].");
40+
TORCH_CHECK(
41+
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
42+
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
43+
"b_scale_group_shape must be [128, 128].");
44+
} else {
45+
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
46+
// kernel, or introducing ceil_div to the load_init() of mainloop.
47+
using GroupShape = std::array<int64_t, 2>;
48+
auto make_group_shape = [](torch::Tensor const& x,
49+
torch::Tensor const& s) -> GroupShape {
50+
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
51+
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
52+
cuda_utils::ceil_div(x.size(1), s.size(1))};
53+
};
54+
55+
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
56+
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
3857

39-
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
40-
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
58+
// 1x128 per-token group scales for activations
59+
// 128x128 blockwise scales for weights
60+
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
61+
b_scale_group_shape == GroupShape{128, 128} &&
62+
a.dtype() == torch::kFloat8_e4m3fn &&
63+
b.dtype() == torch::kFloat8_e4m3fn),
64+
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
65+
"a_scale_group_shape must be [1, 128]. Got: [",
66+
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
67+
"]\n"
68+
"b_scale_group_shape must be [128, 128]. Got: [",
69+
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
70+
}
4171

42-
// 1x128 per-token group scales for activations
43-
// 128x128 blockwise scales for weights
44-
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
45-
b_scale_group_shape == GroupShape{128, 128} &&
46-
a.dtype() == torch::kFloat8_e4m3fn &&
47-
b.dtype() == torch::kFloat8_e4m3fn),
48-
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
49-
"a_scale_group_shape must be [1, 128]. Got: [",
50-
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
51-
"]\n"
52-
"b_scale_group_shape must be [128, 128]. Got: [",
53-
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
5472
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
5573
blockwise_func(c, a, b, a_scales, b_scales);
5674
}

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,19 @@ def apply_w8a8_block_fp8_linear(
115115
output_shape = [*input.shape[:-1], weight.shape[0]]
116116

117117
if current_platform.is_cuda():
118-
use_cutlass = cutlass_block_fp8_supported and (
119-
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
118+
if current_platform.has_device_capability(100):
119+
120+
def ceil_div(x: int, y: int) -> int:
121+
return (x + y - 1) // y
122+
123+
use_cutlass = cutlass_block_fp8_supported and (
124+
ceil_div(weight.shape[0], 128) == weight_scale.shape[0]
125+
and ceil_div(weight.shape[1], 128) == weight_scale.shape[1])
126+
else:
127+
# TODO: update this after switching to public sm90 block scale gemm
128+
# as it also supports weight.shape % 128 != 0
129+
use_cutlass = cutlass_block_fp8_supported and (
130+
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
120131
else:
121132
use_cutlass = False
122133

0 commit comments

Comments
 (0)