Skip to content

Commit ed0ec4e

Browse files
committed
Fix mute tokens dead byte
1 parent 2a8f6b0 commit ed0ec4e

File tree

7 files changed

+24
-9
lines changed

7 files changed

+24
-9
lines changed

benchmarks/bench_regex_guide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class RegexIndexBenchmark:
2121
params = regex_samples.keys()
2222

2323
def setup(self, pattern_name):
24-
self.vocabulary = Vocabulary.from_pretrained("gpt2")
24+
self.vocabulary = Vocabulary.from_pretrained("unsloth/Llama-3.1-8B-Instruct")
2525
self.pattern = regex_samples[pattern_name]
2626

2727
def time_regex_to_guide(self, pattern_name):

benchmarks/test_index_time.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import timeit
2+
from outlines_core import Index, Vocabulary
3+
4+
regex_samples = {
5+
"email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?",
6+
# Ajoute d'autres regex si nécessaire
7+
}
8+
9+
# Initialisation du Vocabulary avant la mesure
10+
vocabulary = Vocabulary.from_pretrained("unsloth/Llama-3.1-8B-Instruct")
11+
pattern = regex_samples["email"]
12+
# Code de setup (ne contient que l'importation et la définition de pattern)
13+
setup_code = "from outlines_core import Index"
14+
# Mesure uniquement la construction de l'Index
15+
stmt = "Index(pattern, vocabulary)"
16+
execution_time = timeit.timeit(stmt, setup=setup_code, globals=locals(), number=1)
17+
print(f"Temps d'exécution pour une construction froide de l'Index (Vocabulary pré-initialisé) : {execution_time} secondes")

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
binding=Binding.PyO3,
1414
features=["python-bindings"],
1515
rustc_flags=["--crate-type=cdylib"],
16+
debug=False,
1617
),
1718
]
1819

src/tokens_dfa/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ impl TokensDFA
5757
//let start_muting = Instant::now();
5858
let (muted_regex, muted_list) = mute_literals(regex, vocabulary, &mut additionnal_tokens);
5959
//let time_muting = start_muting.elapsed();
60+
// println!("> Muted Regex : {}", muted_regex);
6061
// println!("> Muted List {:?}", muted_list);
6162
// println!("> Additionnal : {:?}", additionnal_tokens);
6263
let alphabet_len = vocabulary.len_alphabet()+additionnal_tokens.len(); // Real number of different token_id

src/tokens_dfa/reduce.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ pub fn init_classes_and_graph_optimized(
3838
if token_ids[0] == 216 {return None;} // BUG IN THE VOCABULARY. token_id 216 ""\u011c"" is interpreted as \x1C, (Byte 28)
3939
let t_class = get_token_class(token, byte_classes);
4040

41-
if t_class.as_bytes().iter().any(|byte| dead_byte_classes.contains(byte)) {
41+
if t_class.as_bytes().iter().any(|byte| dead_byte_classes.contains(byte)) && !token_ids.iter().any(|&id| additionnal_tokens.iter().any(|(_, add_id)| *add_id == id)) {
4242
return None;
4343
}
4444

@@ -67,7 +67,6 @@ pub fn init_classes_and_graph_optimized(
6767
transitions_table.get_token_ids_by_class().add_token_id(class_id, id);
6868
// Avoid to override the Tokens class of muted token.
6969
if additionnal_classes.contains(&t_class){
70-
7170
continue;
7271

7372
}
@@ -313,9 +312,7 @@ fn decompose_all_literals_optimized(
313312

314313
token_sequence.reverse();
315314
result.insert(literal.clone(), (token_sequence, positions.clone()));
316-
} else {
317-
println!("Aucune décomposition trouvée pour le littéral: {}", literal);
318-
}
315+
}
319316
}
320317

321318
result

src/tokens_dfa/transitions_table.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ impl MasksTable {
164164
let mut real_class_id = class_id.clone();
165165

166166
if tokens.len() == 1 && muted_list.contains(&tokens[0]) {
167-
168167
real_class_id = self.token_classes[tokens[0]]
169168
}
170169

src/v2_index.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ mod tests {
159159

160160
#[test]
161161
fn test_sample(){
162-
let regex = r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?";
162+
let regex = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?";
163163
//let sch =r#"{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "complexe_phone": {"type": "string", "pattern": "\\+?\\d{1,4}?[-. ]?\\(\\d{1,3}\\)?[-. ]?\\d{1,4}[-. ]?\\d{1,4}[-. ]?\\d{1,9}"}}, "required": ["name", "age", "complexe_phone"]}"#;
164164
//let regex = &json_schema::regex_from_str(sch, None).unwrap();
165-
println!("{}", regex);
165+
println!("{}", regex);
166166
let model_name = "unsloth/Llama-3.1-8B-Instruct";
167167
let vocab = Vocabulary::from_pretrained(model_name, None).unwrap();
168168

0 commit comments

Comments
 (0)