Skip to content

Commit 62687ca

Browse files
hauntsaninjal0rinc
authored andcommitted
Simplify byte_pair_merge
Based on suggestion in openai#239 (specifically openai/tiktoken@8f5dd7d) Like that commit, this: - Does the init in a single loop and saves a loop if there are no merges - Simplifies get_rank and no longer uses it in init (so you don't need multiple skip values) Unlike that commit: - We drop optimisations enabled by ignoring single tokens. These didn't show any benefit on benchmarks for me (this makes sense given typical piece sizes, but let me know if that's unexpected!). Given this, I opted for the simpler version. - I preserve some of the comments from the original that I think are still useful Co-authored-by: @paplorinc --------- Co-authored-by: Lőrinc Pap <1841944+paplorinc@users.noreply.github.com>
1 parent 2692606 commit 62687ca

File tree

1 file changed

+36
-56
lines changed

1 file changed

+36
-56
lines changed

src/corebpe.rs

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,80 +13,60 @@ fn _byte_pair_merge(
1313
piece: &[u8],
1414
) -> Vec<(usize, Rank)> {
1515
// This is a vector of (start, rank).
16-
// The rank is of the byte pair starting at position start.
17-
// The rank of the last item in the vector is not a valid value.
18-
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
16+
// The rank is of the pair starting at position start.
17+
let mut parts = Vec::with_capacity(piece.len() + 1);
18+
19+
// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
20+
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
21+
// merge priority from token index or to prevent specific token merges.
22+
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
23+
for i in 0..piece.len() - 1 {
24+
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
25+
if rank < min_rank.0 {
26+
min_rank = (rank, i);
27+
}
28+
parts.push((i, rank));
29+
}
30+
parts.push((piece.len() - 1, Rank::MAX));
31+
parts.push((piece.len(), Rank::MAX));
1932

2033
let get_rank = {
2134
#[inline(always)]
22-
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
23-
if (start_idx + skip + 2) < parts.len() {
24-
ranks
25-
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
26-
.copied()
35+
|parts: &Vec<(usize, Rank)>, i: usize| {
36+
if (i + 3) < parts.len() {
37+
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
38+
// parts[i + 1], see comment in the main loop.
39+
*ranks
40+
.get(&piece[parts[i].0..parts[i + 3].0])
41+
.unwrap_or(&Rank::MAX)
2742
} else {
28-
None
43+
Rank::MAX
2944
}
3045
}
3146
};
3247

33-
// We look up the ranks once in the beginning and iteratively update
34-
// them during each merge, which reduces the number of rank lookups.
35-
for i in 0..parts.len() - 2 {
36-
match get_rank(&parts, i, 0) {
37-
Some(rank) => {
38-
// Rank::MAX is a sentinel value and cannot be a valid rank
39-
debug_assert!(rank != Rank::MAX);
40-
parts[i].1 = rank;
41-
}
42-
None => {
43-
continue;
44-
}
45-
};
46-
}
47-
4848
// If you have n parts and m merges, this does O(mn) work.
4949
// We could do something with a heap and do O(m log n) work.
50-
// It is important to consider that n is often small (<100), and as such
51-
// the cache-locality benefits outweigh the algorithmic complexity downsides
52-
// of the `parts` vector data structure above.
53-
54-
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
55-
// currently do, this is equivalent. An easy way to break this would be to decouple
56-
// merge priority from token index or to prevent specific token merges.
57-
loop {
58-
if parts.len() == 1 {
59-
break;
50+
// n is often very small so considerations like cache-locality outweigh the algorithmic
51+
// complexity downsides of the `parts` vector.
52+
while min_rank.0 != Rank::MAX {
53+
let i = min_rank.1;
54+
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
55+
// `parts.remove(i + 1)` will thrash the cache.
56+
if i > 0 {
57+
parts[i - 1].1 = get_rank(&parts, i - 1);
6058
}
59+
parts[i].1 = get_rank(&parts, i);
60+
parts.remove(i + 1);
6161

62-
// Rank::MAX is a sentinel rank value allowing us to
63-
// take the min more quickly
64-
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
62+
min_rank = (Rank::MAX, usize::MAX);
6563
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
6664
if rank < min_rank.0 {
6765
min_rank = (rank, i);
6866
}
6967
}
70-
71-
if min_rank.0 != Rank::MAX {
72-
let i = min_rank.1;
73-
74-
// NOTE: We are about to remove parts[i + 1]. We do not do it
75-
// yet because there are cache-locality benefits to updating
76-
// parts[i] and parts[i-1] before removing, which could thrash
77-
// the cache. Thus, we update the rank calculation by skipping over
78-
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
79-
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
80-
if i > 0 {
81-
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
82-
}
83-
84-
parts.remove(i + 1);
85-
} else {
86-
break;
87-
}
8868
}
89-
69+
9070
parts
9171
}
9272

0 commit comments

Comments
 (0)