@@ -168,11 +168,11 @@ const MAX_NUM_THREADS: usize = 8;
168
168
pub struct CoreBPE {
169
169
pub encoder : HashMap < Vec < u8 > , usize > ,
170
170
special_tokens_encoder : HashMap < String , usize > ,
171
- decoder : HashMap < usize , Vec < u8 > > ,
171
+ decoder : HashMap < usize , & ' static [ u8 ] > ,
172
172
special_tokens_decoder : HashMap < usize , Vec < u8 > > ,
173
173
regex_tls : Arc < [ Regex ] > ,
174
174
special_regex_tls : Arc < [ Regex ] > ,
175
- sorted_token_bytes : Vec < Vec < u8 > > ,
175
+ sorted_token_bytes : Vec < & ' static [ u8 ] > ,
176
176
}
177
177
178
178
impl CoreBPE {
@@ -193,7 +193,8 @@ impl CoreBPE {
193
193
let token_bytes = self
194
194
. decoder
195
195
. get ( token)
196
- . unwrap_or_else ( || & self . special_tokens_decoder [ token] ) ;
196
+ . copied ( )
197
+ . unwrap_or_else ( || self . special_tokens_decoder [ token] . as_slice ( ) ) ;
197
198
ret. extend ( token_bytes) ;
198
199
}
199
200
ret
@@ -341,12 +342,12 @@ impl CoreBPE {
341
342
// Separating this from the loop below helps with performance in a common case.
342
343
let mut point = self
343
344
. sorted_token_bytes
344
- . partition_point ( |x| x . as_slice ( ) < unstable_bytes. as_slice ( ) ) ;
345
+ . partition_point ( |x| * x < unstable_bytes. as_slice ( ) ) ;
345
346
while point < self . sorted_token_bytes . len ( )
346
347
&& self . sorted_token_bytes [ point] . starts_with ( & unstable_bytes)
347
348
{
348
349
completions. insert ( vec ! [
349
- self . encoder[ self . sorted_token_bytes[ point] . as_slice ( ) ] ,
350
+ self . encoder[ self . sorted_token_bytes[ point] ] ,
350
351
] ) ;
351
352
point += 1 ;
352
353
}
@@ -359,12 +360,12 @@ impl CoreBPE {
359
360
let suffix = & unstable_bytes[ i..] ;
360
361
let mut point = self
361
362
. sorted_token_bytes
362
- . partition_point ( |x| x . as_slice ( ) < suffix) ;
363
+ . partition_point ( |x| * x < suffix) ;
363
364
// TODO: Perf optimisation if suffix starts with " "?
364
365
while point < self . sorted_token_bytes . len ( )
365
366
&& self . sorted_token_bytes [ point] . starts_with ( suffix)
366
367
{
367
- let possibility = [ prefix, self . sorted_token_bytes [ point] . as_slice ( ) ] . concat ( ) ;
368
+ let possibility = [ prefix, self . sorted_token_bytes [ point] ] . concat ( ) ;
368
369
let encoded = match std:: str:: from_utf8 ( & possibility) {
369
370
// Morally, this is byte_pair_encode(&possibility, &self.encoder)
370
371
// But we might have introduced a regex split which would prevent merges.
@@ -443,8 +444,15 @@ impl CoreBPE {
443
444
Regex :: new ( & parts. join ( "|" ) ) ?
444
445
} ;
445
446
446
- let decoder: HashMap < usize , Vec < u8 > > =
447
- encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
447
+ // Use unsafe to extend the lifetime of references to the encoder's keys
448
+ let decoder: HashMap < usize , & ' static [ u8 ] > = encoder
449
+ . iter ( )
450
+ . map ( |( k, v) | {
451
+ let bytes: & [ u8 ] = k. as_slice ( ) ;
452
+ let static_bytes: & ' static [ u8 ] = unsafe { std:: mem:: transmute ( bytes) } ;
453
+ ( * v, static_bytes)
454
+ } )
455
+ . collect ( ) ;
448
456
449
457
assert ! (
450
458
encoder. len( ) == decoder. len( ) ,
@@ -456,8 +464,14 @@ impl CoreBPE {
456
464
. map ( |( k, v) | ( * v, k. as_bytes ( ) . to_vec ( ) ) )
457
465
. collect ( ) ;
458
466
459
- // Clone because I don't know how to tell Rust I'm not going to change the map
460
- let mut sorted_token_bytes: Vec < Vec < u8 > > = encoder. keys ( ) . cloned ( ) . collect ( ) ;
467
+ let mut sorted_token_bytes: Vec < & ' static [ u8 ] > = encoder
468
+ . keys ( )
469
+ . map ( |k| {
470
+ let bytes: & [ u8 ] = k. as_slice ( ) ;
471
+ let static_bytes: & ' static [ u8 ] = unsafe { std:: mem:: transmute ( bytes) } ;
472
+ static_bytes
473
+ } )
474
+ . collect ( ) ;
461
475
sorted_token_bytes. sort ( ) ;
462
476
463
477
Ok ( CoreBPE {
@@ -553,7 +567,7 @@ impl CoreBPE {
553
567
554
568
pub fn decode_single_token_bytes ( & self , token : usize ) -> Result < Vec < u8 > , usize > {
555
569
if let Some ( bytes) = self . decoder . get ( & token) {
556
- return Ok ( bytes. clone ( ) ) ;
570
+ return Ok ( bytes. to_vec ( ) ) ;
557
571
}
558
572
if let Some ( bytes) = self . special_tokens_decoder . get ( & token) {
559
573
return Ok ( bytes. clone ( ) ) ;
@@ -566,7 +580,7 @@ impl CoreBPE {
566
580
// ====================
567
581
568
582
pub fn token_byte_values ( & self ) -> Vec < Vec < u8 > > {
569
- self . sorted_token_bytes . clone ( )
583
+ self . sorted_token_bytes . iter ( ) . map ( | & bytes| bytes . to_vec ( ) ) . collect ( )
570
584
}
571
585
}
572
586
0 commit comments