-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Simplify byte_pair_merge #255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap; | |
|
||
type Rank = u32; | ||
|
||
fn _byte_pair_merge( | ||
ranks: &HashMap<Vec<u8>, Rank>, | ||
piece: &[u8], | ||
) -> Vec<(usize, Rank)> { | ||
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { | ||
// This is a vector of (start, rank). | ||
// The rank is of the byte pair starting at position start. | ||
// The rank of the last item in the vector is not a valid value. | ||
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect(); | ||
// The rank is of the pair starting at position start. | ||
let mut parts = Vec::with_capacity(piece.len() + 1); | ||
|
||
// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE | ||
// the way we currently do, this is equivalent. An easy way to break this would be to decouple | ||
// merge priority from token index or to prevent specific token merges. | ||
let mut min_rank: (Rank, usize) = (Rank::MAX, 0); | ||
for i in 0..piece.len() - 1 { | ||
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX); | ||
if rank < min_rank.0 { | ||
min_rank = (rank, i); | ||
} | ||
parts.push((i, rank)); | ||
} | ||
parts.push((piece.len() - 1, Rank::MAX)); | ||
parts.push((piece.len(), Rank::MAX)); | ||
|
||
let get_rank = { | ||
#[inline(always)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you see any effect of the inlining here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question, I dimly remember it being useful in #31 (but it was also used in an additional place then). I can double check :-) Which linter? |
||
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| { | ||
if (start_idx + skip + 2) < parts.len() { | ||
ranks | ||
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) | ||
.copied() | ||
|parts: &Vec<(usize, Rank)>, i: usize| { | ||
if (i + 3) < parts.len() { | ||
hauntsaninja marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted | ||
// parts[i + 1], see comment in the main loop. | ||
*ranks | ||
.get(&piece[parts[i].0..parts[i + 3].0]) | ||
.unwrap_or(&Rank::MAX) | ||
} else { | ||
None | ||
Rank::MAX | ||
} | ||
} | ||
}; | ||
|
||
// We look up the ranks once in the beginning and iteratively update | ||
// them during each merge, which reduces the number of rank lookups. | ||
for i in 0..parts.len() - 2 { | ||
match get_rank(&parts, i, 0) { | ||
Some(rank) => { | ||
// Rank::MAX is a sentinel value and cannot be a valid rank | ||
debug_assert!(rank != Rank::MAX); | ||
parts[i].1 = rank; | ||
} | ||
None => { | ||
continue; | ||
} | ||
}; | ||
} | ||
|
||
// If you have n parts and m merges, this does O(mn) work. | ||
// We could do something with a heap and do O(m log n) work. | ||
// It is important to consider that n is often small (<100), and as such | ||
// the cache-locality benefits outweigh the algorithmic complexity downsides | ||
// of the `parts` vector data structure above. | ||
|
||
// Note that we hash bytes, not token pairs. As long as we train BPE the way we | ||
// currently do, this is equivalent. An easy way to break this would be to decouple | ||
// merge priority from token index or to prevent specific token merges. | ||
loop { | ||
if parts.len() == 1 { | ||
break; | ||
// n is often very small so considerations like cache-locality outweigh the algorithmic | ||
// complexity downsides of the `parts` vector. | ||
while min_rank.0 != Rank::MAX { | ||
let i = min_rank.1; | ||
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since | ||
// `parts.remove(i + 1)` will thrash the cache. | ||
parts[i].1 = get_rank(&parts, i); | ||
if i > 0 { | ||
parts[i - 1].1 = get_rank(&parts, i - 1); | ||
} | ||
parts.remove(i + 1); | ||
hauntsaninja marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Rank::MAX is a sentinel rank value allowing us to | ||
// take the min more quickly | ||
let mut min_rank: (Rank, usize) = (Rank::MAX, 0); | ||
min_rank = (Rank::MAX, 0); | ||
hauntsaninja marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { | ||
if rank < min_rank.0 { | ||
min_rank = (rank, i); | ||
} | ||
} | ||
|
||
if min_rank.0 != Rank::MAX { | ||
let i = min_rank.1; | ||
|
||
// NOTE: We are about to remove parts[i + 1]. We do not do it | ||
// yet because there are cache-locality benefits to updating | ||
// parts[i] and parts[i-1] before removing, which could thrash | ||
// the cache. Thus, we update the rank calculation by skipping over | ||
// parts[i + 1], by invoking `get_rank!` with `skip = 1`. | ||
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX); | ||
if i > 0 { | ||
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX); | ||
} | ||
|
||
parts.remove(i + 1); | ||
} else { | ||
break; | ||
} | ||
} | ||
|
||
parts | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.