Skip to content

Commit c2960c1

Browse files
Lőrinchauntsaninja
authored andcommitted
Store tokens in u32 instead of usize
And hide it behind a Rank type to make it easier to separate it from other numeric values
1 parent 84d88dc commit c2960c1

File tree

1 file changed

+38
-35
lines changed

1 file changed

+38
-35
lines changed

src/lib.rs

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,31 @@
22
#![allow(clippy::borrow_deref_ref)]
33

44
use std::collections::HashSet;
5+
use std::num::NonZeroU64;
56
use std::thread;
67

78
use fancy_regex::Regex;
89
use pyo3::exceptions;
910
use pyo3::prelude::*;
10-
use pyo3::types::{PyBytes, PyList, PyTuple};
1111
use pyo3::PyResult;
12+
use pyo3::types::{PyBytes, PyList, PyTuple};
1213
use rustc_hash::FxHashMap as HashMap;
1314

15+
type Rank = u32;
16+
1417
fn _byte_pair_merge<T>(
1518
piece: &[u8],
16-
ranks: &HashMap<Vec<u8>, usize>,
19+
ranks: &HashMap<Vec<u8>, Rank>,
1720
f: impl Fn(std::ops::Range<usize>) -> T,
1821
) -> Vec<T> {
1922
// This is a vector of (start, rank).
2023
// The rank is of the byte pair starting at position start.
2124
// The rank of the last item in the vector is not a valid value.
22-
let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();
25+
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
2326

2427
let get_rank = {
2528
#[inline(always)]
26-
|parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| {
29+
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
2730
if (start_idx + skip + 2) < parts.len() {
2831
ranks
2932
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
@@ -39,8 +42,8 @@ fn _byte_pair_merge<T>(
3942
for i in 0..parts.len() - 2 {
4043
match get_rank(&parts, i, 0) {
4144
Some(rank) => {
42-
// usize::MAX is a sentinel value and cannot be a valid rank
43-
debug_assert!(rank != usize::MAX);
45+
// Rank::MAX is a sentinel value and cannot be a valid rank
46+
debug_assert!(rank != Rank::MAX);
4447
parts[i].1 = rank;
4548
}
4649
None => {
@@ -63,26 +66,26 @@ fn _byte_pair_merge<T>(
6366
break;
6467
}
6568

66-
// usize::MAX is a sentinel rank value allowing us to
69+
// Rank::MAX is a sentinel rank value allowing us to
6770
// take the min more quickly
68-
let mut min_rank: (usize, usize) = (usize::MAX, 0);
71+
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
6972
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
7073
if rank < min_rank.0 {
7174
min_rank = (rank, i);
7275
}
7376
}
7477

75-
if min_rank.0 != usize::MAX {
78+
if min_rank.0 != Rank::MAX {
7679
let i = min_rank.1;
7780

7881
// NOTE: We are about to remove parts[i + 1]. We do not do it
7982
// yet because there are cache-locality benefits to updating
8083
// parts[i] and parts[i-1] before removing, which could thrash
8184
// the cache. Thus, we update the rank calculation by skipping over
8285
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
83-
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX);
86+
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
8487
if i > 0 {
85-
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX);
88+
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
8689
}
8790

8891
parts.remove(i + 1);
@@ -97,14 +100,14 @@ fn _byte_pair_merge<T>(
97100
out
98101
}
99102

100-
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
103+
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
101104
if piece.len() == 1 {
102105
return vec![ranks[piece]];
103106
}
104107
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
105108
}
106109

107-
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
110+
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
108111
if piece.len() == 1 {
109112
return vec![piece];
110113
}
@@ -152,7 +155,6 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
152155
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
153156
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
154157

155-
use std::num::NonZeroU64;
156158
pub struct FakeThreadId(NonZeroU64);
157159

158160
fn hash_current_thread() -> usize {
@@ -169,12 +171,13 @@ fn hash_current_thread() -> usize {
169171
}
170172

171173
const MAX_NUM_THREADS: usize = 128;
174+
172175
#[pyclass]
173176
struct CoreBPE {
174-
encoder: HashMap<Vec<u8>, usize>,
175-
special_tokens_encoder: HashMap<String, usize>,
176-
decoder: HashMap<usize, Vec<u8>>,
177-
special_tokens_decoder: HashMap<usize, Vec<u8>>,
177+
encoder: HashMap<Vec<u8>, Rank>,
178+
special_tokens_encoder: HashMap<String, Rank>,
179+
decoder: HashMap<Rank, Vec<u8>>,
180+
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
178181
regex_tls: Vec<Regex>,
179182
special_regex_tls: Vec<Regex>,
180183
sorted_token_bytes: Vec<Vec<u8>>,
@@ -192,7 +195,7 @@ impl CoreBPE {
192195
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
193196
}
194197

195-
fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
198+
fn _decode_native(&self, tokens: &[Rank]) -> Vec<u8> {
196199
let mut ret = Vec::with_capacity(tokens.len() * 2);
197200
for token in tokens {
198201
let token_bytes = self
@@ -204,7 +207,7 @@ impl CoreBPE {
204207
ret
205208
}
206209

207-
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
210+
fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
208211
// This is the core of the encoding logic; the other functions in here
209212
// just make things complicated :-)
210213
let regex = self._get_tl_regex();
@@ -220,7 +223,7 @@ impl CoreBPE {
220223
ret
221224
}
222225

223-
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
226+
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
224227
let special_regex = self._get_tl_special_regex();
225228
let regex = self._get_tl_regex();
226229
let mut ret = vec![];
@@ -278,9 +281,9 @@ impl CoreBPE {
278281

279282
fn _increase_last_piece_token_len(
280283
&self,
281-
tokens: Vec<usize>,
284+
tokens: Vec<Rank>,
282285
mut last_piece_token_len: usize,
283-
) -> (Vec<usize>, usize) {
286+
) -> (Vec<Rank>, usize) {
284287
// Unfortunately, the locations where our regex splits can be unstable.
285288
// For the purposes of determining unstable tokens, unstable regex splitting
286289
// is only a problem if a split that was present disappears, since this can
@@ -319,7 +322,7 @@ impl CoreBPE {
319322
&self,
320323
text: &str,
321324
allowed_special: &HashSet<&str>,
322-
) -> (Vec<usize>, HashSet<Vec<usize>>) {
325+
) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
323326
let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
324327
if last_piece_token_len == 0 {
325328
// If last_piece_token_len is zero, the last token was a special token and we have
@@ -436,8 +439,8 @@ impl CoreBPE {
436439
impl CoreBPE {
437440
#[new]
438441
fn new(
439-
encoder: HashMap<Vec<u8>, usize>,
440-
special_tokens_encoder: HashMap<String, usize>,
442+
encoder: HashMap<Vec<u8>, Rank>,
443+
special_tokens_encoder: HashMap<String, Rank>,
441444
pattern: &str,
442445
) -> PyResult<Self> {
443446
let regex = Regex::new(pattern)
@@ -452,15 +455,15 @@ impl CoreBPE {
452455
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
453456
};
454457

455-
let decoder: HashMap<usize, Vec<u8>> =
458+
let decoder: HashMap<Rank, Vec<u8>> =
456459
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
457460

458461
assert!(
459462
encoder.len() == decoder.len(),
460463
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
461464
);
462465

463-
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
466+
let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
464467
.iter()
465468
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
466469
.collect();
@@ -486,15 +489,15 @@ impl CoreBPE {
486489
// Encoding
487490
// ====================
488491

489-
fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> {
492+
fn encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> {
490493
py.allow_threads(|| self._encode_ordinary_native(text))
491494
}
492495

493-
fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
496+
fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<Rank> {
494497
py.allow_threads(|| self._encode_native(text, &allowed_special).0)
495498
}
496499

497-
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
500+
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
498501
py.allow_threads(|| {
499502
match std::str::from_utf8(bytes) {
500503
Ok(text) => self._encode_ordinary_native(text),
@@ -534,7 +537,7 @@ impl CoreBPE {
534537
(tokens, py_completions).into_py(py)
535538
}
536539

537-
fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> {
540+
fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
538541
if let Some(token) = self.encoder.get(piece).copied() {
539542
return Ok(token);
540543
}
@@ -546,7 +549,7 @@ impl CoreBPE {
546549
Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
547550
}
548551

549-
fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
552+
fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> {
550553
if let Some(token) = self.encoder.get(piece) {
551554
return vec![*token];
552555
}
@@ -557,12 +560,12 @@ impl CoreBPE {
557560
// Decoding
558561
// ====================
559562

560-
fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
563+
fn decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Py<PyBytes> {
561564
let bytes = py.allow_threads(|| self._decode_native(&tokens));
562565
PyBytes::new(py, &bytes).into()
563566
}
564567

565-
fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
568+
fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> {
566569
if let Some(bytes) = self.decoder.get(&token) {
567570
return Ok(PyBytes::new(py, bytes).into());
568571
}

0 commit comments

Comments
 (0)