Skip to content

Commit e5ed822

Browse files
committed
Address comment: change moe_data topk_ids type back to int32_t
Signed-off-by: Ming Yang <yming@meta.com>
1 parent 1567180 commit e5ed822

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
constexpr uint64_t THREADS_PER_EXPERT = 512;
99

10-
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
10+
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
1111
int32_t* problem_sizes1,
1212
int32_t* problem_sizes2,
1313
int32_t* atomic_buffer,
@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
6262
}
6363
}
6464

65-
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
65+
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
6666
const int32_t* __restrict__ expert_offsets,
6767
int32_t* input_permutation,
6868
int32_t* output_permutation,
@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
103103

104104
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
105105
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
106-
static_cast<const uint32_t*>(topk_ids.data_ptr()),
106+
static_cast<const int32_t*>(topk_ids.data_ptr()),
107107
static_cast<int32_t*>(problem_sizes1.data_ptr()),
108108
static_cast<int32_t*>(problem_sizes2.data_ptr()),
109109
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
@@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
120120
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
121121
}
122122
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
123-
static_cast<const uint32_t*>(topk_ids.data_ptr()),
123+
static_cast<const int32_t*>(topk_ids.data_ptr()),
124124
static_cast<const int32_t*>(expert_offsets.data_ptr()),
125125
static_cast<int32_t*>(input_permutation.data_ptr()),
126126
static_cast<int32_t*>(output_permutation.data_ptr()),

0 commit comments

Comments
 (0)