Skip to content

[Bug] Fix Compressed Tensor NVFP4 cutlass_fp4_group_mm illegal memory access #21465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,27 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,

__global__ void compute_expert_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* atomic_buffer, const int num_experts, const int topk_length) {
int32_t* atomic_buffer, const int num_experts, const bool swap_ab) {
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3]
: problem_sizes1[i * 3 + 1];
tot_offset += swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3];
expert_offsets[i + 1] = tot_offset;
}
}

__global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
const int topk_length) {
const bool swap_ab) {
int32_t tot_offset = 0;
int32_t tot_offset_round = 0;
expert_offsets[0] = 0;
blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
? problem_sizes1[i * 3]
: problem_sizes1[i * 3 + 1];
int32_t cur_offset =
swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3];
atomic_buffer[i] = tot_offset;
tot_offset += cur_offset;
expert_offsets[i + 1] = tot_offset;
Expand Down Expand Up @@ -119,15 +117,19 @@ void get_cutlass_moe_mm_data_caller(

int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());

if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
// Swap-AB should be disabled for FP4 path
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
Comment on lines +121 to +122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to rather add a boolean argument to get_cutlass_moe_mm_data() that forces no swap? Looks like disabling swap will be also needed for fp8 blockwise CUTLASS and it doesn't pass blockscale_offsets to this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_cutlass_moe_fp8 run_cutlass_block_scaled_fused_experts which path are you taking about?
I don't have enough context so I am thinking we can do that in following up pr

Copy link
Contributor

@ElizaWszola ElizaWszola Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean a get_cutlass_moe_mm_data() call in run_cutlass_block_scaled_fused_experts() :) But I can add that change to a separate PR


if (may_swap_ab) {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
} else {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
Expand All @@ -136,18 +138,19 @@ void get_cutlass_moe_mm_data_caller(
}

if (blockscale_offsets.has_value()) {
// fp4 path
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
topk_ids.numel());
may_swap_ab);
} else {
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
topk_ids.numel());
may_swap_ab);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
Expand Down