Skip to content

Commit 61b3997

Browse files
authored
refactor rl get_name_mappings_to_training (#2847)
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
1 parent e7bcbba commit 61b3997

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1599
-1637
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
116116

117117
paddle::Tensor FusedExpertMoeFunc(
118118
const paddle::Tensor &input, const paddle::Tensor &gate_weight,
119-
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
120-
const paddle::optional<paddle::Tensor> &ffn1_bias,
121-
const paddle::optional<paddle::Tensor> &ffn1_scale,
122-
const paddle::optional<paddle::Tensor> &ffn2_bias,
123-
const paddle::optional<paddle::Tensor> &ffn2_scale,
119+
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
120+
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
121+
const paddle::optional<paddle::Tensor> &up_gate_proj_scale,
122+
const paddle::optional<paddle::Tensor> &down_proj_bias,
123+
const paddle::optional<paddle::Tensor> &down_proj_scale,
124124
const std::string &quant_method, const int moe_topk,
125125
const bool norm_topk_prob, const bool group_moe);
126126

@@ -149,7 +149,7 @@ MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits,
149149
std::vector<paddle::Tensor>
150150
EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids,
151151
const paddle::Tensor &topk_weights,
152-
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
152+
const paddle::optional<paddle::Tensor> &up_gate_proj_in_scale,
153153
const std::vector<int> &token_nums_per_expert,
154154
const int token_nums_this_rank,
155155
const std::string &moe_quant_type);
@@ -173,7 +173,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
173173
const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float,
174174
const paddle::Tensor &permute_indices_per_token,
175175
const paddle::Tensor &top_k_indices,
176-
const paddle::optional<paddle::Tensor> &ffn2_bias,
176+
const paddle::optional<paddle::Tensor> &down_proj_bias,
177177
const bool norm_topk_prob, const float routed_scaling_factor);
178178

179179
std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
@@ -182,35 +182,35 @@ std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
182182
paddle::Tensor MoeExpertFFNFunc(
183183
const paddle::Tensor& permute_input,
184184
const paddle::Tensor& tokens_expert_prefix_sum,
185-
const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight,
186-
const paddle::optional<paddle::Tensor>& ffn1_bias,
187-
const paddle::optional<paddle::Tensor>& ffn1_scale,
188-
const paddle::optional<paddle::Tensor>& ffn2_scale,
189-
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
185+
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight,
186+
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
187+
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
188+
const paddle::optional<paddle::Tensor>& down_proj_scale,
189+
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
190190
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
191191
const std::string& quant_method, const bool used_in_ep_low_latency);
192192

193193
paddle::Tensor MoeExpertFFNWint2Func(
194194
const paddle::Tensor& permute_input,
195195
const paddle::Tensor& tokens_expert_prefix_sum,
196-
const paddle::Tensor& ffn1_weight,
197-
const paddle::Tensor& ffn2_weight,
198-
const paddle::optional<paddle::Tensor>& ffn1_bias,
199-
const paddle::optional<paddle::Tensor>& ffn1_scale,
200-
const paddle::optional<paddle::Tensor>& ffn2_scale,
201-
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
202-
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
203-
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
204-
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
205-
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
206-
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
196+
const paddle::Tensor& up_gate_proj_weight,
197+
const paddle::Tensor& down_proj_weight,
198+
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
199+
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
200+
const paddle::optional<paddle::Tensor>& down_proj_scale,
201+
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
202+
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
203+
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
204+
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
205+
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
206+
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
207207
const bool used_in_ep_low_latency);
208208

209209
paddle::Tensor MoeExpertReduceFunc(
210210
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
211211
const paddle::Tensor &permute_indices_per_token,
212212
const paddle::Tensor &top_k_indices,
213-
const paddle::optional<paddle::Tensor> &ffn2_bias,
213+
const paddle::optional<paddle::Tensor> &down_proj_bias,
214214
const bool norm_topk_prob, const float routed_scaling_factor);
215215

216216
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
@@ -816,15 +816,15 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
816816
* ep_moe_dispatch
817817
*/
818818
m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"),
819-
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("ffn1_in_scale"),
819+
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"),
820820
py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"),
821821
py::arg("moe_quant_type"), "ep moe export dispatch function");
822822

