Skip to content

Commit 0a46366

Browse files
committed
subquadratic
1 parent bb5805d commit 0a46366

File tree

1 file changed

+137
-10
lines changed

1 file changed

+137
-10
lines changed

src/lib.rs

Lines changed: 137 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::borrow::Borrow;
2-
use std::borrow::Cow;
31
use std::collections::HashSet;
42
use std::num::NonZeroU64;
53
use std::thread;
@@ -14,6 +12,131 @@ mod py;
1412

1513
pub type Rank = u32;
1614

15+
use std::collections::BinaryHeap;
16+
17+
#[derive(Eq, PartialEq, Clone, Copy)]
18+
struct Merge {
19+
start: usize,
20+
rank: Rank,
21+
}
22+
23+
impl Ord for Merge {
24+
#[inline]
25+
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
26+
other
27+
.rank
28+
.cmp(&self.rank)
29+
.then_with(|| other.start.cmp(&self.start))
30+
}
31+
}
32+
33+
impl PartialOrd for Merge {
34+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
35+
Some(self.cmp(other))
36+
}
37+
}
38+
39+
struct State {
40+
prev: usize,
41+
end: usize,
42+
next_end: usize,
43+
next_rank: Rank,
44+
cur_rank: Rank,
45+
}
46+
47+
fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
48+
let mut state = Vec::with_capacity(piece.len());
49+
state.push(State {
50+
prev: usize::MAX,
51+
end: 1,
52+
next_end: 2,
53+
next_rank: Rank::MAX,
54+
cur_rank: Rank::MAX,
55+
});
56+
57+
let mut heap = BinaryHeap::with_capacity(piece.len());
58+
for i in 0..piece.len() - 1 {
59+
if let Some(&rank) = ranks.get(&piece[i..i + 2]) {
60+
heap.push(Merge { start: i, rank });
61+
state[i].next_rank = rank;
62+
}
63+
// note this is happening offset by 1
64+
state.push(State {
65+
prev: i,
66+
end: i + 2,
67+
next_end: i + 3,
68+
next_rank: Rank::MAX,
69+
cur_rank: Rank::MAX,
70+
});
71+
}
72+
73+
// Repeatedly find the valid merge with smallest rank. We merge the (left) token that
74+
// starts at `start` and ends at `state[start].end` with the (right) token that starts at
75+
// `state[start].end` and ends at `state[start].next_end`. We invalidate the old merges
76+
// (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
77+
// new potential merges to the heap.
78+
79+
let potential_merge = {
80+
#[inline(always)]
81+
|state: &mut Vec<State>,
82+
heap: &mut BinaryHeap<Merge>,
83+
start: usize,
84+
next_end_item: usize| {
85+
state[start].next_end = next_end_item;
86+
state[start].next_rank = Rank::MAX; // Always invalidate the old merge
87+
if next_end_item <= piece.len() {
88+
if let Some(&rank) = ranks.get(&piece[start..next_end_item]) {
89+
// We have a valid potential merge!
90+
heap.push(Merge { start, rank });
91+
state[start].next_rank = rank;
92+
}
93+
}
94+
}
95+
};
96+
97+
while let Some(left) = heap.pop() {
98+
if left.rank == Rank::MAX {
99+
break;
100+
}
101+
if left.rank != state[left.start].next_rank {
102+
continue; // This merge was invalidated, ignore it
103+
}
104+
105+
let left_start = left.start;
106+
let right_start = state[left_start].end;
107+
let right_end = state[left_start].next_end;
108+
debug_assert!(right_end == state[right_start].end);
109+
let right_next_end = state[right_start].next_end;
110+
111+
// Merge left and right into a single token
112+
state[left_start].cur_rank = state[left_start].next_rank;
113+
state[left_start].end = right_end;
114+
potential_merge(&mut state, &mut heap, left_start, right_next_end);
115+
if right_end < state.len() {
116+
state[right_end].prev = left_start;
117+
}
118+
// Update the merge that ends at left_start
119+
if left_start > 0 {
120+
let prev_start = state[left_start].prev;
121+
potential_merge(&mut state, &mut heap, prev_start, right_end);
122+
}
123+
// Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
124+
state[right_start].next_rank = Rank::MAX;
125+
}
126+
127+
let mut result = Vec::new();
128+
let mut i = 0;
129+
while i < state.len() {
130+
if state[i].cur_rank != Rank::MAX {
131+
result.push(state[i].cur_rank);
132+
} else {
133+
result.push(ranks[&piece[i..state[i].end]]);
134+
}
135+
i = state[i].end;
136+
}
137+
result
138+
}
139+
17140
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
18141
// This is a vector of (start, rank).
19142
// The rank is of the pair starting at position start.
@@ -73,21 +196,25 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
73196
}
74197

75198
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
76-
if piece.len() == 1 {
77-
return vec![ranks[piece]];
78-
}
79-
_byte_pair_merge(ranks, piece)
199+
assert!(piece.len() > 1);
200+
_byte_pair_merge_large(ranks, piece)
201+
/*
80202
.windows(2)
81-
.map(|part| ranks[&piece[part[0].0..part[1].0]])
203+
// .map(|part| ranks[&piece[dbg!(part[0].0..part[1].0)]])
204+
.map(|part| ranks[&piece[part[0]..part[1]]])
82205
.collect()
206+
*/
83207
}
84208

85-
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
209+
pub fn byte_pair_split<'a>(piece: &'a [u8], _ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
86210
assert!(piece.len() > 1);
87-
_byte_pair_merge(ranks, piece)
211+
panic!("Not implemented");
212+
/*
213+
_byte_pair_merge_large(&ranks, &piece)
88214
.windows(2)
89215
.map(|part| &piece[part[0].0..part[1].0])
90216
.collect()
217+
*/
91218
}
92219

93220
// Various performance notes:
@@ -521,7 +648,7 @@ impl CoreBPE {
521648

522649
#[cfg(test)]
523650
mod tests {
524-
use fancy_regex::Regex;
651+
525652
use rustc_hash::FxHashMap as HashMap;
526653

527654
use crate::{byte_pair_split, Rank};

0 commit comments

Comments
 (0)