1
+ use std:: num:: NonZeroU64 ;
1
2
use std:: thread;
2
3
3
4
use fancy_regex:: Regex ;
4
5
use rustc_hash:: FxHashMap as HashMap ;
5
6
use rustc_hash:: FxHashSet as HashSet ;
6
7
use std:: sync:: Arc ;
7
8
9
+ pub type Rank = u32 ;
10
+
8
11
fn _byte_pair_merge < T > (
9
12
piece : & [ u8 ] ,
10
- ranks : & HashMap < Vec < u8 > , usize > ,
13
+ ranks : & HashMap < Vec < u8 > , Rank > ,
11
14
f : impl Fn ( std:: ops:: Range < usize > ) -> T ,
12
15
) -> Vec < T > {
13
16
// This is a vector of (start, rank).
14
17
// The rank is of the byte pair starting at position start.
15
18
// The rank of the last item in the vector is not a valid value.
16
- let mut parts: Vec < ( usize , usize ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, usize :: MAX ) ) . collect ( ) ;
19
+ let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
17
20
18
21
let get_rank = {
19
22
#[ inline( always) ]
20
- |parts : & Vec < ( usize , usize ) > , start_idx : usize , skip : usize | {
23
+ |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
21
24
if ( start_idx + skip + 2 ) < parts. len ( ) {
22
25
ranks
23
26
. get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
@@ -33,8 +36,8 @@ fn _byte_pair_merge<T>(
33
36
for i in 0 ..parts. len ( ) - 2 {
34
37
match get_rank ( & parts, i, 0 ) {
35
38
Some ( rank) => {
36
- // usize ::MAX is a sentinel value and cannot be a valid rank
37
- debug_assert ! ( rank != usize :: MAX ) ;
39
+ // Rank ::MAX is a sentinel value and cannot be a valid rank
40
+ debug_assert ! ( rank != Rank :: MAX ) ;
38
41
parts[ i] . 1 = rank;
39
42
}
40
43
None => {
@@ -57,26 +60,26 @@ fn _byte_pair_merge<T>(
57
60
break ;
58
61
}
59
62
60
- // usize ::MAX is a sentinel rank value allowing us to
63
+ // Rank ::MAX is a sentinel rank value allowing us to
61
64
// take the min more quickly
62
- let mut min_rank: ( usize , usize ) = ( usize :: MAX , 0 ) ;
65
+ let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
63
66
for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
64
67
if rank < min_rank. 0 {
65
68
min_rank = ( rank, i) ;
66
69
}
67
70
}
68
71
69
- if min_rank. 0 != usize :: MAX {
72
+ if min_rank. 0 != Rank :: MAX {
70
73
let i = min_rank. 1 ;
71
74
72
75
// NOTE: We are about to remove parts[i + 1]. We do not do it
73
76
// yet because there are cache-locality benefits to updating
74
77
// parts[i] and parts[i-1] before removing, which could thrash
75
78
// the cache. Thus, we update the rank calculation by skipping over
76
79
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
77
- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( usize :: MAX ) ;
80
+ parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
78
81
if i > 0 {
79
- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( usize :: MAX ) ;
82
+ parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
80
83
}
81
84
82
85
parts. remove ( i + 1 ) ;
@@ -91,14 +94,14 @@ fn _byte_pair_merge<T>(
91
94
out
92
95
}
93
96
94
- pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , usize > ) -> Vec < usize > {
97
+ pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
95
98
if piece. len ( ) == 1 {
96
99
return vec ! [ ranks[ piece] ] ;
97
100
}
98
101
_byte_pair_merge ( piece, ranks, |p| ranks[ & piece[ p. start ..p. end ] ] )
99
102
}
100
103
101
- pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , usize > ) -> Vec < & ' a [ u8 ] > {
104
+ pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < & ' a [ u8 ] > {
102
105
if piece. len ( ) == 1 {
103
106
return vec ! [ piece] ;
104
107
}
@@ -146,7 +149,6 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
146
149
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
147
150
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
148
151
149
- use std:: num:: NonZeroU64 ;
150
152
pub struct FakeThreadId ( NonZeroU64 ) ;
151
153
152
154
fn hash_current_thread ( ) -> usize {
@@ -166,10 +168,10 @@ const MAX_NUM_THREADS: usize = 8;
166
168
167
169
#[ derive( Debug ) ]
168
170
pub struct CoreBPE {
169
- pub encoder : HashMap < Vec < u8 > , usize > ,
170
- special_tokens_encoder : HashMap < String , usize > ,
171
- decoder : HashMap < usize , & ' static [ u8 ] > ,
172
- special_tokens_decoder : HashMap < usize , Vec < u8 > > ,
171
+ pub encoder : HashMap < Vec < u8 > , Rank > ,
172
+ special_tokens_encoder : HashMap < String , Rank > ,
173
+ decoder : HashMap < Rank , & ' static [ u8 ] > ,
174
+ special_tokens_decoder : HashMap < Rank , Vec < u8 > > ,
173
175
regex_tls : Arc < [ Regex ] > ,
174
176
special_regex_tls : Arc < [ Regex ] > ,
175
177
sorted_token_bytes : Vec < & ' static [ u8 ] > ,
@@ -187,7 +189,7 @@ impl CoreBPE {
187
189
& self . special_regex_tls [ hash_current_thread ( ) % MAX_NUM_THREADS ]
188
190
}
189
191
190
- fn _decode_native ( & self , tokens : & [ usize ] ) -> Vec < u8 > {
192
+ fn _decode_native ( & self , tokens : & [ Rank ] ) -> Vec < u8 > {
191
193
let mut ret = Vec :: with_capacity ( tokens. len ( ) * 2 ) ;
192
194
for token in tokens {
193
195
let token_bytes = self
@@ -200,7 +202,7 @@ impl CoreBPE {
200
202
ret
201
203
}
202
204
203
- fn _encode_ordinary_native ( & self , text : & str ) -> Vec < usize > {
205
+ fn _encode_ordinary_native ( & self , text : & str ) -> Vec < Rank > {
204
206
// This is the core of the encoding logic; the other functions in here
205
207
// just make things complicated :-)
206
208
let regex = self . _get_tl_regex ( ) ;
@@ -216,7 +218,7 @@ impl CoreBPE {
216
218
ret
217
219
}
218
220
219
- fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < usize > , usize ) {
221
+ fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < Rank > , usize ) {
220
222
let special_regex = self . _get_tl_special_regex ( ) ;
221
223
let regex = self . _get_tl_regex ( ) ;
222
224
let mut ret = vec ! [ ] ;
@@ -274,9 +276,9 @@ impl CoreBPE {
274
276
275
277
fn _increase_last_piece_token_len (
276
278
& self ,
277
- tokens : Vec < usize > ,
279
+ tokens : Vec < Rank > ,
278
280
mut last_piece_token_len : usize ,
279
- ) -> ( Vec < usize > , usize ) {
281
+ ) -> ( Vec < Rank > , usize ) {
280
282
// Unfortunately, the locations where our regex splits can be unstable.
281
283
// For the purposes of determining unstable tokens, unstable regex splitting
282
284
// is only a problem if a split that was present disappears, since this can
@@ -315,7 +317,7 @@ impl CoreBPE {
315
317
& self ,
316
318
text : & str ,
317
319
allowed_special : & HashSet < & str > ,
318
- ) -> ( Vec < usize > , HashSet < Vec < usize > > ) {
320
+ ) -> ( Vec < Rank > , HashSet < Vec < Rank > > ) {
319
321
let ( tokens, last_piece_token_len) = self . _encode_native ( text, allowed_special) ;
320
322
if last_piece_token_len == 0 {
321
323
// If last_piece_token_len is zero, the last token was a special token and we have
@@ -430,8 +432,8 @@ impl CoreBPE {
430
432
431
433
impl CoreBPE {
432
434
pub fn new (
433
- encoder : HashMap < Vec < u8 > , usize > ,
434
- special_tokens_encoder : HashMap < String , usize > ,
435
+ encoder : HashMap < Vec < u8 > , Rank > ,
436
+ special_tokens_encoder : HashMap < String , Rank > ,
435
437
pattern : & str ,
436
438
) -> Result < Self , fancy_regex:: Error > {
437
439
let regex = Regex :: new ( pattern) ?;
@@ -445,7 +447,7 @@ impl CoreBPE {
445
447
} ;
446
448
447
449
// Use unsafe to extend the lifetime of references to the encoder's keys
448
- let decoder: HashMap < usize , & ' static [ u8 ] > = encoder
450
+ let decoder: HashMap < Rank , & ' static [ u8 ] > = encoder
449
451
. iter ( )
450
452
. map ( |( k, v) | {
451
453
let bytes: & [ u8 ] = k. as_slice ( ) ;
@@ -459,7 +461,7 @@ impl CoreBPE {
459
461
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
460
462
) ;
461
463
462
- let special_tokens_decoder: HashMap < usize , Vec < u8 > > = special_tokens_encoder
464
+ let special_tokens_decoder: HashMap < Rank , Vec < u8 > > = special_tokens_encoder
463
465
. iter ( )
464
466
. map ( |( k, v) | ( * v, k. as_bytes ( ) . to_vec ( ) ) )
465
467
. collect ( ) ;
@@ -497,15 +499,15 @@ impl CoreBPE {
497
499
// Encoding
498
500
// ====================
499
501
500
- pub fn encode_ordinary ( & self , text : & str ) -> Vec < usize > {
502
+ pub fn encode_ordinary ( & self , text : & str ) -> Vec < Rank > {
501
503
self . _encode_ordinary_native ( text)
502
504
}
503
505
504
- pub fn encode ( & self , text : & str , allowed_special : & HashSet < & str > ) -> Vec < usize > {
506
+ pub fn encode ( & self , text : & str , allowed_special : & HashSet < & str > ) -> Vec < Rank > {
505
507
self . _encode_native ( text, & allowed_special) . 0
506
508
}
507
509
508
- pub fn _encode_bytes ( & self , bytes : & [ u8 ] ) -> Vec < usize > {
510
+ pub fn _encode_bytes ( & self , bytes : & [ u8 ] ) -> Vec < Rank > {
509
511
match std:: str:: from_utf8 ( bytes) {
510
512
Ok ( text) => self . _encode_ordinary_native ( text) ,
511
513
Err ( e) => {
@@ -534,11 +536,11 @@ impl CoreBPE {
534
536
& self ,
535
537
text : & str ,
536
538
allowed_special : & HashSet < & str > ,
537
- ) -> ( Vec < usize > , HashSet < Vec < usize > > ) {
539
+ ) -> ( Vec < Rank > , HashSet < Vec < Rank > > ) {
538
540
self . _encode_unstable_native ( text, & allowed_special)
539
541
}
540
542
541
- pub fn encode_single_token ( & self , piece : & [ u8 ] ) -> Result < usize , Vec < u8 > > {
543
+ pub fn encode_single_token ( & self , piece : & [ u8 ] ) -> Result < Rank , Vec < u8 > > {
542
544
if let Some ( token) = self . encoder . get ( piece) . copied ( ) {
543
545
return Ok ( token) ;
544
546
}
@@ -550,7 +552,7 @@ impl CoreBPE {
550
552
Err ( piece. to_owned ( ) )
551
553
}
552
554
553
- pub fn encode_single_piece ( & self , piece : & [ u8 ] ) -> Vec < usize > {
555
+ pub fn encode_single_piece ( & self , piece : & [ u8 ] ) -> Vec < Rank > {
554
556
if let Some ( token) = self . encoder . get ( piece) {
555
557
return vec ! [ * token] ;
556
558
}
@@ -561,11 +563,11 @@ impl CoreBPE {
561
563
// Decoding
562
564
// ====================
563
565
564
- pub fn decode_bytes ( & self , tokens : & [ usize ] ) -> Vec < u8 > {
566
+ pub fn decode_bytes ( & self , tokens : & [ Rank ] ) -> Vec < u8 > {
565
567
self . _decode_native ( & tokens)
566
568
}
567
569
568
- pub fn decode_single_token_bytes ( & self , token : usize ) -> Result < Vec < u8 > , usize > {
570
+ pub fn decode_single_token_bytes ( & self , token : Rank ) -> Result < Vec < u8 > , Rank > {
569
571
if let Some ( bytes) = self . decoder . get ( & token) {
570
572
return Ok ( bytes. to_vec ( ) ) ;
571
573
}
0 commit comments