Skip to content

Commit 241dee1

Browse files
authored
make decoder and sorted_token_bytes re-use existing memory (#14)
i tried a bunch of safe ways to do this but they were all slower than this unsafe approach by itself this PR provides memory savings and a negligible impact on performance however, when combined with #13 the count_token performance is improved another ~3% (for a total improvement of ~10%)
1 parent 300ec0a commit 241dee1

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

src/corebpe.rs

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ const MAX_NUM_THREADS: usize = 8;
168168
pub struct CoreBPE {
169169
pub encoder: HashMap<Vec<u8>, usize>,
170170
special_tokens_encoder: HashMap<String, usize>,
171-
decoder: HashMap<usize, Vec<u8>>,
171+
decoder: HashMap<usize, &'static [u8]>,
172172
special_tokens_decoder: HashMap<usize, Vec<u8>>,
173173
regex_tls: Arc<[Regex]>,
174174
special_regex_tls: Arc<[Regex]>,
175-
sorted_token_bytes: Vec<Vec<u8>>,
175+
sorted_token_bytes: Vec<&'static [u8]>,
176176
}
177177

178178
impl CoreBPE {
@@ -193,7 +193,8 @@ impl CoreBPE {
193193
let token_bytes = self
194194
.decoder
195195
.get(token)
196-
.unwrap_or_else(|| &self.special_tokens_decoder[token]);
196+
.copied()
197+
.unwrap_or_else(|| self.special_tokens_decoder[token].as_slice());
197198
ret.extend(token_bytes);
198199
}
199200
ret
@@ -341,12 +342,12 @@ impl CoreBPE {
341342
// Separating this from the loop below helps with performance in a common case.
342343
let mut point = self
343344
.sorted_token_bytes
344-
.partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
345+
.partition_point(|x| *x < unstable_bytes.as_slice());
345346
while point < self.sorted_token_bytes.len()
346347
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
347348
{
348349
completions.insert(vec![
349-
self.encoder[self.sorted_token_bytes[point].as_slice()],
350+
self.encoder[self.sorted_token_bytes[point]],
350351
]);
351352
point += 1;
352353
}
@@ -359,12 +360,12 @@ impl CoreBPE {
359360
let suffix = &unstable_bytes[i..];
360361
let mut point = self
361362
.sorted_token_bytes
362-
.partition_point(|x| x.as_slice() < suffix);
363+
.partition_point(|x| *x < suffix);
363364
// TODO: Perf optimisation if suffix starts with " "?
364365
while point < self.sorted_token_bytes.len()
365366
&& self.sorted_token_bytes[point].starts_with(suffix)
366367
{
367-
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
368+
let possibility = [prefix, self.sorted_token_bytes[point]].concat();
368369
let encoded = match std::str::from_utf8(&possibility) {
369370
// Morally, this is byte_pair_encode(&possibility, &self.encoder)
370371
// But we might have introduced a regex split which would prevent merges.
@@ -443,8 +444,15 @@ impl CoreBPE {
443444
Regex::new(&parts.join("|"))?
444445
};
445446

446-
let decoder: HashMap<usize, Vec<u8>> =
447-
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
447+
// Use unsafe to extend the lifetime of references to the encoder's keys
448+
let decoder: HashMap<usize, &'static [u8]> = encoder
449+
.iter()
450+
.map(|(k, v)| {
451+
let bytes: &[u8] = k.as_slice();
452+
let static_bytes: &'static [u8] = unsafe { std::mem::transmute(bytes) };
453+
(*v, static_bytes)
454+
})
455+
.collect();
448456

449457
assert!(
450458
encoder.len() == decoder.len(),
@@ -456,8 +464,14 @@ impl CoreBPE {
456464
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
457465
.collect();
458466

459-
// Clone because I don't know how to tell Rust I'm not going to change the map
460-
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
467+
let mut sorted_token_bytes: Vec<&'static [u8]> = encoder
468+
.keys()
469+
.map(|k| {
470+
let bytes: &[u8] = k.as_slice();
471+
let static_bytes: &'static [u8] = unsafe { std::mem::transmute(bytes) };
472+
static_bytes
473+
})
474+
.collect();
461475
sorted_token_bytes.sort();
462476

463477
Ok(CoreBPE {
@@ -553,7 +567,7 @@ impl CoreBPE {
553567

554568
pub fn decode_single_token_bytes(&self, token: usize) -> Result<Vec<u8>, usize> {
555569
if let Some(bytes) = self.decoder.get(&token) {
556-
return Ok(bytes.clone());
570+
return Ok(bytes.to_vec());
557571
}
558572
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
559573
return Ok(bytes.clone());
@@ -566,7 +580,7 @@ impl CoreBPE {
566580
// ====================
567581

568582
pub fn token_byte_values(&self) -> Vec<Vec<u8>> {
569-
self.sorted_token_bytes.clone()
583+
self.sorted_token_bytes.iter().map(|&bytes| bytes.to_vec()).collect()
570584
}
571585
}
572586

0 commit comments

Comments
 (0)