Skip to content

Commit 620a2a6

Browse files
committed
ReplaceStrategy: Add docstring and max_prob param
Idea provided by Alex Sebastian: https://x.com/alexbastian_ai/status/1886253431557587419
1 parent 93acab2 commit 620a2a6

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

strategy/replace_strategy.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,21 @@ def __init__(
1414
max_replacements: int = sys.maxsize,
1515
max_new_tokens_for_replace: int = sys.maxsize,
1616
skip_tokens: int = 0,
17+
max_prob: float = 1
1718
):
19+
"""
20+
find: The list of texts to be replaced.
21+
replace: The text to use as a replacement.
22+
max_replacements: Maximum number of replacements allowed.
23+
max_new_tokens_for_replace: Only allow find-and-replace within the first N generated tokens.
24+
skip_tokens: Only allow find-and-replace after N generated tokens.
25+
max_prob: Only allow find-and-replace if the probability of the first token in the found text is lower than N.
26+
"""
1827
super().__init__(provider, find, with_variants=False, skip_tokens=skip_tokens)
1928
self.replace_ids = self.provider.encode(replace, add_special_tokens=False)
2029
self.max_replacements = max_replacements
2130
self.max_new_tokens_for_replace = max_new_tokens_for_replace
31+
self.max_prob = max_prob
2232
self.reset()
2333

2434
def reset(self) -> None:
@@ -29,16 +39,23 @@ def reset(self) -> None:
2939
def on_logits(
3040
self, logits: torch.Tensor, continuation_tokens: List[int]
3141
) -> torch.Tensor:
42+
return logits
43+
44+
def on_probs(
45+
self, probs: torch.FloatTensor, continuation_tokens: List[int]
46+
) -> torch.FloatTensor:
3247
if self.slop_start_pos is not None:
33-
self.replace_index = 0
34-
self.replaced += 1
48+
for slop_token in self.found_slop_tokens[self.slop_start_pos]:
49+
if probs[:, slop_token] <= self.max_prob:
50+
self.replace_index = 0
51+
self.replaced += 1
3552

3653
if self.replace_index is not None:
37-
logits[:, self.replace_ids[self.replace_index]] = 1e9
54+
probs[:, self.replace_ids[self.replace_index]] = 1e9
3855
self.replace_index = self.replace_index + 1
3956
if self.replace_index >= len(self.replace_ids):
4057
self.replace_index = None
41-
return logits
58+
return probs
4259

4360
def backtrack(self, continuation_tokens: List[int]) -> List[int]:
4461
self.slop_start_pos = None

0 commit comments

Comments
 (0)