@@ -14,11 +14,21 @@ def __init__(
14
14
max_replacements : int = sys .maxsize ,
15
15
max_new_tokens_for_replace : int = sys .maxsize ,
16
16
skip_tokens : int = 0 ,
17
+ max_prob : float = 1
17
18
):
19
+ """
20
+ find: The list of texts to be replaced.
21
+ replace: The text to use as a replacement.
22
+ max_replacements: Maximum number of replacements allowed.
23
+ max_new_tokens_for_replace: Only allow find-and-replace within the first N generated tokens.
24
+ skip_tokens: Only allow find-and-replace after N generated tokens.
25
+ max_prob: Only allow find-and-replace if the probability of the first token in the found text is lower than N.
26
+ """
18
27
super ().__init__ (provider , find , with_variants = False , skip_tokens = skip_tokens )
19
28
self .replace_ids = self .provider .encode (replace , add_special_tokens = False )
20
29
self .max_replacements = max_replacements
21
30
self .max_new_tokens_for_replace = max_new_tokens_for_replace
31
+ self .max_prob = max_prob
22
32
self .reset ()
23
33
24
34
def reset (self ) -> None :
@@ -29,16 +39,23 @@ def reset(self) -> None:
29
39
def on_logits (
30
40
self , logits : torch .Tensor , continuation_tokens : List [int ]
31
41
) -> torch .Tensor :
42
+ return logits
43
+
44
+ def on_probs (
45
+ self , probs : torch .FloatTensor , continuation_tokens : List [int ]
46
+ ) -> torch .FloatTensor :
32
47
if self .slop_start_pos is not None :
33
- self .replace_index = 0
34
- self .replaced += 1
48
+ for slop_token in self .found_slop_tokens [self .slop_start_pos ]:
49
+ if probs [:, slop_token ] <= self .max_prob :
50
+ self .replace_index = 0
51
+ self .replaced += 1
35
52
36
53
if self .replace_index is not None :
37
- logits [:, self .replace_ids [self .replace_index ]] = 1e9
54
+ probs [:, self .replace_ids [self .replace_index ]] = 1e9
38
55
self .replace_index = self .replace_index + 1
39
56
if self .replace_index >= len (self .replace_ids ):
40
57
self .replace_index = None
41
- return logits
58
+ return probs
42
59
43
60
def backtrack (self , continuation_tokens : List [int ]) -> List [int ]:
44
61
self .slop_start_pos = None
0 commit comments