@@ -162,10 +162,10 @@ fn hash_current_thread() -> usize {
162
162
// that works great for our use case of avoiding collisions in our array. Unfortunately,
163
163
// it's private. However, there are only so many ways you can layout a u64, so just transmute
164
164
// https://github.com/rust-lang/rust/issues/67939
165
- const _: [ u8 ; 8 ] = [ 0 ; std:: mem:: size_of :: < std :: thread:: ThreadId > ( ) ] ;
165
+ const _: [ u8 ; 8 ] = [ 0 ; std:: mem:: size_of :: < thread:: ThreadId > ( ) ] ;
166
166
const _: [ u8 ; 8 ] = [ 0 ; std:: mem:: size_of :: < FakeThreadId > ( ) ] ;
167
167
let x = unsafe {
168
- std:: mem:: transmute :: < std :: thread:: ThreadId , FakeThreadId > ( thread:: current ( ) . id ( ) ) . 0
168
+ std:: mem:: transmute :: < thread:: ThreadId , FakeThreadId > ( thread:: current ( ) . id ( ) ) . 0
169
169
} ;
170
170
u64:: from ( x) as usize
171
171
}
@@ -214,11 +214,10 @@ impl CoreBPE {
214
214
let mut ret = vec ! [ ] ;
215
215
for mat in regex. find_iter ( text) {
216
216
let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
217
- if let Some ( token ) = self . encoder . get ( piece) {
218
- ret. push ( * token) ;
219
- continue ;
217
+ match self . encoder . get ( piece) {
218
+ Some ( token ) => ret. push ( * token) ,
219
+ None => ret . extend ( & byte_pair_encode ( piece , & self . encoder ) ) ,
220
220
}
221
- ret. extend ( & byte_pair_encode ( piece, & self . encoder ) ) ;
222
221
}
223
222
ret
224
223
}
@@ -516,7 +515,10 @@ impl CoreBPE {
516
515
unstable_bytes. extend_from_slice ( & bytes[ e. valid_up_to ( ) ..] ) ;
517
516
518
517
tokens. truncate ( tokens. len ( ) - last_piece_token_len) ;
519
- tokens. extend ( byte_pair_encode ( & unstable_bytes, & self . encoder ) ) ;
518
+ match self . encoder . get ( & unstable_bytes) {
519
+ Some ( token) => tokens. push ( * token) ,
520
+ None => tokens. extend ( & byte_pair_encode ( & unstable_bytes, & self . encoder ) ) ,
521
+ }
520
522
}
521
523
tokens
522
524
}
@@ -597,15 +599,26 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
597
599
mod tests {
598
600
use rustc_hash:: FxHashMap as HashMap ;
599
601
600
- use crate :: byte_pair_split;
602
+ use crate :: { byte_pair_split, Rank } ;
601
603
602
- #[ test]
603
- fn very_simple_test ( ) {
604
- let mut ranks = HashMap :: default ( ) ;
605
- ranks. insert ( b"ab" . to_vec ( ) , 1 ) ;
606
- ranks. insert ( b"cd" . to_vec ( ) , 2 ) ;
604
+ fn setup_ranks ( ) -> HashMap < Vec < u8 > , Rank > {
605
+ HashMap :: from_iter ( [
606
+ ( b"ab" . to_vec ( ) , 0 ) ,
607
+ ( b"cd" . to_vec ( ) , 1 ) ,
608
+ ] )
609
+ }
607
610
611
+ #[ test]
612
+ fn test_simple_characters ( ) {
613
+ let ranks = setup_ranks ( ) ;
608
614
let res = byte_pair_split ( b"abcd" , & ranks) ;
609
615
assert_eq ! ( res, vec![ b"ab" , b"cd" ] ) ;
610
616
}
617
+
618
+ #[ test]
619
+ fn test_repeated_characters ( ) {
620
+ let ranks = setup_ranks ( ) ;
621
+ let res = byte_pair_split ( b"abab" , & ranks) ;
622
+ assert_eq ! ( res, vec![ b"ab" , b"ab" ] ) ;
623
+ }
611
624
}
0 commit comments