Skip to content

Commit 0ecd650

Browse files
Lőrinctmm1
authored andcommitted
Inline custom mapping function in _byte_pair_merge
(cherry-picked from openai/tiktoken@6defed5)
1 parent 48e7eda commit 0ecd650

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

src/corebpe.rs

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ use std::sync::Arc;
88

99
pub type Rank = u32;
1010

11-
fn _byte_pair_merge<T>(
12-
piece: &[u8],
11+
fn _byte_pair_merge(
1312
ranks: &HashMap<Vec<u8>, Rank>,
14-
f: impl Fn(std::ops::Range<usize>) -> T,
15-
) -> Vec<T> {
13+
piece: &[u8],
14+
) -> Vec<(usize, Rank)> {
1615
// This is a vector of (start, rank).
1716
// The rank is of the byte pair starting at position start.
1817
// The rank of the last item in the vector is not a valid value.
@@ -87,25 +86,24 @@ fn _byte_pair_merge<T>(
8786
break;
8887
}
8988
}
90-
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
91-
for i in 0..parts.len() - 1 {
92-
out.push(f(parts[i].0..parts[i + 1].0));
93-
}
94-
out
89+
90+
parts
9591
}
9692

9793
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
98-
if piece.len() == 1 {
99-
return vec![ranks[piece]];
100-
}
101-
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
94+
assert!(piece.len() > 1);
95+
_byte_pair_merge(&ranks, &piece)
96+
.windows(2)
97+
.map(|part| ranks[&piece[part[0].0..part[1].0]])
98+
.collect()
10299
}
103100

104101
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
105-
if piece.len() == 1 {
106-
return vec![piece];
107-
}
108-
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
102+
assert!(piece.len() > 1);
103+
_byte_pair_merge(&ranks, &piece)
104+
.windows(2)
105+
.map(|part| &piece[part[0].0..part[1].0])
106+
.collect()
109107
}
110108

111109
// Various performance notes:

0 commit comments

Comments
 (0)