Skip to content

Commit c9affaa

Browse files
xw285cornellfacebook-github-bot
authored andcommitted
fused moe fix (#4534)
Summary: Pull Request resolved: #4534 X-link: facebookresearch/FBGEMM#1579 A number of changes on CK header that we need downstream fixes: * local token was added to MoeSortingProblemEx: https://github.com/ROCm/composable_kernel/blob/develop/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp#L48 * MeshType_ was added to MoeSortingProblemMp: https://github.com/ROCm/composable_kernel/blob/develop/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp#L56 * MOE_SORTING_FMOE_2D_BUF was added which I just turn off for now ``` In file included from buck-out/v2/gen/fbcode/de5ab6968d300cd2/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/__ck_fused_moe_hipify_gen__/out/fused_moe/instances/fused_moesorting_api.hip:4: In file included from fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moesorting.hpp:8: In file included from third-party/rocm_composable_kernel/include/ck_tile/ops/fused_moe.hpp:9: third-party/rocm_composable_kernel/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp:2100:11: error: no member named 'moe_buf_elem_bytes' in 'ck_tile::MoeSortingMultiPhaseKernel_P2<ck_tile::MoeSortingProblemMp<int, float, int, 1, false, false>>::Kargs' 2100 | k.moe_buf_elem_bytes = h.moe_buf_elem_bytes; | ~ ^ ``` Reviewed By: zjing14 Differential Revision: D78679653
1 parent cd2881e commit c9affaa

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ at::Tensor fused_moe_impl(
7777
auto prec_o = get_prec_str(output);
7878
auto prec_tkw = get_prec_str(topk_weights);
7979

80-
int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
80+
int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0);
8181
void *ws_ptr = nullptr;
8282
if (workspace_size > 0)
8383
{
8484
auto ws = at::zeros({workspace_size}, at::TensorOptions().dtype(topk_ids.dtype()).device(device_of(topk_ids)));
8585
ws_ptr = ws.data_ptr();
8686
}
87-
87+
8888

8989
// Set up traits structure
9090
fused_moe_traits traits{

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moesorting.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ float fused_moesorting(
2020
fused_moesorting_args a,
2121
ck_tile::stream_config s);
2222

23-
int moe_sorting_get_workspace_size(int tokens, int num_experts);
23+
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk);
2424
float moe_sorting_mp(
2525
fused_moesorting_trait t,
2626
fused_moesorting_args a,

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moe_api.hip

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
2222
a.topk_ids_ptr, // const void* p_topk_ids;
2323
a.topk_weight_ptr, // const void* p_weights;
2424
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
25+
nullptr, // const void* p_local_tokens;
2526
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
2627
a.sorted_weight_ptr, // void* p_sorted_weights;
2728
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moesorting_api.hip

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
ms_weight_type, \
3434
sub_token_tile, \
3535
sub_token_onshot, \
36-
local_expert_masking>; \
36+
local_expert_masking, \
37+
false>; \
3738
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
3839
auto kargs = kernel::MakeKargs(a); \
3940
const dim3 grids = kernel::GridSize(a); \
@@ -153,7 +154,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
153154
}
154155
}
155156
#else
156-
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
157+
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
157158
{
158159
return moe_sorting_mp(t, a, s);
159160
}
@@ -176,7 +177,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
176177
constexpr ck_tile::index_t unroll_num = unroll_num_; \
177178
constexpr bool expert_masking = expert_masking_; \
178179
using ms_problem = \
179-
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
180+
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, ck_tile::index_t, unroll_num, expert_masking, false>; \
180181
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
181182
auto kargs = kernel::MakeKargs(a); \
182183
const dim3 grids = kernel::GridSize(a); \
@@ -189,7 +190,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
189190
constexpr ck_tile::index_t unroll_num = unroll_num_; \
190191
constexpr bool expert_masking = expert_masking_; \
191192
using ms_problem = \
192-
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
193+
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, ck_tile::index_t, unroll_num, expert_masking, false>; \
193194
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
194195
auto kargs = kernel::MakeKargs(a); \
195196
const dim3 grids = kernel::GridSize(a); \
@@ -202,7 +203,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
202203
constexpr ck_tile::index_t unroll_num = unroll_num_; \
203204
constexpr bool expert_masking = expert_masking_; \
204205
using ms_problem = \
205-
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
206+
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, ck_tile::index_t, unroll_num, expert_masking, false>; \
206207
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
207208
auto kargs = kernel::MakeKargs(a); \
208209
const dim3 grids = kernel::GridSize(a); \
@@ -215,7 +216,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
215216
constexpr ck_tile::index_t unroll_num = unroll_num_; \
216217
constexpr bool expert_masking = expert_masking_; \
217218
using ms_problem = \
218-
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, unroll_num, expert_masking>; \
219+
ck_tile::MoeSortingProblemMp<ms_index_t, ms_weight_type, ck_tile::index_t, unroll_num, expert_masking, false>; \
219220
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
220221
auto kargs = kernel::MakeKargs(a); \
221222
const dim3 grids = kernel::GridSize(a); \
@@ -252,7 +253,7 @@ float moe_sorting_mp(fused_moesorting_trait t, fused_moesorting_args a, ck_tile:
252253
return -1;
253254
}
254255

255-
int moe_sorting_get_workspace_size(int tokens, int num_experts)
256+
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk)
256257
{
257-
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts);
258-
}
258+
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, 0);
259+
}

0 commit comments

Comments
 (0)