1
- use std:: borrow:: Borrow ;
2
- use std:: borrow:: Cow ;
3
1
use std:: collections:: HashSet ;
4
2
use std:: num:: NonZeroU64 ;
5
3
use std:: thread;
@@ -14,6 +12,131 @@ mod py;
14
12
15
13
pub type Rank = u32 ;
16
14
15
+ use std:: collections:: BinaryHeap ;
16
+
17
+ #[ derive( Eq , PartialEq , Clone , Copy ) ]
18
+ struct Merge {
19
+ start : usize ,
20
+ rank : Rank ,
21
+ }
22
+
23
+ impl Ord for Merge {
24
+ #[ inline]
25
+ fn cmp ( & self , other : & Self ) -> std:: cmp:: Ordering {
26
+ other
27
+ . rank
28
+ . cmp ( & self . rank )
29
+ . then_with ( || other. start . cmp ( & self . start ) )
30
+ }
31
+ }
32
+
33
+ impl PartialOrd for Merge {
34
+ fn partial_cmp ( & self , other : & Self ) -> Option < std:: cmp:: Ordering > {
35
+ Some ( self . cmp ( other) )
36
+ }
37
+ }
38
+
39
+ struct State {
40
+ prev : usize ,
41
+ end : usize ,
42
+ next_end : usize ,
43
+ next_rank : Rank ,
44
+ cur_rank : Rank ,
45
+ }
46
+
47
+ fn _byte_pair_merge_large ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < Rank > {
48
+ let mut state = Vec :: with_capacity ( piece. len ( ) ) ;
49
+ state. push ( State {
50
+ prev : usize:: MAX ,
51
+ end : 1 ,
52
+ next_end : 2 ,
53
+ next_rank : Rank :: MAX ,
54
+ cur_rank : Rank :: MAX ,
55
+ } ) ;
56
+
57
+ let mut heap = BinaryHeap :: with_capacity ( piece. len ( ) ) ;
58
+ for i in 0 ..piece. len ( ) - 1 {
59
+ if let Some ( & rank) = ranks. get ( & piece[ i..i + 2 ] ) {
60
+ heap. push ( Merge { start : i, rank } ) ;
61
+ state[ i] . next_rank = rank;
62
+ }
63
+ // note this is happening offset by 1
64
+ state. push ( State {
65
+ prev : i,
66
+ end : i + 2 ,
67
+ next_end : i + 3 ,
68
+ next_rank : Rank :: MAX ,
69
+ cur_rank : Rank :: MAX ,
70
+ } ) ;
71
+ }
72
+
73
+ // Repeatedly find the valid merge with smallest rank. We merge the (left) token that
74
+ // starts at `start` and ends at `state[start].end` with the (right) token that starts at
75
+ // `state[start].end` and ends at `state[start].next_end`. We invalidate the old merges
76
+ // (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
77
+ // new potential merges to the heap.
78
+
79
+ let potential_merge = {
80
+ #[ inline( always) ]
81
+ |state : & mut Vec < State > ,
82
+ heap : & mut BinaryHeap < Merge > ,
83
+ start : usize ,
84
+ next_end_item : usize | {
85
+ state[ start] . next_end = next_end_item;
86
+ state[ start] . next_rank = Rank :: MAX ; // Always invalidate the old merge
87
+ if next_end_item <= piece. len ( ) {
88
+ if let Some ( & rank) = ranks. get ( & piece[ start..next_end_item] ) {
89
+ // We have a valid potential merge!
90
+ heap. push ( Merge { start, rank } ) ;
91
+ state[ start] . next_rank = rank;
92
+ }
93
+ }
94
+ }
95
+ } ;
96
+
97
+ while let Some ( left) = heap. pop ( ) {
98
+ if left. rank == Rank :: MAX {
99
+ break ;
100
+ }
101
+ if left. rank != state[ left. start ] . next_rank {
102
+ continue ; // This merge was invalidated, ignore it
103
+ }
104
+
105
+ let left_start = left. start ;
106
+ let right_start = state[ left_start] . end ;
107
+ let right_end = state[ left_start] . next_end ;
108
+ debug_assert ! ( right_end == state[ right_start] . end) ;
109
+ let right_next_end = state[ right_start] . next_end ;
110
+
111
+ // Merge left and right into a single token
112
+ state[ left_start] . cur_rank = state[ left_start] . next_rank ;
113
+ state[ left_start] . end = right_end;
114
+ potential_merge ( & mut state, & mut heap, left_start, right_next_end) ;
115
+ if right_end < state. len ( ) {
116
+ state[ right_end] . prev = left_start;
117
+ }
118
+ // Update the merge that ends at left_start
119
+ if left_start > 0 {
120
+ let prev_start = state[ left_start] . prev ;
121
+ potential_merge ( & mut state, & mut heap, prev_start, right_end) ;
122
+ }
123
+ // Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
124
+ state[ right_start] . next_rank = Rank :: MAX ;
125
+ }
126
+
127
+ let mut result = Vec :: new ( ) ;
128
+ let mut i = 0 ;
129
+ while i < state. len ( ) {
130
+ if state[ i] . cur_rank != Rank :: MAX {
131
+ result. push ( state[ i] . cur_rank ) ;
132
+ } else {
133
+ result. push ( ranks[ & piece[ i..state[ i] . end ] ] ) ;
134
+ }
135
+ i = state[ i] . end ;
136
+ }
137
+ result
138
+ }
139
+
17
140
fn _byte_pair_merge ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < ( usize , Rank ) > {
18
141
// This is a vector of (start, rank).
19
142
// The rank is of the pair starting at position start.
@@ -73,21 +196,27 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
73
196
}
74
197
75
198
pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
76
- if piece. len ( ) == 1 {
199
+ let piece_len = piece. len ( ) ;
200
+
201
+ if piece_len == 1 {
77
202
return vec ! [ ranks[ piece] ] ;
78
203
}
79
- _byte_pair_merge ( ranks, piece)
80
- . windows ( 2 )
81
- . map ( |part| ranks[ & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] ] )
82
- . collect ( )
204
+ if piece_len < 100 {
205
+ return _byte_pair_merge ( ranks, piece)
206
+ . windows ( 2 )
207
+ . map ( |part| ranks[ & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] ] )
208
+ . collect ( ) ;
209
+ }
210
+ return _byte_pair_merge_large ( ranks, piece) ;
83
211
}
84
212
85
213
pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < & ' a [ u8 ] > {
86
214
assert ! ( piece. len( ) > 1 ) ;
87
- _byte_pair_merge ( ranks, piece)
215
+ return _byte_pair_merge ( ranks, piece)
88
216
. windows ( 2 )
89
217
. map ( |part| & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] )
90
- . collect ( )
218
+ . collect ( ) ;
219
+ // TODO: _byte_pair_merge_large
91
220
}
92
221
93
222
// Various performance notes:
@@ -521,7 +650,7 @@ impl CoreBPE {
521
650
522
651
#[ cfg( test) ]
523
652
mod tests {
524
- use fancy_regex :: Regex ;
653
+
525
654
use rustc_hash:: FxHashMap as HashMap ;
526
655
527
656
use crate :: { byte_pair_split, Rank } ;
0 commit comments