@@ -47,29 +47,27 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
47
47
48
48
__global__ void compute_expert_offsets (
49
49
const int32_t * __restrict__ problem_sizes1, int32_t * expert_offsets,
50
- int32_t * atomic_buffer, const int num_experts, const int topk_length ) {
50
+ int32_t * atomic_buffer, const int num_experts, const bool swap_ab ) {
51
51
int32_t tot_offset = 0 ;
52
52
expert_offsets[0 ] = 0 ;
53
53
for (int i = 0 ; i < num_experts; ++i) {
54
54
atomic_buffer[i] = tot_offset;
55
- tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3 ]
56
- : problem_sizes1[i * 3 + 1 ];
55
+ tot_offset += swap_ab ? problem_sizes1[i * 3 + 1 ] : problem_sizes1[i * 3 ];
57
56
expert_offsets[i + 1 ] = tot_offset;
58
57
}
59
58
}
60
59
61
60
__global__ void compute_expert_blockscale_offsets (
62
61
const int32_t * __restrict__ problem_sizes1, int32_t * expert_offsets,
63
62
int32_t * blockscale_offsets, int32_t * atomic_buffer, const int num_experts,
64
- const int topk_length ) {
63
+ const bool swap_ab ) {
65
64
int32_t tot_offset = 0 ;
66
65
int32_t tot_offset_round = 0 ;
67
66
expert_offsets[0 ] = 0 ;
68
67
blockscale_offsets[0 ] = 0 ;
69
68
for (int i = 0 ; i < num_experts; ++i) {
70
- int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
71
- ? problem_sizes1[i * 3 ]
72
- : problem_sizes1[i * 3 + 1 ];
69
+ int32_t cur_offset =
70
+ swap_ab ? problem_sizes1[i * 3 + 1 ] : problem_sizes1[i * 3 ];
73
71
atomic_buffer[i] = tot_offset;
74
72
tot_offset += cur_offset;
75
73
expert_offsets[i + 1 ] = tot_offset;
@@ -119,15 +117,19 @@ void get_cutlass_moe_mm_data_caller(
119
117
120
118
int num_threads = min (THREADS_PER_EXPERT, topk_ids.numel ());
121
119
122
- if (topk_ids.numel () > SWAP_AB_THRESHOLD) {
123
- compute_problem_sizes<false ><<<num_experts, num_threads, 0 , stream>>> (
120
+ // Swap-AB should be disabled for FP4 path
121
+ bool may_swap_ab = (!blockscale_offsets.has_value ()) &&
122
+ (topk_ids.numel () <= SWAP_AB_THRESHOLD);
123
+
124
+ if (may_swap_ab) {
125
+ compute_problem_sizes<true ><<<num_experts, num_threads, 0 , stream>>> (
124
126
static_cast <const int32_t *>(topk_ids.data_ptr ()),
125
127
static_cast <int32_t *>(problem_sizes1.data_ptr ()),
126
128
static_cast <int32_t *>(problem_sizes2.data_ptr ()),
127
129
static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n,
128
130
k);
129
131
} else {
130
- compute_problem_sizes<true ><<<num_experts, num_threads, 0 , stream>>> (
132
+ compute_problem_sizes<false ><<<num_experts, num_threads, 0 , stream>>> (
131
133
static_cast <const int32_t *>(topk_ids.data_ptr ()),
132
134
static_cast <int32_t *>(problem_sizes1.data_ptr ()),
133
135
static_cast <int32_t *>(problem_sizes2.data_ptr ()),
@@ -136,18 +138,19 @@ void get_cutlass_moe_mm_data_caller(
136
138
}
137
139
138
140
if (blockscale_offsets.has_value ()) {
141
+ // fp4 path
139
142
compute_expert_blockscale_offsets<<<1 , 1 , 0 , stream>>> (
140
143
static_cast <const int32_t *>(problem_sizes1.data_ptr ()),
141
144
static_cast <int32_t *>(expert_offsets.data_ptr ()),
142
145
static_cast <int32_t *>(blockscale_offsets.value ().data_ptr ()),
143
146
static_cast <int32_t *>(atomic_buffer.data_ptr ()), num_experts,
144
- topk_ids. numel () );
147
+ may_swap_ab );
145
148
} else {
146
149
compute_expert_offsets<<<1 , 1 , 0 , stream>>> (
147
150
static_cast <const int32_t *>(problem_sizes1.data_ptr ()),
148
151
static_cast <int32_t *>(expert_offsets.data_ptr ()),
149
152
static_cast <int32_t *>(atomic_buffer.data_ptr ()), num_experts,
150
- topk_ids. numel () );
153
+ may_swap_ab );
151
154
}
152
155
compute_arg_sorts<<<num_experts, num_threads, 0 , stream>>> (
153
156
static_cast <const int32_t *>(topk_ids.data_ptr ()),
0 commit comments