Skip to content

[Perf][MTP] Improve speculative decoding efficiency #2818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 245 additions & 19 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,6 @@ paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
std::string output_dtype,
std::string activation_type);

paddle::Tensor MoeFusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled);

paddle::Tensor FusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const float scale);
Expand Down Expand Up @@ -518,6 +510,213 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);

void free_shared_buffer(int64_t buffer);

// speculative decoding Kernel
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder);

std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);

std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
const int max_seq_len);


void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &output_padding_offset,
const paddle::Tensor &output_cum_offsets,
const int max_seq_len);

void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids);


void SpeculateVerify(
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);

void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums);

void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx);

void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank);


void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);

void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int max_draft_tokens);


// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_stop_flags);


void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill);


void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_ids,
const paddle::Tensor& base_model_draft_tokens,
const int max_seq_len,
const int substep);



std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num);

void MTPStepPaddle(
const paddle::Tensor &base_model_stop_flags,
const paddle::Tensor &stop_flags,
const paddle::Tensor &batch_drop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const int block_size,
const int max_draft_tokens);

void SpeculateStepPaddle(
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const paddle::Tensor &accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens);

PYBIND11_MODULE(fastdeploy_ops, m) {

m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
Expand Down Expand Up @@ -687,9 +886,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* append_attn/get_block_shape_and_split_kv_block.cu
* get_block_shape_and_split_kv_block
*/
// m.def("f_get_block_shape_and_split_kv_block",
// &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block
// function");
m.def("get_block_shape_and_split_kv_block",
&GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function");

/**
* get_padding_offset.cu
Expand Down Expand Up @@ -747,7 +945,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
"text_image_gather_scatter function");

m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);

m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);

m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
Expand Down Expand Up @@ -786,7 +983,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
"dynamic_per_token_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));

m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");

m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
Expand All @@ -802,17 +998,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");

m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");

m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
#endif

m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function");

m.def("all_reduce", &all_reduce, "all reduce function");

m.def("dispose", &dispose, "del function for python");
Expand All @@ -830,4 +1021,39 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");

m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");

// speculative decoding Kernel
m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function");

m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function");

m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function");

m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function");

m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function");

m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");

m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");

m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");

m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function");

m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");

m.def("ngram_match", &NgramMatch, "ngram_match function");

m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");

m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");

m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function");

m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");

m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");

m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
}
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ void token_penalty_multi_scores_kernel(
max_seq_len);
}

void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
Expand Down Expand Up @@ -338,4 +338,4 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
.Outputs({"logits_out"})
.Attrs({"max_seq_len: int"})
.SetInplaceMap({{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
.SetKernelFn(PD_KERNEL(SpecTokenPenaltyMultiScores));
12 changes: 0 additions & 12 deletions custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,6 @@ void SpeculateVerify(
seed++;
offset++;

auto err = cudaDeviceSynchronize();
if (err != 0) {
printf("err %d\n", err);
}

err = cudaGetLastError();

if (err != 0) {
printf("err %d\n", err);
}

// printf("inited curand\n");
bool use_topk = false;
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
if (env_var) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def apply_speculative_penalty_multi_scores(
from fastdeploy.model_executor.ops.gpu import \
speculate_get_token_penalty_multi_scores

logits = speculate_get_token_penalty_multi_scores(
speculate_get_token_penalty_multi_scores(
pre_token_ids,
logits,
repetition_penalties,
Expand All @@ -141,5 +141,5 @@ def apply_speculative_penalty_multi_scores(
)
else:
raise NotImplementedError()

# inplace
return logits
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def pre_process(
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
)
)[0]
output_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output)
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ def _prepare_inputs(self, full_hidden_states):
self.main_model_inputs["seq_lens_encoder"],
self.max_draft_token_num,
)
if isinstance(target_hidden_states, list):
target_hidden_states = target_hidden_states[0]

return target_hidden_states

Expand Down
Loading