diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index 24a8eca4..c35a889f 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -19,6 +19,7 @@ class Settings: token_frequency_penalty: float = 0.0 token_presence_penalty: float = 0.0 + non_rep_penalized_tokens: list[int] = field(default_factory = list) temperature: float = 0.8 smoothing_factor: float = 0.0 @@ -159,14 +160,18 @@ def sample(logits: torch.tensor, if settings.token_repetition_penalty != 1.0 or \ settings.token_frequency_penalty != 0.0 or \ settings.token_presence_penalty != 0.0: - - ext_c.apply_rep_penalty(sequence_ids[:, :], + + hold_back_rep_penalty = len(settings.non_rep_penalized_tokens) > 0 # Revert the logits for certain tokens + if hold_back_rep_penalty: original_logits = torch.clone(logits[:, :]) # Copy the original logits + ext_c.apply_rep_penalty(sequence_ids[:, :], # Apply the repetition penalty from cpp side settings.token_repetition_penalty, settings.token_repetition_range, settings.token_repetition_decay, settings.token_frequency_penalty, settings.token_presence_penalty, logits) + if hold_back_rep_penalty: # Restore the original logits for non-penalized tokens + for i in settings.non_rep_penalized_tokens: logits[:, i] = original_logits[:, i] # Token bias