Skip to content

added implementation of DRY sampler #6839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f64dea0
added implementation of DRY sampler
l3utterfly Apr 25, 2024
aea4ad0
fixed editor config check
l3utterfly Apr 25, 2024
4d603e3
added DRY implementation
l3utterfly Apr 25, 2024
75beda2
fixed various issues with sampler pointed out by original creator
l3utterfly Apr 29, 2024
85dadac
added parameter for DRY penalty range, separate from the original rep…
l3utterfly Apr 29, 2024
793e1e2
updated header def for dry sampler to match implementation
l3utterfly Apr 29, 2024
3caec6b
removed unused llama_context in dry sampler
l3utterfly Apr 29, 2024
49e078f
changed array size parameters to size_t
l3utterfly Apr 29, 2024
2f9a36a
Merge branch 'master' into dry-sampler
l3utterfly Jul 29, 2024
802ddd7
added sample_dry_impl
l3utterfly Jul 29, 2024
12bfa78
added llama_sample_dry_impl in header
l3utterfly Jul 29, 2024
0229fc8
added final new line for editor config check
l3utterfly Jul 29, 2024
236da59
fixed int/size_t comparison
l3utterfly Jul 29, 2024
e862def
use int32_t for dry_penalty_last_n due to negative value needed as co…
l3utterfly Jul 29, 2024
9105cf4
Add DRY sampling parameters to gpt_params and server_context
wwoodsTM Aug 5, 2024
20dc562
Delete pr-6839.diff
wwoodsTM Aug 5, 2024
d1676a1
Merge pull request #29 from wwoodsTM/test-dry-sampler
l3utterfly Aug 6, 2024
ed6b909
Merge branch 'master' into dry-sampler
l3utterfly Aug 6, 2024
6579e64
Attempt at slightly optimized vector of strings DRY implementation
wwoodsTM Aug 6, 2024
a18fb2f
Merge remote-tracking branch 'myfork/test-dry-sampler' into test-dry-…
wwoodsTM Aug 6, 2024
190898a
Merge pull request #30 from wwoodsTM/test-dry-sampler
l3utterfly Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl(

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

// repetition penalties
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;

const bool penalize_nl = params.penalize_nl;

// DRY sampler parameters
const float dry_multiplier = params.dry_multiplier;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too much indentation before the assignment operator for this block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation here is trying to match the ones below from dry_allowed_length. What is the convention here?

image

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean? Both the blocks above the new parameters (ending with penalize_nl) and below them (starting with prev) have the equals signs left-aligned as closely as possible to the LHS, whereas the new parameters have three extra spaces.

But code style is really the maintainers' business. I don't care that much, just something I noticed.

const float dry_base = params.dry_base;
const int dry_allowed_length = params.dry_allowed_length;

auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;

Expand Down Expand Up @@ -309,10 +314,19 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];

// repetition penalties
llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);

// DRY penalties (multiplier > 0 means enabled)
if(dry_multiplier > 0.0f) {
llama_sample_dry(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
}

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
Expand Down
4 changes: 4 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ typedef struct llama_sampling_params {
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
float dry_base = 1.75f;
int dry_allowed_length = 2;

std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
Expand All @@ -61,6 +64,7 @@ typedef struct llama_sampling_params {
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

std::vector<llama_token> penalty_prompt_tokens;
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
bool use_penalty_prompt_tokens = false;
} llama_sampling_params;

Expand Down
84 changes: 84 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13233,6 +13233,90 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
// sanity check
GGML_ASSERT(last_tokens_size > 0);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do all models use BOS tokens? Because if not, this assertion might fail with an empty context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced this with an if check instead. I'm not sure if all models use BOS tokens.


// get the last token
auto last_token = last_tokens[last_tokens_size - 1];

// if last token is part of the sequence breakers, skip whole sampler
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
return;
}

// create an unordered map of "next tokens" <-> max match length
std::unordered_map<llama_token, size_t> match_lengths;

// loop through each previous token (exclude the last token)
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
// skip if the compare token if it's not the same as the last token
if(last_tokens[i] != last_token) {
continue;
}

// get the next token (i + 1 is always less than last_tokens_size)
auto next_token = last_tokens[i + 1];

// try to extend the match backwards (match length starts a 1 because last token is already matched)
size_t match_length = 1;

// loop through the previous tokens
for(;; match_length++) {
// if we have reached the start of our last tokens, break
if(i < match_length) break;

// compare token starts at our prev index, going backwards by match length
auto compare_token = last_tokens[i - match_length];

// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
auto head_token = last_tokens[last_tokens_size - 1 - match_length];

// if compare token is part of the sequence breakers, break out of the match
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
break;

// break out of the match if any tokens don't match
if(compare_token != head_token)
break;
}

// Check if the next token exists in the map
auto it = match_lengths.find(next_token);

if (it == match_lengths.end()) {
// Key does not exist, insert the new value
match_lengths[next_token] = match_length;
} else {
// Key exists, update it with the max of the new value or the existing value
it->second = std::max(it->second, match_length);
}
}

// apply penalties
for (const auto& pair : match_lengths) {
auto next_token = pair.first;
auto match_length = pair.second;

// if the match length is greater than our allowed length in config, we apply penalities
if(match_length > dry_allowed_length) {

// find our next token in the candidates->data
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't the candidates indices equal to the token ID? In Transformers, this is the case, which is why the original PR doesn't need to search.

If this isn't true for llama.cpp, how are the candidates ordered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked through the creation of candidates. It appears it is true for this case (that the token ID = indices), but it may not always be true. It appears the candidates structure has a flag bool sorted, if it's true, then the candidates are sorted by logits descending.

We can check for that condition here? But I cannot determine if the candidates are guaranteed to have indices = token ID if sorted = false

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I guess the purpose of sorting by logit is to simplify truncation samplers.

Probably best to keep the current code then. There are of course possible optimizations (such as interchanging the two loops and deleting tokens from match_lengths once they have been found, which should roughly cut the execution time in half), but I'm not sure if they are worth the extra complexity.

size_t i = 0;
for (; i < candidates->size; ++i) {
if (candidates->data[i].id == next_token) {
// calculate the penalty
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);

// apply the dry penalty
candidates->data[i].logit -= penalty;
break;
}
}
}
}
}

void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;
Expand Down
12 changes: 12 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,18 @@ extern "C" {
float p,
size_t min_keep);

/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
LLAMA_API void llama_sample_dry(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
int last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
const llama_token * seq_breakers,
int seq_breakers_size);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,
Expand Down