Skip to content

Commit 85b229b

Browse files
cthifacebook-github-bot
authored andcommitted
Typo and small fixes to CK fp8 rowwise grouped (#4550)
Summary: X-link: facebookresearch/FBGEMM#1593 - Typos+grammar fix found by LLM - Validation bug found by LLM - Small logic simplification I missed in my prior PR. Differential Revision: D78827450
1 parent 86a031b commit 85b229b

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void set_static_kernel_args(
109109
int64_t output_offset = 0;
110110
// When group count is large, we can more efficiently initialize
111111
// by doing host setup and a memcpy. This is only viable if cuda
112-
// graphs arent being used.
112+
// graphs aren't being used.
113113
// Iterate over inputs and get group information.
114114
for (int i = 0; i < group_count; i++) {
115115
int64_t M = XQ[i].size(0);
@@ -163,7 +163,7 @@ __global__ void set_kernel_args(
163163
int64_t K,
164164
int64_t group_count,
165165
std::optional<GroupedGemmInputType> input_type = std::nullopt) {
166-
// The "message" part seems not working on AMD currently :(
166+
// The "message" part is not working on AMD currently :(
167167
CUDA_KERNEL_ASSERT_MSG((M_sizes == nullptr && offsets == nullptr) || (M_sizes == nullptr ^ offsets == nullptr), "Cannot set both M_sizes and offsets");
168168
CUDA_KERNEL_ASSERT_MSG(input_type.has_value() || M_sizes != nullptr, "M_sizes should not be used with input_type");
169169

@@ -513,7 +513,7 @@ OutputType _f8f8bf16_rowwise_grouped(
513513
for (at::Tensor xs : x_scale) {
514514
TORCH_CHECK(xs.dtype() == at::kFloat, "Scales must be float32.");
515515
}
516-
for (at::Tensor ws : x_scale) {
516+
for (at::Tensor ws : w_scale) {
517517
TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32.");
518518
}
519519

@@ -774,7 +774,7 @@ at::Tensor f8f8bf16_rowwise_grouped_mm(
774774
TORCH_CHECK(w_scale.size(0) == G && w_scale.size(1) == N, "w_scale shape must be (G, N).");
775775
TORCH_CHECK(out.dim() == 3 && out.size(0) == G && out.size(1) == M && out.size(2) == N, "out shape must be (G, M, N).");
776776
} else if (XQ.dim() == 2 && WQ.dim() == 2) {
777-
TORCH_CHECK(offsets.has_value(), "Must pass offsets for 2D inputs XQ nd WQ.");
777+
TORCH_CHECK(offsets.has_value(), "Must pass offsets for 2D inputs XQ and WQ.");
778778
TORCH_CHECK(offsets->dtype() == at::kInt, "offsets must be int32.");
779779

780780
G = offsets->size(0);

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,11 @@ struct DeviceGemmHelper {
118118
// Get input information.
119119
int group_count;
120120
if constexpr (std::is_same_v<InputType, at::Tensor>) {
121-
if (WQ.dim() == 3) {
122-
// If WQ is 3D the group count is the min of G and total_M (if XQ is
123-
// 2D).
124-
group_count = std::min(WQ.size(0), XQ.size(0));
125-
} else if (XQ.dim() == 3) {
126-
// If XQ is 3D the group count is the min of G and total_N (if WQ is
127-
// 2D).
121+
if (XQ.dim() == 3 || WQ.dim() == 3) {
122+
// If WQ and XQ are 3D, the group count is G.
123+
// If WQ is 3D and XQ is 2D (and the reverse by symmetry), the group
124+
// count is the minimum of G and total_M/total_N. In all cases we just
125+
// compare the first dimension of XQ and WQ.
128126
group_count = std::min(XQ.size(0), WQ.size(0));
129127
} else {
130128
// XQ and WQ are 2D. The group count is G.
@@ -163,7 +161,7 @@ struct DeviceGemmHelper {
163161
// pointers below are unused, as the device memory contains the correct
164162
// data.
165163
if constexpr (std::is_same_v<InputType, at::Tensor>) {
166-
// Set these to 0 as placeholders, they are unsused.
164+
// Set these to 0 as placeholders, they are unused.
167165
M = 0;
168166
N = 0;
169167
K = 0;

0 commit comments

Comments
 (0)