Skip to content

Commit 0a951f9

Browse files
tmm1hauntsaninjal0rinc
authored
Simplify byte_pair_merge (#17)
backport of openai#255 Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Co-authored-by: Lőrinc Pap <1841944+paplorinc@users.noreply.github.com>
1 parent 2692606 commit 0a951f9

File tree

1 file changed

+36
-56
lines changed

1 file changed

+36
-56
lines changed

src/corebpe.rs

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,80 +13,60 @@ fn _byte_pair_merge(
1313
piece: &[u8],
1414
) -> Vec<(usize, Rank)> {
1515
// This is a vector of (start, rank).
16-
// The rank is of the byte pair starting at position start.
17-
// The rank of the last item in the vector is not a valid value.
18-
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
16+
// The rank is of the pair starting at position start.
17+
let mut parts = Vec::with_capacity(piece.len() + 1);
18+
19+
// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
20+
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
21+
// merge priority from token index or to prevent specific token merges.
22+
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
23+
for i in 0..piece.len() - 1 {
24+
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
25+
if rank < min_rank.0 {
26+
min_rank = (rank, i);
27+
}
28+
parts.push((i, rank));
29+
}
30+
parts.push((piece.len() - 1, Rank::MAX));
31+
parts.push((piece.len(), Rank::MAX));
1932

2033
let get_rank = {
2134
#[inline(always)]
22-
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
23-
if (start_idx + skip + 2) < parts.len() {
24-
ranks
25-
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
26-
.copied()
35+
|parts: &Vec<(usize, Rank)>, i: usize| {
36+
if (i + 3) < parts.len() {
37+
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
38+
// parts[i + 1], see comment in the main loop.
39+
*ranks
40+
.get(&piece[parts[i].0..parts[i + 3].0])
41+
.unwrap_or(&Rank::MAX)
2742
} else {
28-
None
43+
Rank::MAX
2944
}
3045
}
3146
};
3247

33-
// We look up the ranks once in the beginning and iteratively update
34-
// them during each merge, which reduces the number of rank lookups.
35-
for i in 0..parts.len() - 2 {
36-
match get_rank(&parts, i, 0) {
37-
Some(rank) => {
38-
// Rank::MAX is a sentinel value and cannot be a valid rank
39-
debug_assert!(rank != Rank::MAX);
40-
parts[i].1 = rank;
41-
}
42-
None => {
43-
continue;
44-
}
45-
};
46-
}
47-
4848
// If you have n parts and m merges, this does O(mn) work.
4949
// We could do something with a heap and do O(m log n) work.
50-
// It is important to consider that n is often small (<100), and as such
51-
// the cache-locality benefits outweigh the algorithmic complexity downsides
52-
// of the `parts` vector data structure above.
53-
54-
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
55-
// currently do, this is equivalent. An easy way to break this would be to decouple
56-
// merge priority from token index or to prevent specific token merges.
57-
loop {
58-
if parts.len() == 1 {
59-
break;
50+
// n is often very small so considerations like cache-locality outweigh the algorithmic
51+
// complexity downsides of the `parts` vector.
52+
while min_rank.0 != Rank::MAX {
53+
let i = min_rank.1;
54+
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
55+
// `parts.remove(i + 1)` will thrash the cache.
56+
if i > 0 {
57+
parts[i - 1].1 = get_rank(&parts, i - 1);
6058
}
59+
parts[i].1 = get_rank(&parts, i);
60+
parts.remove(i + 1);
6161

62-
// Rank::MAX is a sentinel rank value allowing us to
63-
// take the min more quickly
64-
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
62+
min_rank = (Rank::MAX, usize::MAX);
6563
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
6664
if rank < min_rank.0 {
6765
min_rank = (rank, i);
6866
}
6967
}
70-
71-
if min_rank.0 != Rank::MAX {
72-
let i = min_rank.1;
73-
74-
// NOTE: We are about to remove parts[i + 1]. We do not do it
75-
// yet because there are cache-locality benefits to updating
76-
// parts[i] and parts[i-1] before removing, which could thrash
77-
// the cache. Thus, we update the rank calculation by skipping over
78-
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
79-
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
80-
if i > 0 {
81-
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
82-
}
83-
84-
parts.remove(i + 1);
85-
} else {
86-
break;
87-
}
8868
}
89-
69+
9070
parts
9171
}
9272

0 commit comments

Comments
 (0)