-
Notifications
You must be signed in to change notification settings - Fork 564
[Feature] support prompt repetition_penalty #2806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 5 commits
7f3ddc4
f7ee8cb
938b77a
0789316
d0a625c
7c04502
e07e350
ffe7f6c
1aa9ce9
3420694
430b7b8
f9c31b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,16 +20,16 @@ __global__ inline void min_length_logits_process(T *logits, | |
const int64_t *min_len, | ||
const int64_t *eos_token_id, | ||
const int64_t bs, | ||
const int64_t length, | ||
const int64_t end_length) { | ||
const int64_t length_logits, | ||
const int64_t length_eos_token_id) { | ||
int bi = threadIdx.x; | ||
if (bi >= bs) return; | ||
if (cur_len[bi] < 0) { | ||
return; | ||
} | ||
if (cur_len[bi] < min_len[bi]) { | ||
for (int i = 0; i < end_length; i++) { | ||
logits[bi * length + eos_token_id[i]] = -1e10; | ||
for (int i = 0; i < length_eos_token_id; i++) { | ||
logits[bi * length_logits + eos_token_id[i]] = -1e10; | ||
} | ||
} | ||
} | ||
|
@@ -41,61 +41,82 @@ __global__ inline void min_length_logits_process<half>( | |
const int64_t *min_len, | ||
const int64_t *eos_token_id, | ||
const int64_t bs, | ||
const int64_t length, | ||
const int64_t end_length) { | ||
const int64_t length_logits, | ||
const int64_t length_eos_token_id) { | ||
int bi = threadIdx.x; | ||
if (bi >= bs) return; | ||
if (cur_len[bi] < 0) { | ||
return; | ||
} | ||
if (cur_len[bi] < min_len[bi]) { | ||
for (int i = 0; i < end_length; i++) { | ||
logits[bi * length + eos_token_id[i]] = -1e4; | ||
for (int i = 0; i < length_eos_token_id; i++) { | ||
logits[bi * length_logits + eos_token_id[i]] = -1e4; | ||
} | ||
} | ||
} | ||
|
||
__global__ void update_repeat_times(const int64_t *pre_ids, | ||
const int64_t *prompt_ids, | ||
const int64_t *prompt_len, | ||
const int64_t *cur_len, | ||
int *repeat_times, | ||
int *is_repeated, | ||
const int64_t bs, | ||
const int64_t length, | ||
const int64_t length_id) { | ||
int bi = blockIdx.x; | ||
const int64_t length_logits, | ||
const int64_t length_pre_ids, | ||
const int64_t length_prompt_ids) { | ||
int64_t bi = blockIdx.x; | ||
if (cur_len[bi] < 0) { | ||
return; | ||
} | ||
int tid = threadIdx.x; | ||
const int64_t *pre_ids_now = pre_ids + bi * length_id; | ||
int *repeat_times_now = repeat_times + bi * length; | ||
for (int i = tid; i < length_id; i += blockDim.x) { | ||
int64_t id = pre_ids_now[i]; | ||
if (id < 0) break; | ||
atomicAdd(&repeat_times_now[id], 1); | ||
const int64_t prompt_len_now = prompt_len[bi]; | ||
int64_t tid = threadIdx.x; | ||
const int64_t *prompt_now = prompt_ids + bi * length_prompt_ids; | ||
const int64_t *pre_ids_now = pre_ids + bi * length_pre_ids; | ||
int *repeat_times_now = repeat_times + bi * length_logits; | ||
int *is_repeated_now = is_repeated + bi * length_logits; | ||
const int64_t loop_len = prompt_len_now > length_pre_ids ? prompt_len_now : length_pre_ids; | ||
for (int64_t i = tid; i < loop_len; i += blockDim.x) { | ||
if (i < length_pre_ids) { | ||
int64_t id = pre_ids_now[i]; | ||
if (id >= 0) { | ||
atomicAdd(&repeat_times_now[id], 1); | ||
atomicAdd(&is_repeated_now[id], 1); | ||
} | ||
} | ||
if (i < prompt_len_now) { | ||
int64_t id = prompt_ids[i]; | ||
if (id >= 0) { | ||
atomicAdd(&is_repeated_now[id], 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 出现在prompt的token不用记入repeat_times吗,这块逻辑是否与模型组确认 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void update_value_by_repeat_times(const int *repeat_times, | ||
const int *is_repeated, | ||
const T *penalty_scores, | ||
const T *frequency_score, | ||
const T *presence_score, | ||
const float *temperatures, | ||
T *logits, | ||
const int64_t bs, | ||
const int64_t length) { | ||
const int64_t length_logits) { | ||
int bi = blockIdx.x; | ||
int tid = threadIdx.x; | ||
T *logits_now = logits + bi * length; | ||
const int *repeat_times_now = repeat_times + bi * length; | ||
T *logits_now = logits + bi * length_logits; | ||
const int *repeat_times_now = repeat_times + bi * length_logits; | ||
float alpha = static_cast<float>(penalty_scores[bi]); | ||
float beta = static_cast<float>(frequency_score[bi]); | ||
float gamma = static_cast<float>(presence_score[bi]); | ||
for (int i = tid; i < length; i += blockDim.x) { | ||
for (int i = tid; i < length_logits; i += blockDim.x) { | ||
int times = repeat_times_now[i]; | ||
float logit_now = static_cast<float>(logits_now[i]); | ||
if (times != 0) { | ||
if (is_repeated[i] != 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 用is_repeated判断和之前的区别是? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; | ||
} | ||
if (times != 0) { | ||
logit_now = logit_now - times * beta - gamma; | ||
} | ||
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]); | ||
|
@@ -106,20 +127,22 @@ template <typename T> | |
__global__ void ban_bad_words(T *logits, | ||
const int64_t *bad_words_list, | ||
const int64_t bs, | ||
const int64_t length, | ||
const int64_t bad_words_length) { | ||
const int64_t length_logits, | ||
const int64_t length_bad_words) { | ||
const int bi = blockIdx.x; | ||
int tid = threadIdx.x; | ||
T *logits_now = logits + bi * length; | ||
for (int i = tid; i < bad_words_length; i += blockDim.x) { | ||
T *logits_now = logits + bi * length_logits; | ||
for (int i = tid; i < length_bad_words; i += blockDim.x) { | ||
const int64_t bad_words_token_id = bad_words_list[i]; | ||
if (bad_words_token_id >= length || bad_words_token_id < 0) continue; | ||
if (bad_words_token_id >= length_logits || bad_words_token_id < 0) continue; | ||
logits_now[bad_words_token_id] = -1e10; | ||
} | ||
} | ||
|
||
template <paddle::DataType D> | ||
void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, | ||
const paddle::Tensor &prompt_ids, | ||
const paddle::Tensor &prompt_len, | ||
const paddle::Tensor &logits, | ||
const paddle::Tensor &penalty_scores, | ||
const paddle::Tensor &frequency_score, | ||
|
@@ -141,12 +164,15 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, | |
std::vector<int64_t> shape = logits.shape(); | ||
auto repeat_times = | ||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); | ||
auto is_repeated = | ||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); | ||
int64_t bs = shape[0]; | ||
int64_t length = shape[1]; | ||
int64_t length_id = pre_ids.shape()[1]; | ||
int64_t length_bad_words = bad_tokens.shape()[0]; | ||
|
||
int64_t end_length = eos_token_id.shape()[0]; | ||
int64_t length_logits = shape[1]; | ||
int64_t length_pre_ids = pre_ids.shape()[1]; | ||
int64_t length_bad_words = bad_tokens.shape()[0]; | ||
int64_t length_eos_token_id = eos_token_id.shape()[0]; | ||
int64_t length_prompt_ids = prompt_ids.shape()[1]; | ||
|
||
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
min_length_logits_process<<<1, block_size, 0, cu_stream>>>( | ||
|
@@ -156,31 +182,36 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, | |
min_len.data<int64_t>(), | ||
eos_token_id.data<int64_t>(), | ||
bs, | ||
length, | ||
end_length); | ||
length_logits, | ||
length_eos_token_id); | ||
|
||
block_size = (length_id + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
block_size = (length_pre_ids + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
#ifdef PADDLE_WITH_COREX | ||
block_size = std::min(block_size, 512); | ||
#else | ||
block_size = min(block_size, 512); | ||
#endif | ||
update_repeat_times<<<bs, block_size, 0, cu_stream>>>( | ||
pre_ids.data<int64_t>(), | ||
prompt_ids.data<int64_t>(), | ||
prompt_len.data<int64_t>(), | ||
cur_len.data<int64_t>(), | ||
repeat_times.data<int>(), | ||
is_repeated.data<int>(), | ||
bs, | ||
length, | ||
length_id); | ||
length_logits, | ||
length_pre_ids, | ||
length_prompt_ids); | ||
|
||
block_size = (length + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
block_size = (length_logits + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
#ifdef PADDLE_WITH_COREX | ||
block_size = std::min(block_size, 512); | ||
#else | ||
block_size = min(block_size, 512); | ||
#endif | ||
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>( | ||
repeat_times.data<int>(), | ||
is_repeated.data<int>(), | ||
reinterpret_cast<DataType_ *>( | ||
const_cast<data_t *>(penalty_scores.data<data_t>())), | ||
reinterpret_cast<DataType_ *>( | ||
|
@@ -191,7 +222,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, | |
reinterpret_cast<DataType_ *>( | ||
const_cast<data_t *>(logits.data<data_t>())), | ||
bs, | ||
length); | ||
length_logits); | ||
|
||
block_size = (length_bad_words + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
#ifdef PADDLE_WITH_COREX | ||
|
@@ -204,11 +235,13 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, | |
const_cast<data_t *>(logits.data<data_t>())), | ||
bad_tokens.data<int64_t>(), | ||
bs, | ||
length, | ||
length_logits, | ||
length_bad_words); | ||
} | ||
|
||
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, | ||
const paddle::Tensor &prompt_ids, | ||
const paddle::Tensor &prompt_len, | ||
const paddle::Tensor &logits, | ||
const paddle::Tensor &penalty_scores, | ||
const paddle::Tensor &frequency_scores, | ||
|
@@ -222,6 +255,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, | |
case paddle::DataType::BFLOAT16: { | ||
return token_penalty_multi_scores_kernel< | ||
paddle::DataType::BFLOAT16>(pre_ids, | ||
prompt_ids, | ||
prompt_len, | ||
logits, | ||
penalty_scores, | ||
frequency_scores, | ||
|
@@ -233,30 +268,34 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, | |
eos_token_id); | ||
} | ||
case paddle::DataType::FLOAT16: { | ||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>( | ||
pre_ids, | ||
logits, | ||
penalty_scores, | ||
frequency_scores, | ||
presence_scores, | ||
temperatures, | ||
bad_tokens, | ||
cur_len, | ||
min_len, | ||
eos_token_id); | ||
return token_penalty_multi_scores_kernel< | ||
paddle::DataType::FLOAT16>(pre_ids, | ||
prompt_ids, | ||
prompt_len, | ||
logits, | ||
penalty_scores, | ||
frequency_scores, | ||
presence_scores, | ||
temperatures, | ||
bad_tokens, | ||
cur_len, | ||
min_len, | ||
eos_token_id); | ||
} | ||
case paddle::DataType::FLOAT32: { | ||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>( | ||
pre_ids, | ||
logits, | ||
penalty_scores, | ||
frequency_scores, | ||
presence_scores, | ||
temperatures, | ||
bad_tokens, | ||
cur_len, | ||
min_len, | ||
eos_token_id); | ||
return token_penalty_multi_scores_kernel< | ||
paddle::DataType::FLOAT32>(pre_ids, | ||
prompt_ids, | ||
prompt_len, | ||
logits, | ||
penalty_scores, | ||
frequency_scores, | ||
presence_scores, | ||
temperatures, | ||
bad_tokens, | ||
cur_len, | ||
min_len, | ||
eos_token_id); | ||
} | ||
default: { | ||
PD_THROW( | ||
|
@@ -269,6 +308,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, | |
|
||
PD_BUILD_STATIC_OP(get_token_penalty_multi_scores) | ||
.Inputs({"pre_ids", | ||
"prompt_ids", | ||
"prompt_len", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这两个是否要作为可选输入?兼容之前的模型 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 计算逻辑上本来就可以兼容,prompt_ids全传-1或者prompt_len全给0即可。 |
||
"logits", | ||
"penalty_scores", | ||
"frequency_scores", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是prompt_now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,我修复下