7
7
8
8
constexpr uint64_t THREADS_PER_EXPERT = 512 ;
9
9
10
- __global__ void compute_problem_sizes (const uint32_t * __restrict__ topk_ids,
10
+ __global__ void compute_problem_sizes (const int32_t * __restrict__ topk_ids,
11
11
int32_t * problem_sizes1,
12
12
int32_t * problem_sizes2,
13
13
int32_t * atomic_buffer,
@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
62
62
}
63
63
}
64
64
65
- __global__ void compute_arg_sorts (const uint32_t * __restrict__ topk_ids,
65
+ __global__ void compute_arg_sorts (const int32_t * __restrict__ topk_ids,
66
66
const int32_t * __restrict__ expert_offsets,
67
67
int32_t * input_permutation,
68
68
int32_t * output_permutation,
@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
103
103
104
104
int num_threads = min (THREADS_PER_EXPERT, topk_ids.numel ());
105
105
compute_problem_sizes<<<num_experts, num_threads, 0 , stream>>> (
106
- static_cast <const uint32_t *>(topk_ids.data_ptr ()),
106
+ static_cast <const int32_t *>(topk_ids.data_ptr ()),
107
107
static_cast <int32_t *>(problem_sizes1.data_ptr ()),
108
108
static_cast <int32_t *>(problem_sizes2.data_ptr ()),
109
109
static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n, k);
@@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
120
120
static_cast <int32_t *>(atomic_buffer.data_ptr ()), num_experts);
121
121
}
122
122
compute_arg_sorts<<<num_experts, num_threads, 0 , stream>>> (
123
- static_cast <const uint32_t *>(topk_ids.data_ptr ()),
123
+ static_cast <const int32_t *>(topk_ids.data_ptr ()),
124
124
static_cast <const int32_t *>(expert_offsets.data_ptr ()),
125
125
static_cast <int32_t *>(input_permutation.data_ptr ()),
126
126
static_cast <int32_t *>(output_permutation.data_ptr ()),
0 commit comments