@@ -129,18 +129,10 @@ def apply(
129
129
True , # apply_norm_weight,
130
130
False ,
131
131
)
132
- intermediate_cache1 = paddle .empty (
132
+ ffn1_out = paddle .empty (
133
133
[token_num * top_k , moe_intermediate_size * 2 ],
134
134
dtype = x .dtype ,
135
135
)
136
- intermediate_cache2 = paddle .empty (
137
- (token_num * top_k , moe_intermediate_size ),
138
- dtype = x .dtype ,
139
- )
140
- intermediate_cache3 = paddle .empty (
141
- (token_num * top_k , hidden_size ),
142
- dtype = x .dtype ,
143
- )
144
136
145
137
config = {
146
138
"BLOCK_SIZE_M" : 32 ,
@@ -158,7 +150,7 @@ def apply(
158
150
fused_moe_kernel_paddle [grid ](
159
151
x ,
160
152
layer .moe_ffn1_weight ,
161
- intermediate_cache1 ,
153
+ ffn1_out ,
162
154
None ,
163
155
layer .moe_ffn1_weight_scale ,
164
156
None ,
@@ -174,8 +166,8 @@ def apply(
174
166
stride_be = layer .moe_ffn1_weight .strides [0 ],
175
167
stride_bk = layer .moe_ffn1_weight .strides [1 ],
176
168
stride_bn = layer .moe_ffn1_weight .strides [2 ],
177
- stride_cm = intermediate_cache1 .strides [0 ],
178
- stride_cn = intermediate_cache1 .strides [1 ],
169
+ stride_cm = ffn1_out .strides [0 ],
170
+ stride_cn = ffn1_out .strides [1 ],
179
171
#
180
172
stride_asm = - 1 ,
181
173
stride_ask = - 1 ,
@@ -197,16 +189,21 @@ def apply(
197
189
even_Ks = hidden_size % config ["BLOCK_SIZE_K" ] == 0 ,
198
190
)
199
191
200
- intermediate_cache2 = paddle .incubate .nn .functional .swiglu (
201
- intermediate_cache1 )
192
+ ffn2_input = paddle .incubate .nn .functional .swiglu (
193
+ ffn1_out )
194
+
195
+ ffn2_out = paddle .empty (
196
+ (token_num * top_k , hidden_size ),
197
+ dtype = x .dtype ,
198
+ )
202
199
203
200
grid = (
204
201
ceil_div (max_possible_num_post_padded , config ["BLOCK_SIZE_M" ]) *
205
202
ceil_div (hidden_size , config ["BLOCK_SIZE_N" ]), )
206
203
fused_moe_kernel_paddle [grid ](
207
- intermediate_cache2 ,
204
+ ffn2_input ,
208
205
layer .moe_ffn2_weight ,
209
- intermediate_cache3 ,
206
+ ffn2_out ,
210
207
None ,
211
208
layer .moe_ffn2_weight_scale ,
212
209
topk_weights ,
@@ -217,13 +214,13 @@ def apply(
217
214
token_num * top_k ,
218
215
N = hidden_size ,
219
216
K = moe_intermediate_size ,
220
- stride_am = intermediate_cache2 .strides [0 ],
221
- stride_ak = intermediate_cache2 .strides [1 ],
217
+ stride_am = ffn2_input .strides [0 ],
218
+ stride_ak = ffn2_input .strides [1 ],
222
219
stride_be = layer .moe_ffn2_weight .strides [0 ],
223
220
stride_bk = layer .moe_ffn2_weight .strides [1 ],
224
221
stride_bn = layer .moe_ffn2_weight .strides [2 ],
225
- stride_cm = intermediate_cache3 .strides [0 ],
226
- stride_cn = intermediate_cache3 .strides [1 ],
222
+ stride_cm = ffn2_out .strides [0 ],
223
+ stride_cn = ffn2_out .strides [1 ],
227
224
stride_asm = - 1 ,
228
225
stride_ask = - 1 ,
229
226
stride_bse = layer .moe_ffn2_weight_scale .strides [0 ],
@@ -244,8 +241,8 @@ def apply(
244
241
even_Ks = moe_intermediate_size % config ["BLOCK_SIZE_K" ] == 0 ,
245
242
)
246
243
247
- intermediate_cache3 .reshape_ ([token_num , top_k , hidden_size ])
248
- out = intermediate_cache3 .sum (axis = 1 )
244
+ ffn2_out .reshape_ ([token_num , top_k , hidden_size ])
245
+ out = ffn2_out .sum (axis = 1 )
249
246
return out
250
247
251
248
@@ -343,18 +340,10 @@ def apply(
343
340
False ,
344
341
)
345
342
346
- intermediate_cache1 = paddle .empty (
343
+ ffn1_out = paddle .empty (
347
344
[token_num * top_k , moe_intermediate_size * 2 ],
348
345
dtype = x .dtype ,
349
346
)
350
- intermediate_cache2 = paddle .empty (
351
- (token_num * top_k , moe_intermediate_size ),
352
- dtype = x .dtype ,
353
- )
354
- intermediate_cache3 = paddle .empty (
355
- (token_num * top_k , hidden_size ),
356
- dtype = x .dtype ,
357
- )
358
347
359
348
config_ffn1 = {
360
349
"BLOCK_SIZE_M" : 32 ,
@@ -381,7 +370,7 @@ def apply(
381
370
fused_moe_kernel_paddle [grid ](
382
371
permute_x ,
383
372
layer .moe_ffn1_weight ,
384
- intermediate_cache1 ,
373
+ ffn1_out ,
385
374
layer .moe_ffn1_in_scale ,
386
375
layer .moe_ffn1_weight_scale ,
387
376
None ,
@@ -397,8 +386,8 @@ def apply(
397
386
stride_be = layer .moe_ffn1_weight .strides [0 ],
398
387
stride_bk = layer .moe_ffn1_weight .strides [1 ],
399
388
stride_bn = layer .moe_ffn1_weight .strides [2 ],
400
- stride_cm = intermediate_cache1 .strides [0 ],
401
- stride_cn = intermediate_cache1 .strides [1 ],
389
+ stride_cm = ffn1_out .strides [0 ],
390
+ stride_cn = ffn1_out .strides [1 ],
402
391
#
403
392
stride_asm = - 1 , # only used in blockwise fp8
404
393
stride_ask = - 1 , # only used in blockwise fp8
@@ -420,11 +409,11 @@ def apply(
420
409
even_Ks = hidden_size % config_ffn1 ["BLOCK_SIZE_K" ] == 0 ,
421
410
)
422
411
423
- intermediate_cache2 = paddle .incubate .nn .functional .swiglu (
424
- intermediate_cache1 )
412
+ ffn2_input = paddle .incubate .nn .functional .swiglu (
413
+ ffn1_out )
425
414
426
- intermediate_cache2 = fastdeploy .model_executor .ops .gpu .moe_fused_hadamard_quant_fp8 (
427
- intermediate_cache2 ,
415
+ ffn2_input = fastdeploy .model_executor .ops .gpu .moe_fused_hadamard_quant_fp8 (
416
+ ffn2_input ,
428
417
scale = layer .moe_ffn2_in_scale ,
429
418
topk_ids = topk_ids ,
430
419
top_k = top_k ,
@@ -438,14 +427,19 @@ def apply(
438
427
"GROUP_SIZE_M" : 1 ,
439
428
}
440
429
430
+ ffn2_out = paddle .empty (
431
+ (token_num * top_k , hidden_size ),
432
+ dtype = x .dtype ,
433
+ )
434
+
441
435
grid = (
442
436
ceil_div (max_possible_num_post_padded , config_ffn2 ["BLOCK_SIZE_M" ]) *
443
437
ceil_div (hidden_size , config_ffn2 ["BLOCK_SIZE_N" ]), )
444
438
445
439
fused_moe_kernel_paddle [grid ](
446
- intermediate_cache2 ,
440
+ ffn2_input ,
447
441
layer .moe_ffn2_weight ,
448
- intermediate_cache3 ,
442
+ ffn2_out ,
449
443
layer .moe_ffn2_in_scale ,
450
444
layer .moe_ffn2_weight_scale ,
451
445
topk_weights ,
@@ -456,13 +450,13 @@ def apply(
456
450
token_num * top_k ,
457
451
N = hidden_size ,
458
452
K = moe_intermediate_size ,
459
- stride_am = intermediate_cache2 .strides [0 ],
460
- stride_ak = intermediate_cache2 .strides [1 ],
453
+ stride_am = ffn2_input .strides [0 ],
454
+ stride_ak = ffn2_input .strides [1 ],
461
455
stride_be = layer .moe_ffn2_weight .strides [0 ],
462
456
stride_bk = layer .moe_ffn2_weight .strides [1 ],
463
457
stride_bn = layer .moe_ffn2_weight .strides [2 ],
464
- stride_cm = intermediate_cache3 .strides [0 ],
465
- stride_cn = intermediate_cache3 .strides [1 ],
458
+ stride_cm = ffn2_out .strides [0 ],
459
+ stride_cn = ffn2_out .strides [1 ],
466
460
stride_asm = - 1 ,
467
461
stride_ask = - 1 ,
468
462
stride_bse = - 1 ,
@@ -483,8 +477,8 @@ def apply(
483
477
even_Ks = moe_intermediate_size % config_ffn2 ["BLOCK_SIZE_K" ] == 0 ,
484
478
)
485
479
486
- intermediate_cache3 .reshape_ ([token_num , top_k , hidden_size ])
487
- out = intermediate_cache3 .sum (axis = 1 )
480
+ ffn2_out .reshape_ ([token_num , top_k , hidden_size ])
481
+ out = ffn2_out .sum (axis = 1 )
488
482
489
483
if layer .tp_size > 1 :
490
484
tensor_model_parallel_all_reduce (out )
0 commit comments