diff --git a/custom_ops/gpu_ops/token_penalty_multi_scores.cu b/custom_ops/gpu_ops/token_penalty_multi_scores.cu index a930791e77..7530b7949b 100644 --- a/custom_ops/gpu_ops/token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/token_penalty_multi_scores.cu @@ -20,16 +20,16 @@ __global__ inline void min_length_logits_process(T *logits, const int64_t *min_len, const int64_t *eos_token_id, const int64_t bs, - const int64_t length, - const int64_t end_length) { + const int64_t length_logits, + const int64_t length_eos_token_id) { int bi = threadIdx.x; if (bi >= bs) return; if (cur_len[bi] < 0) { return; } if (cur_len[bi] < min_len[bi]) { - for (int i = 0; i < end_length; i++) { - logits[bi * length + eos_token_id[i]] = -1e10; + for (int i = 0; i < length_eos_token_id; i++) { + logits[bi * length_logits + eos_token_id[i]] = -1e10; } } } @@ -41,61 +41,84 @@ __global__ inline void min_length_logits_process( const int64_t *min_len, const int64_t *eos_token_id, const int64_t bs, - const int64_t length, - const int64_t end_length) { + const int64_t length_logits, + const int64_t length_eos_token_id) { int bi = threadIdx.x; if (bi >= bs) return; if (cur_len[bi] < 0) { return; } if (cur_len[bi] < min_len[bi]) { - for (int i = 0; i < end_length; i++) { - logits[bi * length + eos_token_id[i]] = -1e4; + for (int i = 0; i < length_eos_token_id; i++) { + logits[bi * length_logits + eos_token_id[i]] = -1e4; } } } -__global__ void update_repeat_times(const int64_t *pre_ids, +__global__ void update_repeat_times(const int64_t *input_ids, + const int64_t *first_token_ids, + const int64_t *pre_ids, const int64_t *cur_len, int *repeat_times, + int *is_repeated, const int64_t bs, - const int64_t length, - const int64_t length_id) { - int bi = blockIdx.x; + const int64_t length_logits, + const int64_t length_pre_ids, + const int64_t length_input_ids) { + int64_t bi = blockIdx.x; if (cur_len[bi] < 0) { return; } - int tid = threadIdx.x; - const int64_t *pre_ids_now = pre_ids + bi * length_id; - int *repeat_times_now = repeat_times + bi * length; - for (int i = tid; i < length_id; i += blockDim.x) { - int64_t id = pre_ids_now[i]; - if (id < 0) break; - atomicAdd(&repeat_times_now[id], 1); + int64_t tid = threadIdx.x; + const int64_t *input_ids_now = input_ids + bi * length_input_ids; + const int64_t *pre_ids_now = pre_ids + bi * length_pre_ids; + int *repeat_times_now = repeat_times + bi * length_logits; + int *is_repeated_now = is_repeated + bi * length_logits; + const int64_t loop_len = length_input_ids > length_pre_ids ? length_input_ids : length_pre_ids; + for (int64_t i = tid; i < loop_len; i += blockDim.x) { + if (i < length_pre_ids) { + int64_t id = pre_ids_now[i]; + if (id >= 0) { + atomicAdd(&repeat_times_now[id], 1); + atomicAdd(&is_repeated_now[id], 1); + } + } + if (i > 0 && i < length_input_ids) { + int64_t id = input_ids_now[i]; + if (id >= 0) { + atomicAdd(&is_repeated_now[id], 1); + } + } + } + if (tid == 0) { + atomicAdd(&is_repeated_now[first_token_ids[bi]], 1); } } template __global__ void update_value_by_repeat_times(const int *repeat_times, + const int *is_repeated, const T *penalty_scores, const T *frequency_score, const T *presence_score, const float *temperatures, T *logits, const int64_t bs, - const int64_t length) { + const int64_t length_logits) { int bi = blockIdx.x; int tid = threadIdx.x; - T *logits_now = logits + bi * length; - const int *repeat_times_now = repeat_times + bi * length; + T *logits_now = logits + bi * length_logits; + const int *repeat_times_now = repeat_times + bi * length_logits; float alpha = static_cast(penalty_scores[bi]); float beta = static_cast(frequency_score[bi]); float gamma = static_cast(presence_score[bi]); - for (int i = tid; i < length; i += blockDim.x) { + for (int i = tid; i < length_logits; i += blockDim.x) { int times = repeat_times_now[i]; float logit_now = static_cast(logits_now[i]); - if (times != 0) { + if (is_repeated[i] != 0) { logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; + } + if (times != 0) { logit_now = logit_now - times * beta - gamma; } logits_now[i] = static_cast(logit_now / temperatures[bi]); @@ -106,20 +129,22 @@ template __global__ void ban_bad_words(T *logits, const int64_t *bad_words_list, const int64_t bs, - const int64_t length, - const int64_t bad_words_length) { + const int64_t length_logits, + const int64_t length_bad_words) { const int bi = blockIdx.x; int tid = threadIdx.x; - T *logits_now = logits + bi * length; - for (int i = tid; i < bad_words_length; i += blockDim.x) { + T *logits_now = logits + bi * length_logits; + for (int i = tid; i < length_bad_words; i += blockDim.x) { const int64_t bad_words_token_id = bad_words_list[i]; - if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + if (bad_words_token_id >= length_logits || bad_words_token_id < 0) continue; logits_now[bad_words_token_id] = -1e10; } } template -void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, +void token_penalty_multi_scores_kernel(const paddle::Tensor &input_ids, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &pre_ids, const paddle::Tensor &logits, const paddle::Tensor &penalty_scores, const paddle::Tensor &frequency_score, @@ -141,12 +166,15 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, std::vector shape = logits.shape(); auto repeat_times = paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); + auto is_repeated = + paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); int64_t bs = shape[0]; - int64_t length = shape[1]; - int64_t length_id = pre_ids.shape()[1]; - int64_t length_bad_words = bad_tokens.shape()[0]; - int64_t end_length = eos_token_id.shape()[0]; + int64_t length_logits = shape[1]; + int64_t length_pre_ids = pre_ids.shape()[1]; + int64_t length_bad_words = bad_tokens.shape()[0]; + int64_t length_eos_token_id = eos_token_id.shape()[0]; + int64_t length_input_ids = input_ids.shape()[1]; int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; min_length_logits_process<<<1, block_size, 0, cu_stream>>>( @@ -156,24 +184,28 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, min_len.data(), eos_token_id.data(), bs, - length, - end_length); + length_logits, + length_eos_token_id); - block_size = (length_id + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; + block_size = (length_pre_ids + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; #ifdef PADDLE_WITH_COREX block_size = std::min(block_size, 512); #else block_size = min(block_size, 512); #endif update_repeat_times<<>>( + input_ids.data(), + first_token_ids.data(), pre_ids.data(), cur_len.data(), repeat_times.data(), + is_repeated.data(), bs, - length, - length_id); + length_logits, + length_pre_ids, + length_input_ids); - block_size = (length + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; + block_size = (length_logits + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; #ifdef PADDLE_WITH_COREX block_size = std::min(block_size, 512); #else @@ -181,6 +213,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, #endif update_value_by_repeat_times<<>>( repeat_times.data(), + is_repeated.data(), reinterpret_cast( const_cast(penalty_scores.data())), reinterpret_cast( @@ -191,7 +224,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, reinterpret_cast( const_cast(logits.data())), bs, - length); + length_logits); block_size = (length_bad_words + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; #ifdef PADDLE_WITH_COREX @@ -204,11 +237,13 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, const_cast(logits.data())), bad_tokens.data(), bs, - length, + length_logits, length_bad_words); } -void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, +void TokenPenaltyMultiScores(const paddle::Tensor& input_ids, + const paddle::Tensor& first_token_ids, + const paddle::Tensor &pre_ids, const paddle::Tensor &logits, const paddle::Tensor &penalty_scores, const paddle::Tensor &frequency_scores, @@ -221,7 +256,9 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, switch (logits.type()) { case paddle::DataType::BFLOAT16: { return token_penalty_multi_scores_kernel< - paddle::DataType::BFLOAT16>(pre_ids, + paddle::DataType::BFLOAT16>(input_ids, + first_token_ids, + pre_ids, logits, penalty_scores, frequency_scores, @@ -234,6 +271,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, } case paddle::DataType::FLOAT16: { return token_penalty_multi_scores_kernel( + input_ids, + first_token_ids, pre_ids, logits, penalty_scores, @@ -247,6 +286,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, } case paddle::DataType::FLOAT32: { return token_penalty_multi_scores_kernel( + input_ids, + first_token_ids, pre_ids, logits, penalty_scores, @@ -268,7 +309,9 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, } PD_BUILD_STATIC_OP(get_token_penalty_multi_scores) - .Inputs({"pre_ids", + .Inputs({"input_ids", + "first_token_ids", + "pre_ids", "logits", "penalty_scores", "frequency_scores", diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 41a96ee1e8..400996f749 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -28,6 +28,8 @@ class SamplingMetadata: temperature: paddle.Tensor + input_ids: paddle.Tensor + first_token_ids: paddle.Tensor pre_token_ids: paddle.Tensor eos_token_ids: paddle.Tensor frequency_penalties: paddle.Tensor diff --git a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py index f6b512e0ce..b3a854da8d 100644 --- a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py +++ b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py @@ -20,6 +20,8 @@ def apply_penalty_multi_scores( + input_ids: paddle.Tensor, + first_token_ids: paddle.Tensor, pre_token_ids: paddle.Tensor, logits: paddle.Tensor, repetition_penalties: paddle.Tensor, @@ -38,6 +40,8 @@ def apply_penalty_multi_scores( from fastdeploy.model_executor.ops.gpu import \ get_token_penalty_multi_scores logits = get_token_penalty_multi_scores( + input_ids, + first_token_ids, pre_token_ids, logits, repetition_penalties, diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 162bbc347f..3ca35528f8 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -252,6 +252,8 @@ def forward_cuda( logits = self.processor.apply_token_mask(logits, skip_idx_list) logits = apply_penalty_multi_scores( + sampling_metadata.input_ids, + sampling_metadata.first_token_ids, sampling_metadata.pre_token_ids, logits, sampling_metadata.repetition_penalties, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 71b9547844..5bfd70e2c4 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -608,6 +608,8 @@ def _prepare_inputs(self) -> None: top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], step_idx=self.share_inputs["step_idx"], + input_ids=self.share_inputs["input_ids"], + first_token_ids=self.share_inputs["first_token_ids"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], presence_penalties=self.share_inputs["presence_score"], diff --git a/fastdeploy/worker/vl_gpu_model_runner.py b/fastdeploy/worker/vl_gpu_model_runner.py index 6817023927..c115d7b2cd 100644 --- a/fastdeploy/worker/vl_gpu_model_runner.py +++ b/fastdeploy/worker/vl_gpu_model_runner.py @@ -158,6 +158,9 @@ def update_chunked_prefill(self, tasks: list[any]) -> None: token_chunk_size = inputs["input_ids"].shape[1] self.share_inputs["input_ids"][ idx:idx + 1, :token_chunk_size] = inputs["input_ids"] + self.share_inputs["first_token_ids"][idx:idx + + 1] = inputs["input_ids"][ + idx:idx + 1, :1] self.share_inputs["seq_lens_this_time"][idx:idx + 1] = token_chunk_size self.share_inputs['seq_lens_encoder'][idx:idx + @@ -716,6 +719,9 @@ def get_numeric_value(task, key, default_value): task.set("start_idx", token_chunk_size) self.share_inputs["input_ids"][ idx:idx + 1, :token_chunk_size] = inputs["input_ids"] + self.share_inputs["first_token_ids"][idx:idx + + 1] = inputs["input_ids"][ + idx:idx + 1, :1] self.share_inputs["seq_lens_this_time"][idx:idx + 1] = token_chunk_size self.share_inputs["seq_lens_encoder"][idx:idx + @@ -736,6 +742,9 @@ def get_numeric_value(task, key, default_value): length = inputs["input_ids"].shape[1] self.share_inputs["input_ids"][ idx:idx + 1, :length] = inputs["input_ids"] + self.share_inputs["first_token_ids"][idx:idx + + 1] = inputs["input_ids"][ + idx:idx + 1, :1] self.share_inputs["seq_lens_this_time"][idx:idx + 1] = length self.share_inputs["seq_lens_encoder"][idx:idx + 1] = length self.share_inputs["step_seq_lens_encoder"][idx:idx + @@ -841,6 +850,8 @@ def pre_process(self) -> None: temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], step_idx=self.share_inputs["step_idx"], + input_ids=self.share_inputs["input_ids"], + first_token_ids=self.share_inputs["first_token_ids"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], presence_penalties=self.share_inputs["presence_score"], diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 8be2d9d47c..7194811a47 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -544,6 +544,8 @@ def _prepare_inputs(self) -> None: top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], step_idx=self.share_inputs["step_idx"], + input_ids=self.share_inputs["input_ids"], + first_token_ids=self.share_inputs["first_token_ids"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], presence_penalties=self.share_inputs["presence_score"], diff --git a/test/layers/test_sampler.py b/test/layers/test_sampler.py index 2887400d06..f380d07102 100644 --- a/test/layers/test_sampler.py +++ b/test/layers/test_sampler.py @@ -57,6 +57,8 @@ def _create_default_sampling_metadata( top_p=paddle.full(shape=[batch_size, 1], fill_value=0.7, dtype="float32"), + input_ids=_create_tokens_tensor(batch_size, max_seq_len), + first_token_ids=_create_tokens_tensor(batch_size, max_seq_len)[:, :1], step_idx=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"),