Skip to content

Commit 1eb8ea7

Browse files
authored
[Bug fix] fix complie bug when sm < 89 (#2738)
1 parent ef6649a commit 1eb8ea7

File tree

2 files changed

+44
-46
lines changed

2 files changed

+44
-46
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ std::vector<paddle::Tensor> NoauxTc(
468468
int topk,
469469
float routed_scaling_factor);
470470

471+
#ifdef ENABLE_FP8
471472
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
472473
const paddle::Tensor& x,
473474
const paddle::Tensor& y,
@@ -489,6 +490,7 @@ paddle::Tensor MoeFusedHadamardQuantFp8Func(
489490
paddle::Tensor FusedHadamardQuantFp8Func(
490491
const paddle::Tensor &input,
491492
const float scale);
493+
#endif
492494

493495
PYBIND11_MODULE(fastdeploy_ops, m) {
494496

@@ -769,6 +771,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
769771

770772
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
771773

774+
#ifdef ENABLE_FP8
772775
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
773776
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
774777
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
@@ -780,4 +783,5 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
780783

781784
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
782785
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
786+
#endif
783787
}

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,10 @@ def apply(
129129
True, # apply_norm_weight,
130130
False,
131131
)
132-
intermediate_cache1 = paddle.empty(
132+
ffn1_out = paddle.empty(
133133
[token_num * top_k, moe_intermediate_size * 2],
134134
dtype=x.dtype,
135135
)
136-
intermediate_cache2 = paddle.empty(
137-
(token_num * top_k, moe_intermediate_size),
138-
dtype=x.dtype,
139-
)
140-
intermediate_cache3 = paddle.empty(
141-
(token_num * top_k, hidden_size),
142-
dtype=x.dtype,
143-
)
144136

145137
config = {
146138
"BLOCK_SIZE_M": 32,
@@ -158,7 +150,7 @@ def apply(
158150
fused_moe_kernel_paddle[grid](
159151
x,
160152
layer.moe_ffn1_weight,
161-
intermediate_cache1,
153+
ffn1_out,
162154
None,
163155
layer.moe_ffn1_weight_scale,
164156
None,
@@ -174,8 +166,8 @@ def apply(
174166
stride_be=layer.moe_ffn1_weight.strides[0],
175167
stride_bk=layer.moe_ffn1_weight.strides[1],
176168
stride_bn=layer.moe_ffn1_weight.strides[2],
177-
stride_cm=intermediate_cache1.strides[0],
178-
stride_cn=intermediate_cache1.strides[1],
169+
stride_cm=ffn1_out.strides[0],
170+
stride_cn=ffn1_out.strides[1],
179171
#
180172
stride_asm=-1,
181173
stride_ask=-1,
@@ -197,16 +189,21 @@ def apply(
197189
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
198190
)
199191

200-
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
201-
intermediate_cache1)
192+
ffn2_input = paddle.incubate.nn.functional.swiglu(
193+
ffn1_out)
194+
195+
ffn2_out = paddle.empty(
196+
(token_num * top_k, hidden_size),
197+
dtype=x.dtype,
198+
)
202199

203200
grid = (
204201
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
205202
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
206203
fused_moe_kernel_paddle[grid](
207-
intermediate_cache2,
204+
ffn2_input,
208205
layer.moe_ffn2_weight,
209-
intermediate_cache3,
206+
ffn2_out,
210207
None,
211208
layer.moe_ffn2_weight_scale,
212209
topk_weights,
@@ -217,13 +214,13 @@ def apply(
217214
token_num * top_k,
218215
N=hidden_size,
219216
K=moe_intermediate_size,
220-
stride_am=intermediate_cache2.strides[0],
221-
stride_ak=intermediate_cache2.strides[1],
217+
stride_am=ffn2_input.strides[0],
218+
stride_ak=ffn2_input.strides[1],
222219
stride_be=layer.moe_ffn2_weight.strides[0],
223220
stride_bk=layer.moe_ffn2_weight.strides[1],
224221
stride_bn=layer.moe_ffn2_weight.strides[2],
225-
stride_cm=intermediate_cache3.strides[0],
226-
stride_cn=intermediate_cache3.strides[1],
222+
stride_cm=ffn2_out.strides[0],
223+
stride_cn=ffn2_out.strides[1],
227224
stride_asm=-1,
228225
stride_ask=-1,
229226
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
@@ -244,8 +241,8 @@ def apply(
244241
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
245242
)
246243

247-
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
248-
out = intermediate_cache3.sum(axis=1)
244+
ffn2_out.reshape_([token_num, top_k, hidden_size])
245+
out = ffn2_out.sum(axis=1)
249246
return out
250247

251248

@@ -343,18 +340,10 @@ def apply(
343340
False,
344341
)
345342

346-
intermediate_cache1 = paddle.empty(
343+
ffn1_out = paddle.empty(
347344
[token_num * top_k, moe_intermediate_size * 2],
348345
dtype=x.dtype,
349346
)
350-
intermediate_cache2 = paddle.empty(
351-
(token_num * top_k, moe_intermediate_size),
352-
dtype=x.dtype,
353-
)
354-
intermediate_cache3 = paddle.empty(
355-
(token_num * top_k, hidden_size),
356-
dtype=x.dtype,
357-
)
358347

359348
config_ffn1 = {
360349
"BLOCK_SIZE_M": 32,
@@ -381,7 +370,7 @@ def apply(
381370
fused_moe_kernel_paddle[grid](
382371
permute_x,
383372
layer.moe_ffn1_weight,
384-
intermediate_cache1,
373+
ffn1_out,
385374
layer.moe_ffn1_in_scale,
386375
layer.moe_ffn1_weight_scale,
387376
None,
@@ -397,8 +386,8 @@ def apply(
397386
stride_be=layer.moe_ffn1_weight.strides[0],
398387
stride_bk=layer.moe_ffn1_weight.strides[1],
399388
stride_bn=layer.moe_ffn1_weight.strides[2],
400-
stride_cm=intermediate_cache1.strides[0],
401-
stride_cn=intermediate_cache1.strides[1],
389+
stride_cm=ffn1_out.strides[0],
390+
stride_cn=ffn1_out.strides[1],
402391
#
403392
stride_asm=-1, # only used in blockwise fp8
404393
stride_ask=-1, # only used in blockwise fp8
@@ -420,11 +409,11 @@ def apply(
420409
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
421410
)
422411

423-
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
424-
intermediate_cache1)
412+
ffn2_input = paddle.incubate.nn.functional.swiglu(
413+
ffn1_out)
425414

426-
intermediate_cache2 = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
427-
intermediate_cache2,
415+
ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
416+
ffn2_input,
428417
scale=layer.moe_ffn2_in_scale,
429418
topk_ids=topk_ids,
430419
top_k=top_k,
@@ -438,14 +427,19 @@ def apply(
438427
"GROUP_SIZE_M": 1,
439428
}
440429

430+
ffn2_out = paddle.empty(
431+
(token_num * top_k, hidden_size),
432+
dtype=x.dtype,
433+
)
434+
441435
grid = (
442436
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
443437
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
444438

445439
fused_moe_kernel_paddle[grid](
446-
intermediate_cache2,
440+
ffn2_input,
447441
layer.moe_ffn2_weight,
448-
intermediate_cache3,
442+
ffn2_out,
449443
layer.moe_ffn2_in_scale,
450444
layer.moe_ffn2_weight_scale,
451445
topk_weights,
@@ -456,13 +450,13 @@ def apply(
456450
token_num * top_k,
457451
N=hidden_size,
458452
K=moe_intermediate_size,
459-
stride_am=intermediate_cache2.strides[0],
460-
stride_ak=intermediate_cache2.strides[1],
453+
stride_am=ffn2_input.strides[0],
454+
stride_ak=ffn2_input.strides[1],
461455
stride_be=layer.moe_ffn2_weight.strides[0],
462456
stride_bk=layer.moe_ffn2_weight.strides[1],
463457
stride_bn=layer.moe_ffn2_weight.strides[2],
464-
stride_cm=intermediate_cache3.strides[0],
465-
stride_cn=intermediate_cache3.strides[1],
458+
stride_cm=ffn2_out.strides[0],
459+
stride_cn=ffn2_out.strides[1],
466460
stride_asm=-1,
467461
stride_ask=-1,
468462
stride_bse=-1,
@@ -483,8 +477,8 @@ def apply(
483477
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
484478
)
485479

486-
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
487-
out = intermediate_cache3.sum(axis=1)
480+
ffn2_out.reshape_([token_num, top_k, hidden_size])
481+
out = ffn2_out.sum(axis=1)
488482

489483
if layer.tp_size > 1:
490484
tensor_model_parallel_all_reduce(out)

0 commit comments

Comments
 (0)