Skip to content

Commit b4c687e

Browse files
Lőrinchauntsaninja
authored andcommitted
Avoid calling byte_pair_encode for existing tokens
This was byte_pair_encode can be optimized further, assuming we'll always have at least 2 tokens
1 parent 6e4851a commit b4c687e

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

src/lib.rs

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ fn hash_current_thread() -> usize {
162162
// that works great for our use case of avoiding collisions in our array. Unfortunately,
163163
// it's private. However, there are only so many ways you can layout a u64, so just transmute
164164
// https://github.com/rust-lang/rust/issues/67939
165-
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
165+
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
166166
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
167167
let x = unsafe {
168-
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
168+
std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0
169169
};
170170
u64::from(x) as usize
171171
}
@@ -214,11 +214,10 @@ impl CoreBPE {
214214
let mut ret = vec![];
215215
for mat in regex.find_iter(text) {
216216
let piece = mat.unwrap().as_str().as_bytes();
217-
if let Some(token) = self.encoder.get(piece) {
218-
ret.push(*token);
219-
continue;
217+
match self.encoder.get(piece) {
218+
Some(token) => ret.push(*token),
219+
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
220220
}
221-
ret.extend(&byte_pair_encode(piece, &self.encoder));
222221
}
223222
ret
224223
}
@@ -516,7 +515,10 @@ impl CoreBPE {
516515
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
517516

518517
tokens.truncate(tokens.len() - last_piece_token_len);
519-
tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
518+
match self.encoder.get(&unstable_bytes) {
519+
Some(token) => tokens.push(*token),
520+
None => tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)),
521+
}
520522
}
521523
tokens
522524
}
@@ -597,15 +599,26 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
597599
mod tests {
598600
use rustc_hash::FxHashMap as HashMap;
599601

600-
use crate::byte_pair_split;
602+
use crate::{byte_pair_split, Rank};
601603

602-
#[test]
603-
fn very_simple_test() {
604-
let mut ranks = HashMap::default();
605-
ranks.insert(b"ab".to_vec(), 1);
606-
ranks.insert(b"cd".to_vec(), 2);
604+
fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
605+
HashMap::from_iter([
606+
(b"ab".to_vec(), 0),
607+
(b"cd".to_vec(), 1),
608+
])
609+
}
607610

611+
#[test]
612+
fn test_simple_characters() {
613+
let ranks = setup_ranks();
608614
let res = byte_pair_split(b"abcd", &ranks);
609615
assert_eq!(res, vec![b"ab", b"cd"]);
610616
}
617+
618+
#[test]
619+
fn test_repeated_characters() {
620+
let ranks = setup_ranks();
621+
let res = byte_pair_split(b"abab", &ranks);
622+
assert_eq!(res, vec![b"ab", b"ab"]);
623+
}
611624
}

0 commit comments

Comments
 (0)