Skip to content

Commit e8cb0d0

Browse files
authored
[Bug] Fix Compressed Tensor NVFP4 cutlass_fp4_group_mm illegal memory access (#21465)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 6841741 commit e8cb0d0

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,27 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
4747

4848
__global__ void compute_expert_offsets(
4949
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
50-
int32_t* atomic_buffer, const int num_experts, const int topk_length) {
50+
int32_t* atomic_buffer, const int num_experts, const bool swap_ab) {
5151
int32_t tot_offset = 0;
5252
expert_offsets[0] = 0;
5353
for (int i = 0; i < num_experts; ++i) {
5454
atomic_buffer[i] = tot_offset;
55-
tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3]
56-
: problem_sizes1[i * 3 + 1];
55+
tot_offset += swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3];
5756
expert_offsets[i + 1] = tot_offset;
5857
}
5958
}
6059

6160
__global__ void compute_expert_blockscale_offsets(
6261
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
6362
int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
64-
const int topk_length) {
63+
const bool swap_ab) {
6564
int32_t tot_offset = 0;
6665
int32_t tot_offset_round = 0;
6766
expert_offsets[0] = 0;
6867
blockscale_offsets[0] = 0;
6968
for (int i = 0; i < num_experts; ++i) {
70-
int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
71-
? problem_sizes1[i * 3]
72-
: problem_sizes1[i * 3 + 1];
69+
int32_t cur_offset =
70+
swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3];
7371
atomic_buffer[i] = tot_offset;
7472
tot_offset += cur_offset;
7573
expert_offsets[i + 1] = tot_offset;
@@ -119,15 +117,19 @@ void get_cutlass_moe_mm_data_caller(
119117

120118
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
121119

122-
if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
123-
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
120+
// Swap-AB should be disabled for FP4 path
121+
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
122+
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
123+
124+
if (may_swap_ab) {
125+
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
124126
static_cast<const int32_t*>(topk_ids.data_ptr()),
125127
static_cast<int32_t*>(problem_sizes1.data_ptr()),
126128
static_cast<int32_t*>(problem_sizes2.data_ptr()),
127129
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
128130
k);
129131
} else {
130-
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
132+
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
131133
static_cast<const int32_t*>(topk_ids.data_ptr()),
132134
static_cast<int32_t*>(problem_sizes1.data_ptr()),
133135
static_cast<int32_t*>(problem_sizes2.data_ptr()),
@@ -136,18 +138,19 @@ void get_cutlass_moe_mm_data_caller(
136138
}
137139

138140
if (blockscale_offsets.has_value()) {
141+
// fp4 path
139142
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
140143
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
141144
static_cast<int32_t*>(expert_offsets.data_ptr()),
142145
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
143146
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
144-
topk_ids.numel());
147+
may_swap_ab);
145148
} else {
146149
compute_expert_offsets<<<1, 1, 0, stream>>>(
147150
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
148151
static_cast<int32_t*>(expert_offsets.data_ptr()),
149152
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
150-
topk_ids.numel());
153+
may_swap_ab);
151154
}
152155
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
153156
static_cast<const int32_t*>(topk_ids.data_ptr()),

0 commit comments

Comments
 (0)