823823
m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8);
824824

825825
m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"),
826826
py::arg("expert_scales_float"), py::arg("permute_indices_per_token"),
827-
py::arg("top_k_indices"), py::arg("ffn2_bias"),
827+
py::arg("top_k_indices"), py::arg("down_proj_bias"),
828828
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
829829
"ep moe export combine function");
830830

@@ -866,7 +866,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
866866
*/
867867
m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"),
868868
py::arg("top_k_weight"), py::arg("permute_indices_per_token"),
869-
py::arg("top_k_indices"), py::arg("ffn2_bias"),
869+
py::arg("top_k_indices"), py::arg("down_proj_bias"),
870870
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
871871
"moe export reduce function");
872872

custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
ffn1_n=7168
16-
ffn1_k=8192
15+
up_gate_proj_n=7168
16+
up_gate_proj_k=8192
1717

18-
ffn2_n=8192
19-
ffn2_k=3584
20-
rm -rf ffn1_7168_8192.log
21-
rm -rf ffn2_8192_3584.log
18+
down_proj_n=8192
19+
down_proj_k=3584
20+
rm -rf up_gate_proj_7168_8192.log
21+
rm -rf down_proj_8192_3584.log
2222
num_experts=8
2323

2424
for tokens_per_expert in 12
2525

2626
do
2727
wait
28-
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 1 0 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
29-
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 1 0 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
28+
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
29+
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
3030
done
3131
wait
3232
echo "#### finish ####"

custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ __global__ void combine_prmt_back_kernel(
161161
expanded_permuted_rows + expanded_permuted_row * cols; // prmt后的位置对应的值
162162
Load<T, VEC_SIZE>(expanded_permuted_rows_row_ptr + tid * VEC_SIZE, &load_vec);
163163
const int expert_idx = expert_for_source_row[k_offset]; // 当前位置对应的专家
164-
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的ffn2的bias
164+
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的down_proj的bias
165165
if (bias_ptr) {
166166
Load<T, VEC_SIZE>(bias_ptr + tid * VEC_SIZE, &bias_vec);
167167
#pragma unroll
@@ -188,7 +188,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
188188
const paddle::Tensor& expert_scales_float,
189189
const paddle::Tensor& permute_indices_per_token,
190190
const paddle::Tensor& top_k_indices,
191-
const paddle::optional<paddle::Tensor>& ffn2_bias,
191+
const paddle::optional<paddle::Tensor>& down_proj_bias,
192192
const bool norm_topk_prob,
193193
const float routed_scaling_factor,
194194
const int num_rows,
@@ -206,7 +206,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
206206
combine_prmt_back_kernel<<<gridx, threads, 0, stream>>>(
207207
ffn_out.data<data_t>(),
208208
output->data<data_t>(),
209-
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
209+
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
210210
expert_scales_float.data<float>(),
211211
permute_indices_per_token.data<int32_t>(),
212212
top_k_indices.data<int>(),
@@ -223,7 +223,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
223223
const paddle::Tensor& expert_scales_float, // dst_weights
224224
const paddle::Tensor& permute_indices_per_token, // permute_indices_per_token
225225
const paddle::Tensor& top_k_indices, // dst_indices
226-
const paddle::optional<paddle::Tensor>& ffn2_bias,
226+
const paddle::optional<paddle::Tensor>& down_proj_bias,
227227
const bool norm_topk_prob,
228228
const float routed_scaling_factor) {
229229

@@ -242,7 +242,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
242242
expert_scales_float,
243243
permute_indices_per_token,
244244
top_k_indices,
245-
ffn2_bias,
245+
down_proj_bias,
246246
norm_topk_prob,
247247
routed_scaling_factor,
248248
num_rows,
@@ -255,7 +255,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
255255
expert_scales_float,
256256
permute_indices_per_token,
257257
top_k_indices,
258-
ffn2_bias,
258+
down_proj_bias,
259259
norm_topk_prob,
260260
routed_scaling_factor,
261261
num_rows,
@@ -274,7 +274,7 @@ __global__ void permute_x_kernel(const T *src_x,
274274
const int64_t *topk_idx,
275275
const float *topk_weights,
276276
const int *token_nums_per_expert,
277-
const float *ffn1_in_scale,
277+
const float *up_gate_proj_in_scale,
278278
const int moe_topk,
279279
const int num_rows,
280280
const int token_nums_this_rank,
@@ -327,9 +327,9 @@ __global__ void permute_x_kernel(const T *src_x,
327327
// cp x
328328
for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
329329
Load<T, vec_size>(&src_x[s_token_idx * hidden_size + v_id * vec_size], &src_vec);
330-
if (ffn1_in_scale) {
330+
if (up_gate_proj_in_scale) {
331331
for (int i = 0; i < vec_size; i++) {
332-
float quant_value = max_bound * ffn1_in_scale[expert_now] * static_cast<float>(src_vec[i]);
332+
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
333333
if (RoundType == 0) {
334334
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
335335
} else {
@@ -353,7 +353,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
353353
const paddle::Tensor& topk_ids,
354354
const paddle::Tensor& topk_weights,
355355
const paddle::Tensor& token_nums_per_expert,
356-
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
356+
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
357357
const std::string& moe_quant_type,
358358
const int moe_topk,
359359
const int num_rows,
@@ -383,7 +383,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
383383
topk_ids.data<int64_t>(),
384384
topk_weights.data<float>(),
385385
token_nums_per_expert.data<int>(),
386-
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
386+
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
387387
moe_topk,
388388
num_rows,
389389
token_nums_this_rank,
@@ -404,7 +404,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
404404
topk_ids.data<int64_t>(),
405405
topk_weights.data<float>(),
406406
token_nums_per_expert.data<int>(),
407-
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
407+
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
408408
moe_topk,
409409
num_rows,
410410
token_nums_this_rank,
@@ -427,7 +427,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
427427
topk_ids.data<int64_t>(),
428428
topk_weights.data<float>(),
429429
token_nums_per_expert.data<int>(),
430-
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
430+
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
431431
moe_topk,
432432
num_rows,
433433
token_nums_this_rank,
@@ -448,7 +448,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
448448
topk_ids.data<int64_t>(),
449449
topk_weights.data<float>(),
450450
token_nums_per_expert.data<int>(),
451-
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
451+
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
452452
moe_topk,
453453
num_rows,
454454
token_nums_this_rank,
@@ -472,7 +472,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
472472
const paddle::Tensor& input,
473473
const paddle::Tensor& topk_ids,
474474
const paddle::Tensor& topk_weights,
475-
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
475+
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
476476
const std::vector<int>& token_nums_per_expert,
477477
const int token_nums_this_rank,
478478
const std::string& moe_quant_type) {
@@ -516,7 +516,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
516516
topk_ids,
517517
topk_weights,
518518
num_experts_per_rank_tensor,
519-
ffn1_in_scale,
519+
up_gate_proj_in_scale,
520520
moe_quant_type,
521521
moe_topk,
522522
num_rows,
@@ -536,7 +536,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
536536
topk_ids,
537537
topk_weights,
538538
num_experts_per_rank_tensor,
539-
ffn1_in_scale,
539+
up_gate_proj_in_scale,
540540
moe_quant_type,
541541
moe_topk,
542542
num_rows,
@@ -568,7 +568,7 @@ std::vector<std::vector<int64_t>> EPMoeExpertDispatchInferShape(
568568
const std::vector<int64_t>& input_shape,
569569
const std::vector<int64_t>& topk_ids_shape,
570570
const std::vector<int64_t>& topk_weights_shape,
571-
const paddle::optional<std::vector<int64_t>>& ffn1_in_scale_dtype,
571+
const paddle::optional<std::vector<int64_t>>& up_gate_proj_in_scale_dtype,
572572
const std::vector<int>& token_nums_per_expert,
573573
const int token_nums_this_rank) {
574574
int token_rows = -1;
@@ -610,7 +610,7 @@ std::vector<paddle::DataType> EPMoeExpertDispatchInferDtype(
610610

611611
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
612612
.Inputs({"input", "topk_ids", "topk_weights",
613-
paddle::Optional("ffn1_in_scale")})
613+
paddle::Optional("up_gate_proj_in_scale")})
614614
.Outputs({"permute_input",
615615
"permute_indices_per_token",
616616
"token_nums_per_expert_cumsum",

0 commit comments

Comments
 (0)