Skip to content

[Feature] support prompt repetition_penalty #2806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
165 changes: 103 additions & 62 deletions custom_ops/gpu_ops/token_penalty_multi_scores.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ __global__ inline void min_length_logits_process(T *logits,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t bs,
const int64_t length,
const int64_t end_length) {
const int64_t length_logits,
const int64_t length_eos_token_id) {
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < end_length; i++) {
logits[bi * length + eos_token_id[i]] = -1e10;
for (int i = 0; i < length_eos_token_id; i++) {
logits[bi * length_logits + eos_token_id[i]] = -1e10;
}
}
}
Expand All @@ -41,61 +41,82 @@ __global__ inline void min_length_logits_process<half>(
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t bs,
const int64_t length,
const int64_t end_length) {
const int64_t length_logits,
const int64_t length_eos_token_id) {
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < end_length; i++) {
logits[bi * length + eos_token_id[i]] = -1e4;
for (int i = 0; i < length_eos_token_id; i++) {
logits[bi * length_logits + eos_token_id[i]] = -1e4;
}
}
}

__global__ void update_repeat_times(const int64_t *pre_ids,
const int64_t *prompt_ids,
const int64_t *prompt_len,
const int64_t *cur_len,
int *repeat_times,
int *is_repeated,
const int64_t bs,
const int64_t length,
const int64_t length_id) {
int bi = blockIdx.x;
const int64_t length_logits,
const int64_t length_pre_ids,
const int64_t length_prompt_ids) {
int64_t bi = blockIdx.x;
if (cur_len[bi] < 0) {
return;
}
int tid = threadIdx.x;
const int64_t *pre_ids_now = pre_ids + bi * length_id;
int *repeat_times_now = repeat_times + bi * length;
for (int i = tid; i < length_id; i += blockDim.x) {
int64_t id = pre_ids_now[i];
if (id < 0) break;
atomicAdd(&repeat_times_now[id], 1);
const int64_t prompt_len_now = prompt_len[bi];
int64_t tid = threadIdx.x;
const int64_t *prompt_now = prompt_ids + bi * length_prompt_ids;
const int64_t *pre_ids_now = pre_ids + bi * length_pre_ids;
int *repeat_times_now = repeat_times + bi * length_logits;
int *is_repeated_now = is_repeated + bi * length_logits;
const int64_t loop_len = prompt_len_now > length_pre_ids ? prompt_len_now : length_pre_ids;
for (int64_t i = tid; i < loop_len; i += blockDim.x) {
if (i < length_pre_ids) {
int64_t id = pre_ids_now[i];
if (id >= 0) {
atomicAdd(&repeat_times_now[id], 1);
atomicAdd(&is_repeated_now[id], 1);
}
}
if (i < prompt_len_now) {
int64_t id = prompt_ids[i];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是prompt_now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes,我修复下

if (id >= 0) {
atomicAdd(&is_repeated_now[id], 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

出现在prompt的token不用记入repeat_times吗,这块逻辑是否与模型组确认

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
}
}
}

template <typename T>
__global__ void update_value_by_repeat_times(const int *repeat_times,
const int *is_repeated,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int64_t bs,
const int64_t length) {
const int64_t length_logits) {
int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * length;
const int *repeat_times_now = repeat_times + bi * length;
T *logits_now = logits + bi * length_logits;
const int *repeat_times_now = repeat_times + bi * length_logits;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
for (int i = tid; i < length; i += blockDim.x) {
for (int i = tid; i < length_logits; i += blockDim.x) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (times != 0) {
if (is_repeated[i] != 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用is_repeated判断和之前的区别是?

Copy link
Collaborator Author

@ming1753 ming1753 Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
}
if (times != 0) {
logit_now = logit_now - times * beta - gamma;
}
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
Expand All @@ -106,20 +127,22 @@ template <typename T>
__global__ void ban_bad_words(T *logits,
const int64_t *bad_words_list,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length) {
const int64_t length_logits,
const int64_t length_bad_words) {
const int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * length;
for (int i = tid; i < bad_words_length; i += blockDim.x) {
T *logits_now = logits + bi * length_logits;
for (int i = tid; i < length_bad_words; i += blockDim.x) {
const int64_t bad_words_token_id = bad_words_list[i];
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
if (bad_words_token_id >= length_logits || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
}

template <paddle::DataType D>
void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
const paddle::Tensor &prompt_ids,
const paddle::Tensor &prompt_len,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_score,
Expand All @@ -141,12 +164,15 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
std::vector<int64_t> shape = logits.shape();
auto repeat_times =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
auto is_repeated =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = shape[0];
int64_t length = shape[1];
int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];

int64_t end_length = eos_token_id.shape()[0];
int64_t length_logits = shape[1];
int64_t length_pre_ids = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];
int64_t length_eos_token_id = eos_token_id.shape()[0];
int64_t length_prompt_ids = prompt_ids.shape()[1];

int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
Expand All @@ -156,31 +182,36 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bs,
length,
end_length);
length_logits,
length_eos_token_id);

block_size = (length_id + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
block_size = (length_pre_ids + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
#else
block_size = min(block_size, 512);
#endif
update_repeat_times<<<bs, block_size, 0, cu_stream>>>(
pre_ids.data<int64_t>(),
prompt_ids.data<int64_t>(),
prompt_len.data<int64_t>(),
cur_len.data<int64_t>(),
repeat_times.data<int>(),
is_repeated.data<int>(),
bs,
length,
length_id);
length_logits,
length_pre_ids,
length_prompt_ids);

block_size = (length + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
block_size = (length_logits + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
#else
block_size = min(block_size, 512);
#endif
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>(
repeat_times.data<int>(),
is_repeated.data<int>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_ *>(
Expand All @@ -191,7 +222,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
bs,
length);
length_logits);

block_size = (length_bad_words + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
Expand All @@ -204,11 +235,13 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
const_cast<data_t *>(logits.data<data_t>())),
bad_tokens.data<int64_t>(),
bs,
length,
length_logits,
length_bad_words);
}

void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &prompt_ids,
const paddle::Tensor &prompt_len,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
Expand All @@ -222,6 +255,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
case paddle::DataType::BFLOAT16: {
return token_penalty_multi_scores_kernel<
paddle::DataType::BFLOAT16>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
Expand All @@ -233,30 +268,34 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
eos_token_id);
}
case paddle::DataType::FLOAT16: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
return token_penalty_multi_scores_kernel<
paddle::DataType::FLOAT16>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
case paddle::DataType::FLOAT32: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
return token_penalty_multi_scores_kernel<
paddle::DataType::FLOAT32>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
default: {
PD_THROW(
Expand All @@ -269,6 +308,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,

PD_BUILD_STATIC_OP(get_token_penalty_multi_scores)
.Inputs({"pre_ids",
"prompt_ids",
"prompt_len",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个是否要作为可选输入?兼容之前的模型

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

计算逻辑上本来就可以兼容,prompt_ids全传-1或者prompt_len全给0即可。
不过现在其实没有暴露上层的兼容接口。

"logits",
"penalty_scores",
"frequency_scores",
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/model_executor/layers/sample/meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class SamplingMetadata:
presence_penalties: paddle.Tensor
repetition_penalties: paddle.Tensor

prompt_ids: Optional[paddle.Tensor] = None
prompt_lens: Optional[paddle.Tensor] = None

min_dec_lens: paddle.Tensor

bad_words_token_ids: paddle.Tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

def apply_penalty_multi_scores(
pre_token_ids: paddle.Tensor,
prompt_ids: paddle.Tensor,
prompt_lens: paddle.Tensor,
logits: paddle.Tensor,
repetition_penalties: paddle.Tensor,
frequency_penalties: paddle.Tensor,
Expand All @@ -39,6 +41,8 @@ def apply_penalty_multi_scores(
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
prompt_ids,
prompt_lens,
logits,
repetition_penalties,
frequency_penalties,
Expand Down Expand Up @@ -69,6 +73,8 @@ def apply_penalty_multi_scores(
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
prompt_ids,
prompt_lens,
logits,
repetition_penalties,
frequency_penalties,
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def forward_cuda(

logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
sampling_metadata.prompt_lens,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
Expand Down
Loading
Loading