@@ -13,80 +13,60 @@ fn _byte_pair_merge(
13
13
piece : & [ u8 ] ,
14
14
) -> Vec < ( usize , Rank ) > {
15
15
// This is a vector of (start, rank).
16
- // The rank is of the byte pair starting at position start.
17
- // The rank of the last item in the vector is not a valid value.
18
- let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
16
+ // The rank is of the pair starting at position start.
17
+ let mut parts = Vec :: with_capacity ( piece. len ( ) + 1 ) ;
18
+
19
+ // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
20
+ // the way we currently do, this is equivalent. An easy way to break this would be to decouple
21
+ // merge priority from token index or to prevent specific token merges.
22
+ let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , usize:: MAX ) ;
23
+ for i in 0 ..piece. len ( ) - 1 {
24
+ let rank = * ranks. get ( & piece[ i..i + 2 ] ) . unwrap_or ( & Rank :: MAX ) ;
25
+ if rank < min_rank. 0 {
26
+ min_rank = ( rank, i) ;
27
+ }
28
+ parts. push ( ( i, rank) ) ;
29
+ }
30
+ parts. push ( ( piece. len ( ) - 1 , Rank :: MAX ) ) ;
31
+ parts. push ( ( piece. len ( ) , Rank :: MAX ) ) ;
19
32
20
33
let get_rank = {
21
34
#[ inline( always) ]
22
- |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
23
- if ( start_idx + skip + 2 ) < parts. len ( ) {
24
- ranks
25
- . get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
26
- . copied ( )
35
+ |parts : & Vec < ( usize , Rank ) > , i : usize | {
36
+ if ( i + 3 ) < parts. len ( ) {
37
+ // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
38
+ // parts[i + 1], see comment in the main loop.
39
+ * ranks
40
+ . get ( & piece[ parts[ i] . 0 ..parts[ i + 3 ] . 0 ] )
41
+ . unwrap_or ( & Rank :: MAX )
27
42
} else {
28
- None
43
+ Rank :: MAX
29
44
}
30
45
}
31
46
} ;
32
47
33
- // We look up the ranks once in the beginning and iteratively update
34
- // them during each merge, which reduces the number of rank lookups.
35
- for i in 0 ..parts. len ( ) - 2 {
36
- match get_rank ( & parts, i, 0 ) {
37
- Some ( rank) => {
38
- // Rank::MAX is a sentinel value and cannot be a valid rank
39
- debug_assert ! ( rank != Rank :: MAX ) ;
40
- parts[ i] . 1 = rank;
41
- }
42
- None => {
43
- continue ;
44
- }
45
- } ;
46
- }
47
-
48
48
// If you have n parts and m merges, this does O(mn) work.
49
49
// We could do something with a heap and do O(m log n) work.
50
- // It is important to consider that n is often small (<100), and as such
51
- // the cache-locality benefits outweigh the algorithmic complexity downsides
52
- // of the `parts` vector data structure above.
53
-
54
- // Note that we hash bytes, not token pairs. As long as we train BPE the way we
55
- // currently do, this is equivalent. An easy way to break this would be to decouple
56
- // merge priority from token index or to prevent specific token merges.
57
- loop {
58
- if parts. len ( ) == 1 {
59
- break ;
50
+ // n is often very small so considerations like cache-locality outweigh the algorithmic
51
+ // complexity downsides of the `parts` vector.
52
+ while min_rank. 0 != Rank :: MAX {
53
+ let i = min_rank. 1 ;
54
+ // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
55
+ // `parts.remove(i + 1)` will thrash the cache.
56
+ if i > 0 {
57
+ parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 ) ;
60
58
}
59
+ parts[ i] . 1 = get_rank ( & parts, i) ;
60
+ parts. remove ( i + 1 ) ;
61
61
62
- // Rank::MAX is a sentinel rank value allowing us to
63
- // take the min more quickly
64
- let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
62
+ min_rank = ( Rank :: MAX , usize:: MAX ) ;
65
63
for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
66
64
if rank < min_rank. 0 {
67
65
min_rank = ( rank, i) ;
68
66
}
69
67
}
70
-
71
- if min_rank. 0 != Rank :: MAX {
72
- let i = min_rank. 1 ;
73
-
74
- // NOTE: We are about to remove parts[i + 1]. We do not do it
75
- // yet because there are cache-locality benefits to updating
76
- // parts[i] and parts[i-1] before removing, which could thrash
77
- // the cache. Thus, we update the rank calculation by skipping over
78
- // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
79
- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
80
- if i > 0 {
81
- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
82
- }
83
-
84
- parts. remove ( i + 1 ) ;
85
- } else {
86
- break ;
87
- }
88
68
}
89
-
69
+
90
70
parts
91
71
}
92
72
0 commit comments