diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 35e02e0149..4f4bb29564 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -480,14 +480,6 @@ paddle::Tensor cutlass_fp8_fp8_half_gemm_func( std::string output_dtype, std::string activation_type); -paddle::Tensor MoeFusedHadamardQuantFp8Func( - const paddle::Tensor &input, - const paddle::Tensor &scale, - const paddle::Tensor &topk_ids, - const int top_k, - const int intermediate_size, - const bool tiled); - paddle::Tensor FusedHadamardQuantFp8Func( const paddle::Tensor &input, const float scale); @@ -518,6 +510,213 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); +// speculative decoding Kernel +std::vector SpeculateGetPaddingOffset( + const paddle::Tensor& input_ids, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len, + const paddle::Tensor& seq_lens_encoder); + +std::vector SpeculateGetSeqLensOutput( + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder); + +std::vector SpeculateGetOutputPaddingOffset( + const paddle::Tensor& output_cum_offsets_tmp, + const paddle::Tensor& out_token_num, + const paddle::Tensor& seq_lens_output, + const int max_seq_len); + + +void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids, + const paddle::Tensor &logits, + const paddle::Tensor &penalty_scores, + const paddle::Tensor &frequency_scores, + const paddle::Tensor &presence_scores, + const paddle::Tensor &temperatures, + const paddle::Tensor &bad_tokens, + const paddle::Tensor &cur_len, + const paddle::Tensor &min_len, + const paddle::Tensor &eos_token_id, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &output_padding_offset, + const paddle::Tensor &output_cum_offsets, + const int max_seq_len); + +void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens, + const paddle::Tensor &stop_seqs, + const paddle::Tensor &stop_seqs_len, + const paddle::Tensor &end_ids); + + +void SpeculateVerify( + const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num, + const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores, + const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &output_cum_offsets, + const paddle::Tensor &actual_candidate_len, + const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, + int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode); + +void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &actual_draft_token_nums, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &is_block_step, + const paddle::Tensor &stop_nums); + +void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_idx); + +void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + bool save_each_rank); + + +void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder); + +void NgramMatch(const paddle::Tensor &input_ids, + const paddle::Tensor &input_ids_len, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &draft_token_num, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &max_dec_len, + const int max_ngram_size, + const int max_draft_tokens); + + +// MTP +void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_stop_flags); + + +void DraftModelPreprocess(const paddle::Tensor& draft_tokens, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& batch_drop, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_seq_lens_decoder, + const paddle::Tensor& base_model_step_idx, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& base_model_is_block_step, + const paddle::Tensor& base_model_draft_tokens, + const int max_draft_token, + const bool truncate_first_token, + const bool splitwise_prefill); + + +void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& pre_ids, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& end_ids, + const paddle::Tensor& base_model_draft_tokens, + const int max_seq_len, + const int substep); + + + +std::vector EagleGetHiddenStates( + const paddle::Tensor& input, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& stop_flags, + const paddle::Tensor& accept_nums, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const int actual_draft_token_num); + +void MTPStepPaddle( + const paddle::Tensor &base_model_stop_flags, + const paddle::Tensor &stop_flags, + const paddle::Tensor &batch_drop, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const int block_size, + const int max_draft_tokens); + +void SpeculateStepPaddle( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -687,9 +886,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * append_attn/get_block_shape_and_split_kv_block.cu * get_block_shape_and_split_kv_block */ - // m.def("f_get_block_shape_and_split_kv_block", - // &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block - // function"); + m.def("get_block_shape_and_split_kv_block", + &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function"); /** * get_padding_offset.cu @@ -747,7 +945,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { "text_image_gather_scatter function"); m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func); - m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel); m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi, @@ -786,7 +983,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant, "dynamic_per_token_scaled_fp8_quant function", py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub")); - m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function"); m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function"); @@ -802,17 +998,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"), py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"), py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function"); - - m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func, - py::arg("input"), py::arg("scale"), py::arg("topk_ids"), - py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function"); - m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func, py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function"); #endif m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function"); - + m.def("all_reduce", &all_reduce, "all reduce function"); m.def("dispose", &dispose, "del function for python"); @@ -830,4 +1021,39 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("open_mem_handle", &open_mem_handle, "open_mem_handle"); m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta"); + + // speculative decoding Kernel + m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function"); + + m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function"); + + m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function"); + + m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function"); + + m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function"); + + m.def("speculate_verify",&SpeculateVerify, "speculate_verify function"); + + m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function"); + + m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function"); + + m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function"); + + m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function"); + + m.def("ngram_match", &NgramMatch, "ngram_match function"); + + m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function"); + + m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function"); + + m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function"); + + m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function"); + + m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function"); + + m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function"); } diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_token_penalty_multi_scores.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_token_penalty_multi_scores.cu index 4c8fc7a44b..e2a6405d47 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_token_penalty_multi_scores.cu @@ -246,7 +246,7 @@ void token_penalty_multi_scores_kernel( max_seq_len); } -void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, +void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids, const paddle::Tensor &logits, const paddle::Tensor &penalty_scores, const paddle::Tensor &frequency_scores, @@ -338,4 +338,4 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores) .Outputs({"logits_out"}) .Attrs({"max_seq_len: int"}) .SetInplaceMap({{"logits", "logits_out"}}) - .SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores)); + .SetKernelFn(PD_KERNEL(SpecTokenPenaltyMultiScores)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu index 7306843742..aa62356873 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu @@ -266,18 +266,6 @@ void SpeculateVerify( seed++; offset++; - auto err = cudaDeviceSynchronize(); - if (err != 0) { - printf("err %d\n", err); - } - - err = cudaGetLastError(); - - if (err != 0) { - printf("err %d\n", err); - } - - // printf("inited curand\n"); bool use_topk = false; char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK"); if (env_var) { diff --git a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py index f6b512e0ce..4da6280901 100644 --- a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py +++ b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py @@ -123,7 +123,7 @@ def apply_speculative_penalty_multi_scores( from fastdeploy.model_executor.ops.gpu import \ speculate_get_token_penalty_multi_scores - logits = speculate_get_token_penalty_multi_scores( + speculate_get_token_penalty_multi_scores( pre_token_ids, logits, repetition_penalties, @@ -141,5 +141,5 @@ def apply_speculative_penalty_multi_scores( ) else: raise NotImplementedError() - + # inplace return logits diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 5ba3485746..825bca1b4b 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -100,7 +100,7 @@ def pre_process( seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, - ) + )[0] output_token_num = paddle.sum(seq_lens_output) output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output) output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset( diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index cf24a7e578..a173109a8c 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -497,6 +497,8 @@ def _prepare_inputs(self, full_hidden_states): self.main_model_inputs["seq_lens_encoder"], self.max_draft_token_num, ) + if isinstance(target_hidden_states, list): + target_hidden_states = target_hidden_states[0] return target_hidden_states