2
2
#![ allow( clippy:: borrow_deref_ref) ]
3
3
4
4
use std:: collections:: HashSet ;
5
+ use std:: num:: NonZeroU64 ;
5
6
use std:: thread;
6
7
7
8
use fancy_regex:: Regex ;
8
9
use pyo3:: exceptions;
9
10
use pyo3:: prelude:: * ;
10
- use pyo3:: types:: { PyBytes , PyList , PyTuple } ;
11
11
use pyo3:: PyResult ;
12
+ use pyo3:: types:: { PyBytes , PyList , PyTuple } ;
12
13
use rustc_hash:: FxHashMap as HashMap ;
13
14
15
+ type Rank = u32 ;
16
+
14
17
fn _byte_pair_merge < T > (
15
18
piece : & [ u8 ] ,
16
- ranks : & HashMap < Vec < u8 > , usize > ,
19
+ ranks : & HashMap < Vec < u8 > , Rank > ,
17
20
f : impl Fn ( std:: ops:: Range < usize > ) -> T ,
18
21
) -> Vec < T > {
19
22
// This is a vector of (start, rank).
20
23
// The rank is of the byte pair starting at position start.
21
24
// The rank of the last item in the vector is not a valid value.
22
- let mut parts: Vec < ( usize , usize ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, usize :: MAX ) ) . collect ( ) ;
25
+ let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
23
26
24
27
let get_rank = {
25
28
#[ inline( always) ]
26
- |parts : & Vec < ( usize , usize ) > , start_idx : usize , skip : usize | {
29
+ |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
27
30
if ( start_idx + skip + 2 ) < parts. len ( ) {
28
31
ranks
29
32
. get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
@@ -39,8 +42,8 @@ fn _byte_pair_merge<T>(
39
42
for i in 0 ..parts. len ( ) - 2 {
40
43
match get_rank ( & parts, i, 0 ) {
41
44
Some ( rank) => {
42
- // usize ::MAX is a sentinel value and cannot be a valid rank
43
- debug_assert ! ( rank != usize :: MAX ) ;
45
+ // Rank ::MAX is a sentinel value and cannot be a valid rank
46
+ debug_assert ! ( rank != Rank :: MAX ) ;
44
47
parts[ i] . 1 = rank;
45
48
}
46
49
None => {
@@ -63,26 +66,26 @@ fn _byte_pair_merge<T>(
63
66
break ;
64
67
}
65
68
66
- // usize ::MAX is a sentinel rank value allowing us to
69
+ // Rank ::MAX is a sentinel rank value allowing us to
67
70
// take the min more quickly
68
- let mut min_rank: ( usize , usize ) = ( usize :: MAX , 0 ) ;
71
+ let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
69
72
for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
70
73
if rank < min_rank. 0 {
71
74
min_rank = ( rank, i) ;
72
75
}
73
76
}
74
77
75
- if min_rank. 0 != usize :: MAX {
78
+ if min_rank. 0 != Rank :: MAX {
76
79
let i = min_rank. 1 ;
77
80
78
81
// NOTE: We are about to remove parts[i + 1]. We do not do it
79
82
// yet because there are cache-locality benefits to updating
80
83
// parts[i] and parts[i-1] before removing, which could thrash
81
84
// the cache. Thus, we update the rank calculation by skipping over
82
85
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
83
- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( usize :: MAX ) ;
86
+ parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
84
87
if i > 0 {
85
- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( usize :: MAX ) ;
88
+ parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
86
89
}
87
90
88
91
parts. remove ( i + 1 ) ;
@@ -97,14 +100,14 @@ fn _byte_pair_merge<T>(
97
100
out
98
101
}
99
102
100
- pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , usize > ) -> Vec < usize > {
103
+ pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
101
104
if piece. len ( ) == 1 {
102
105
return vec ! [ ranks[ piece] ] ;
103
106
}
104
107
_byte_pair_merge ( piece, ranks, |p| ranks[ & piece[ p. start ..p. end ] ] )
105
108
}
106
109
107
- pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , usize > ) -> Vec < & ' a [ u8 ] > {
110
+ pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < & ' a [ u8 ] > {
108
111
if piece. len ( ) == 1 {
109
112
return vec ! [ piece] ;
110
113
}
@@ -152,7 +155,6 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
152
155
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
153
156
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
154
157
155
- use std:: num:: NonZeroU64 ;
156
158
pub struct FakeThreadId ( NonZeroU64 ) ;
157
159
158
160
fn hash_current_thread ( ) -> usize {
@@ -169,12 +171,13 @@ fn hash_current_thread() -> usize {
169
171
}
170
172
171
173
const MAX_NUM_THREADS : usize = 128 ;
174
+
172
175
#[ pyclass]
173
176
struct CoreBPE {
174
- encoder : HashMap < Vec < u8 > , usize > ,
175
- special_tokens_encoder : HashMap < String , usize > ,
176
- decoder : HashMap < usize , Vec < u8 > > ,
177
- special_tokens_decoder : HashMap < usize , Vec < u8 > > ,
177
+ encoder : HashMap < Vec < u8 > , Rank > ,
178
+ special_tokens_encoder : HashMap < String , Rank > ,
179
+ decoder : HashMap < Rank , Vec < u8 > > ,
180
+ special_tokens_decoder : HashMap < Rank , Vec < u8 > > ,
178
181
regex_tls : Vec < Regex > ,
179
182
special_regex_tls : Vec < Regex > ,
180
183
sorted_token_bytes : Vec < Vec < u8 > > ,
@@ -192,7 +195,7 @@ impl CoreBPE {
192
195
& self . special_regex_tls [ hash_current_thread ( ) % MAX_NUM_THREADS ]
193
196
}
194
197
195
- fn _decode_native ( & self , tokens : & [ usize ] ) -> Vec < u8 > {
198
+ fn _decode_native ( & self , tokens : & [ Rank ] ) -> Vec < u8 > {
196
199
let mut ret = Vec :: with_capacity ( tokens. len ( ) * 2 ) ;
197
200
for token in tokens {
198
201
let token_bytes = self
@@ -204,7 +207,7 @@ impl CoreBPE {
204
207
ret
205
208
}
206
209
207
- fn _encode_ordinary_native ( & self , text : & str ) -> Vec < usize > {
210
+ fn _encode_ordinary_native ( & self , text : & str ) -> Vec < Rank > {
208
211
// This is the core of the encoding logic; the other functions in here
209
212
// just make things complicated :-)
210
213
let regex = self . _get_tl_regex ( ) ;
@@ -220,7 +223,7 @@ impl CoreBPE {
220
223
ret
221
224
}
222
225
223
- fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < usize > , usize ) {
226
+ fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < Rank > , usize ) {
224
227
let special_regex = self . _get_tl_special_regex ( ) ;
225
228
let regex = self . _get_tl_regex ( ) ;
226
229
let mut ret = vec ! [ ] ;
@@ -278,9 +281,9 @@ impl CoreBPE {
278
281
279
282
fn _increase_last_piece_token_len (
280
283
& self ,
281
- tokens : Vec < usize > ,
284
+ tokens : Vec < Rank > ,
282
285
mut last_piece_token_len : usize ,
283
- ) -> ( Vec < usize > , usize ) {
286
+ ) -> ( Vec < Rank > , usize ) {
284
287
// Unfortunately, the locations where our regex splits can be unstable.
285
288
// For the purposes of determining unstable tokens, unstable regex splitting
286
289
// is only a problem if a split that was present disappears, since this can
@@ -319,7 +322,7 @@ impl CoreBPE {
319
322
& self ,
320
323
text : & str ,
321
324
allowed_special : & HashSet < & str > ,
322
- ) -> ( Vec < usize > , HashSet < Vec < usize > > ) {
325
+ ) -> ( Vec < Rank > , HashSet < Vec < Rank > > ) {
323
326
let ( tokens, last_piece_token_len) = self . _encode_native ( text, allowed_special) ;
324
327
if last_piece_token_len == 0 {
325
328
// If last_piece_token_len is zero, the last token was a special token and we have
@@ -436,8 +439,8 @@ impl CoreBPE {
436
439
impl CoreBPE {
437
440
#[ new]
438
441
fn new (
439
- encoder : HashMap < Vec < u8 > , usize > ,
440
- special_tokens_encoder : HashMap < String , usize > ,
442
+ encoder : HashMap < Vec < u8 > , Rank > ,
443
+ special_tokens_encoder : HashMap < String , Rank > ,
441
444
pattern : & str ,
442
445
) -> PyResult < Self > {
443
446
let regex = Regex :: new ( pattern)
@@ -452,15 +455,15 @@ impl CoreBPE {
452
455
. map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?
453
456
} ;
454
457
455
- let decoder: HashMap < usize , Vec < u8 > > =
458
+ let decoder: HashMap < Rank , Vec < u8 > > =
456
459
encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
457
460
458
461
assert ! (
459
462
encoder. len( ) == decoder. len( ) ,
460
463
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
461
464
) ;
462
465
463
- let special_tokens_decoder: HashMap < usize , Vec < u8 > > = special_tokens_encoder
466
+ let special_tokens_decoder: HashMap < Rank , Vec < u8 > > = special_tokens_encoder
464
467
. iter ( )
465
468
. map ( |( k, v) | ( * v, k. as_bytes ( ) . to_vec ( ) ) )
466
469
. collect ( ) ;
@@ -486,15 +489,15 @@ impl CoreBPE {
486
489
// Encoding
487
490
// ====================
488
491
489
- fn encode_ordinary ( & self , py : Python , text : & str ) -> Vec < usize > {
492
+ fn encode_ordinary ( & self , py : Python , text : & str ) -> Vec < Rank > {
490
493
py. allow_threads ( || self . _encode_ordinary_native ( text) )
491
494
}
492
495
493
- fn encode ( & self , py : Python , text : & str , allowed_special : HashSet < & str > ) -> Vec < usize > {
496
+ fn encode ( & self , py : Python , text : & str , allowed_special : HashSet < & str > ) -> Vec < Rank > {
494
497
py. allow_threads ( || self . _encode_native ( text, & allowed_special) . 0 )
495
498
}
496
499
497
- fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < usize > {
500
+ fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < Rank > {
498
501
py. allow_threads ( || {
499
502
match std:: str:: from_utf8 ( bytes) {
500
503
Ok ( text) => self . _encode_ordinary_native ( text) ,
@@ -534,7 +537,7 @@ impl CoreBPE {
534
537
( tokens, py_completions) . into_py ( py)
535
538
}
536
539
537
- fn encode_single_token ( & self , piece : & [ u8 ] ) -> PyResult < usize > {
540
+ fn encode_single_token ( & self , piece : & [ u8 ] ) -> PyResult < Rank > {
538
541
if let Some ( token) = self . encoder . get ( piece) . copied ( ) {
539
542
return Ok ( token) ;
540
543
}
@@ -546,7 +549,7 @@ impl CoreBPE {
546
549
Err ( PyErr :: new :: < exceptions:: PyKeyError , _ > ( piece. to_owned ( ) ) )
547
550
}
548
551
549
- fn encode_single_piece ( & self , piece : & [ u8 ] ) -> Vec < usize > {
552
+ fn encode_single_piece ( & self , piece : & [ u8 ] ) -> Vec < Rank > {
550
553
if let Some ( token) = self . encoder . get ( piece) {
551
554
return vec ! [ * token] ;
552
555
}
@@ -557,12 +560,12 @@ impl CoreBPE {
557
560
// Decoding
558
561
// ====================
559
562
560
- fn decode_bytes ( & self , py : Python , tokens : Vec < usize > ) -> Py < PyBytes > {
563
+ fn decode_bytes ( & self , py : Python , tokens : Vec < Rank > ) -> Py < PyBytes > {
561
564
let bytes = py. allow_threads ( || self . _decode_native ( & tokens) ) ;
562
565
PyBytes :: new ( py, & bytes) . into ( )
563
566
}
564
567
565
- fn decode_single_token_bytes ( & self , py : Python , token : usize ) -> PyResult < Py < PyBytes > > {
568
+ fn decode_single_token_bytes ( & self , py : Python , token : Rank ) -> PyResult < Py < PyBytes > > {
566
569
if let Some ( bytes) = self . decoder . get ( & token) {
567
570
return Ok ( PyBytes :: new ( py, bytes) . into ( ) ) ;
568
571
}
0 commit comments