1
- use std:: borrow:: Borrow ;
2
- use std:: borrow:: Cow ;
3
1
use std:: collections:: HashSet ;
4
2
use std:: num:: NonZeroU64 ;
5
3
use std:: thread;
@@ -14,6 +12,131 @@ mod py;
14
12
15
13
pub type Rank = u32 ;
16
14
15
+ use std:: collections:: BinaryHeap ;
16
+
17
+ #[ derive( Eq , PartialEq , Clone , Copy ) ]
18
+ struct Merge {
19
+ start : usize ,
20
+ rank : Rank ,
21
+ }
22
+
23
+ impl Ord for Merge {
24
+ #[ inline]
25
+ fn cmp ( & self , other : & Self ) -> std:: cmp:: Ordering {
26
+ other
27
+ . rank
28
+ . cmp ( & self . rank )
29
+ . then_with ( || other. start . cmp ( & self . start ) )
30
+ }
31
+ }
32
+
33
+ impl PartialOrd for Merge {
34
+ fn partial_cmp ( & self , other : & Self ) -> Option < std:: cmp:: Ordering > {
35
+ Some ( self . cmp ( other) )
36
+ }
37
+ }
38
+
39
+ struct State {
40
+ prev : usize ,
41
+ end : usize ,
42
+ next_end : usize ,
43
+ next_rank : Rank ,
44
+ cur_rank : Rank ,
45
+ }
46
+
47
+ fn _byte_pair_merge_large ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < Rank > {
48
+ let mut state = Vec :: with_capacity ( piece. len ( ) ) ;
49
+ state. push ( State {
50
+ prev : usize:: MAX ,
51
+ end : 1 ,
52
+ next_end : 2 ,
53
+ next_rank : Rank :: MAX ,
54
+ cur_rank : Rank :: MAX ,
55
+ } ) ;
56
+
57
+ let mut heap = BinaryHeap :: with_capacity ( piece. len ( ) ) ;
58
+ for i in 0 ..piece. len ( ) - 1 {
59
+ if let Some ( & rank) = ranks. get ( & piece[ i..i + 2 ] ) {
60
+ heap. push ( Merge { start : i, rank } ) ;
61
+ state[ i] . next_rank = rank;
62
+ }
63
+ // note this is happening offset by 1
64
+ state. push ( State {
65
+ prev : i,
66
+ end : i + 2 ,
67
+ next_end : i + 3 ,
68
+ next_rank : Rank :: MAX ,
69
+ cur_rank : Rank :: MAX ,
70
+ } ) ;
71
+ }
72
+
73
+ // Repeatedly find the valid merge with smallest rank. We merge the (left) token that
74
+ // starts at `start` and ends at `state[start].end` with the (right) token that starts at
75
+ // `state[start].end` and ends at `state[start].next_end`. We invalidate the old merges
76
+ // (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
77
+ // new potential merges to the heap.
78
+
79
+ let potential_merge = {
80
+ #[ inline( always) ]
81
+ |state : & mut Vec < State > ,
82
+ heap : & mut BinaryHeap < Merge > ,
83
+ start : usize ,
84
+ next_end_item : usize | {
85
+ state[ start] . next_end = next_end_item;
86
+ state[ start] . next_rank = Rank :: MAX ; // Always invalidate the old merge
87
+ if next_end_item <= piece. len ( ) {
88
+ if let Some ( & rank) = ranks. get ( & piece[ start..next_end_item] ) {
89
+ // We have a valid potential merge!
90
+ heap. push ( Merge { start, rank } ) ;
91
+ state[ start] . next_rank = rank;
92
+ }
93
+ }
94
+ }
95
+ } ;
96
+
97
+ while let Some ( left) = heap. pop ( ) {
98
+ if left. rank == Rank :: MAX {
99
+ break ;
100
+ }
101
+ if left. rank != state[ left. start ] . next_rank {
102
+ continue ; // This merge was invalidated, ignore it
103
+ }
104
+
105
+ let left_start = left. start ;
106
+ let right_start = state[ left_start] . end ;
107
+ let right_end = state[ left_start] . next_end ;
108
+ debug_assert ! ( right_end == state[ right_start] . end) ;
109
+ let right_next_end = state[ right_start] . next_end ;
110
+
111
+ // Merge left and right into a single token
112
+ state[ left_start] . cur_rank = state[ left_start] . next_rank ;
113
+ state[ left_start] . end = right_end;
114
+ potential_merge ( & mut state, & mut heap, left_start, right_next_end) ;
115
+ if right_end < state. len ( ) {
116
+ state[ right_end] . prev = left_start;
117
+ }
118
+ // Update the merge that ends at left_start
119
+ if left_start > 0 {
120
+ let prev_start = state[ left_start] . prev ;
121
+ potential_merge ( & mut state, & mut heap, prev_start, right_end) ;
122
+ }
123
+ // Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
124
+ state[ right_start] . next_rank = Rank :: MAX ;
125
+ }
126
+
127
+ let mut result = Vec :: new ( ) ;
128
+ let mut i = 0 ;
129
+ while i < state. len ( ) {
130
+ if state[ i] . cur_rank != Rank :: MAX {
131
+ result. push ( state[ i] . cur_rank ) ;
132
+ } else {
133
+ result. push ( ranks[ & piece[ i..state[ i] . end ] ] ) ;
134
+ }
135
+ i = state[ i] . end ;
136
+ }
137
+ result
138
+ }
139
+
17
140
fn _byte_pair_merge ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < ( usize , Rank ) > {
18
141
// This is a vector of (start, rank).
19
142
// The rank is of the pair starting at position start.
@@ -73,21 +196,25 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
73
196
}
74
197
75
198
pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
76
- if piece. len ( ) == 1 {
77
- return vec ! [ ranks[ piece] ] ;
78
- }
79
- _byte_pair_merge ( ranks, piece)
199
+ assert ! ( piece. len( ) > 1 ) ;
200
+ _byte_pair_merge_large ( ranks, piece)
201
+ /*
80
202
.windows(2)
81
- . map ( |part| ranks[ & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] ] )
203
+ // .map(|part| ranks[&piece[dbg!(part[0].0..part[1].0)]])
204
+ .map(|part| ranks[&piece[part[0]..part[1]]])
82
205
.collect()
206
+ */
83
207
}
84
208
85
- pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < & ' a [ u8 ] > {
209
+ pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , _ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < & ' a [ u8 ] > {
86
210
assert ! ( piece. len( ) > 1 ) ;
87
- _byte_pair_merge ( ranks, piece)
211
+ panic ! ( "Not implemented" ) ;
212
+ /*
213
+ _byte_pair_merge_large(&ranks, &piece)
88
214
.windows(2)
89
215
.map(|part| &piece[part[0].0..part[1].0])
90
216
.collect()
217
+ */
91
218
}
92
219
93
220
// Various performance notes:
@@ -521,7 +648,7 @@ impl CoreBPE {
521
648
522
649
#[ cfg( test) ]
523
650
mod tests {
524
- use fancy_regex :: Regex ;
651
+
525
652
use rustc_hash:: FxHashMap as HashMap ;
526
653
527
654
use crate :: { byte_pair_split, Rank } ;
0 commit comments