Skip to content

Commit f40333c

Browse files
committed
Store tokens in u32 instead of usize
based on upstream commit openai@c2960c1 cc openai#251
1 parent 241dee1 commit f40333c

File tree

5 files changed

+73
-67
lines changed

5 files changed

+73
-67
lines changed

build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ fn main() {
7272
fn generate(
7373
name: &str,
7474
file: &mut File,
75-
mergeable_ranks: &HashMap<Vec<u8>, usize>,
75+
mergeable_ranks: &HashMap<Vec<u8>, Rank>,
7676
) {
7777
writeln!(
7878
file,

src/corebpe.rs

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
1+
use std::num::NonZeroU64;
12
use std::thread;
23

34
use fancy_regex::Regex;
45
use rustc_hash::FxHashMap as HashMap;
56
use rustc_hash::FxHashSet as HashSet;
67
use std::sync::Arc;
78

9+
pub type Rank = u32;
10+
811
fn _byte_pair_merge<T>(
912
piece: &[u8],
10-
ranks: &HashMap<Vec<u8>, usize>,
13+
ranks: &HashMap<Vec<u8>, Rank>,
1114
f: impl Fn(std::ops::Range<usize>) -> T,
1215
) -> Vec<T> {
1316
// This is a vector of (start, rank).
1417
// The rank is of the byte pair starting at position start.
1518
// The rank of the last item in the vector is not a valid value.
16-
let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();
19+
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
1720

1821
let get_rank = {
1922
#[inline(always)]
20-
|parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| {
23+
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
2124
if (start_idx + skip + 2) < parts.len() {
2225
ranks
2326
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
@@ -33,8 +36,8 @@ fn _byte_pair_merge<T>(
3336
for i in 0..parts.len() - 2 {
3437
match get_rank(&parts, i, 0) {
3538
Some(rank) => {
36-
// usize::MAX is a sentinel value and cannot be a valid rank
37-
debug_assert!(rank != usize::MAX);
39+
// Rank::MAX is a sentinel value and cannot be a valid rank
40+
debug_assert!(rank != Rank::MAX);
3841
parts[i].1 = rank;
3942
}
4043
None => {
@@ -57,26 +60,26 @@ fn _byte_pair_merge<T>(
5760
break;
5861
}
5962

60-
// usize::MAX is a sentinel rank value allowing us to
63+
// Rank::MAX is a sentinel rank value allowing us to
6164
// take the min more quickly
62-
let mut min_rank: (usize, usize) = (usize::MAX, 0);
65+
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
6366
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
6467
if rank < min_rank.0 {
6568
min_rank = (rank, i);
6669
}
6770
}
6871

69-
if min_rank.0 != usize::MAX {
72+
if min_rank.0 != Rank::MAX {
7073
let i = min_rank.1;
7174

7275
// NOTE: We are about to remove parts[i + 1]. We do not do it
7376
// yet because there are cache-locality benefits to updating
7477
// parts[i] and parts[i-1] before removing, which could thrash
7578
// the cache. Thus, we update the rank calculation by skipping over
7679
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
77-
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX);
80+
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
7881
if i > 0 {
79-
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX);
82+
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
8083
}
8184

8285
parts.remove(i + 1);
@@ -91,14 +94,14 @@ fn _byte_pair_merge<T>(
9194
out
9295
}
9396

94-
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
97+
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
9598
if piece.len() == 1 {
9699
return vec![ranks[piece]];
97100
}
98101
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
99102
}
100103

101-
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
104+
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
102105
if piece.len() == 1 {
103106
return vec![piece];
104107
}
@@ -146,7 +149,6 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
146149
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
147150
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
148151

149-
use std::num::NonZeroU64;
150152
pub struct FakeThreadId(NonZeroU64);
151153

152154
fn hash_current_thread() -> usize {
@@ -166,10 +168,10 @@ const MAX_NUM_THREADS: usize = 8;
166168

167169
#[derive(Debug)]
168170
pub struct CoreBPE {
169-
pub encoder: HashMap<Vec<u8>, usize>,
170-
special_tokens_encoder: HashMap<String, usize>,
171-
decoder: HashMap<usize, &'static [u8]>,
172-
special_tokens_decoder: HashMap<usize, Vec<u8>>,
171+
pub encoder: HashMap<Vec<u8>, Rank>,
172+
special_tokens_encoder: HashMap<String, Rank>,
173+
decoder: HashMap<Rank, &'static [u8]>,
174+
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
173175
regex_tls: Arc<[Regex]>,
174176
special_regex_tls: Arc<[Regex]>,
175177
sorted_token_bytes: Vec<&'static [u8]>,
@@ -187,7 +189,7 @@ impl CoreBPE {
187189
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
188190
}
189191

190-
fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
192+
fn _decode_native(&self, tokens: &[Rank]) -> Vec<u8> {
191193
let mut ret = Vec::with_capacity(tokens.len() * 2);
192194
for token in tokens {
193195
let token_bytes = self
@@ -200,7 +202,7 @@ impl CoreBPE {
200202
ret
201203
}
202204

203-
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
205+
fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
204206
// This is the core of the encoding logic; the other functions in here
205207
// just make things complicated :-)
206208
let regex = self._get_tl_regex();
@@ -216,7 +218,7 @@ impl CoreBPE {
216218
ret
217219
}
218220

219-
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
221+
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
220222
let special_regex = self._get_tl_special_regex();
221223
let regex = self._get_tl_regex();
222224
let mut ret = vec![];
@@ -274,9 +276,9 @@ impl CoreBPE {
274276

275277
fn _increase_last_piece_token_len(
276278
&self,
277-
tokens: Vec<usize>,
279+
tokens: Vec<Rank>,
278280
mut last_piece_token_len: usize,
279-
) -> (Vec<usize>, usize) {
281+
) -> (Vec<Rank>, usize) {
280282
// Unfortunately, the locations where our regex splits can be unstable.
281283
// For the purposes of determining unstable tokens, unstable regex splitting
282284
// is only a problem if a split that was present disappears, since this can
@@ -315,7 +317,7 @@ impl CoreBPE {
315317
&self,
316318
text: &str,
317319
allowed_special: &HashSet<&str>,
318-
) -> (Vec<usize>, HashSet<Vec<usize>>) {
320+
) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
319321
let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
320322
if last_piece_token_len == 0 {
321323
// If last_piece_token_len is zero, the last token was a special token and we have
@@ -430,8 +432,8 @@ impl CoreBPE {
430432

431433
impl CoreBPE {
432434
pub fn new(
433-
encoder: HashMap<Vec<u8>, usize>,
434-
special_tokens_encoder: HashMap<String, usize>,
435+
encoder: HashMap<Vec<u8>, Rank>,
436+
special_tokens_encoder: HashMap<String, Rank>,
435437
pattern: &str,
436438
) -> Result<Self, fancy_regex::Error> {
437439
let regex = Regex::new(pattern)?;
@@ -445,7 +447,7 @@ impl CoreBPE {
445447
};
446448

447449
// Use unsafe to extend the lifetime of references to the encoder's keys
448-
let decoder: HashMap<usize, &'static [u8]> = encoder
450+
let decoder: HashMap<Rank, &'static [u8]> = encoder
449451
.iter()
450452
.map(|(k, v)| {
451453
let bytes: &[u8] = k.as_slice();
@@ -459,7 +461,7 @@ impl CoreBPE {
459461
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
460462
);
461463

462-
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
464+
let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
463465
.iter()
464466
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
465467
.collect();
@@ -497,15 +499,15 @@ impl CoreBPE {
497499
// Encoding
498500
// ====================
499501

500-
pub fn encode_ordinary(&self, text: &str) -> Vec<usize> {
502+
pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
501503
self._encode_ordinary_native(text)
502504
}
503505

504-
pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<usize> {
506+
pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<Rank> {
505507
self._encode_native(text, &allowed_special).0
506508
}
507509

508-
pub fn _encode_bytes(&self, bytes: &[u8]) -> Vec<usize> {
510+
pub fn _encode_bytes(&self, bytes: &[u8]) -> Vec<Rank> {
509511
match std::str::from_utf8(bytes) {
510512
Ok(text) => self._encode_ordinary_native(text),
511513
Err(e) => {
@@ -534,11 +536,11 @@ impl CoreBPE {
534536
&self,
535537
text: &str,
536538
allowed_special: &HashSet<&str>,
537-
) -> (Vec<usize>, HashSet<Vec<usize>>) {
539+
) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
538540
self._encode_unstable_native(text, &allowed_special)
539541
}
540542

541-
pub fn encode_single_token(&self, piece: &[u8]) -> Result<usize, Vec<u8>> {
543+
pub fn encode_single_token(&self, piece: &[u8]) -> Result<Rank, Vec<u8>> {
542544
if let Some(token) = self.encoder.get(piece).copied() {
543545
return Ok(token);
544546
}
@@ -550,7 +552,7 @@ impl CoreBPE {
550552
Err(piece.to_owned())
551553
}
552554

553-
pub fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
555+
pub fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> {
554556
if let Some(token) = self.encoder.get(piece) {
555557
return vec![*token];
556558
}
@@ -561,11 +563,11 @@ impl CoreBPE {
561563
// Decoding
562564
// ====================
563565

564-
pub fn decode_bytes(&self, tokens: &[usize]) -> Vec<u8> {
566+
pub fn decode_bytes(&self, tokens: &[Rank]) -> Vec<u8> {
565567
self._decode_native(&tokens)
566568
}
567569

568-
pub fn decode_single_token_bytes(&self, token: usize) -> Result<Vec<u8>, usize> {
570+
pub fn decode_single_token_bytes(&self, token: Rank) -> Result<Vec<u8>, Rank> {
569571
if let Some(bytes) = self.decoder.get(&token) {
570572
return Ok(bytes.to_vec());
571573
}

0 commit comments

Comments
 (0)