Skip to content

Commit 2eb5995

Browse files
committed
Update PreTokenizers so Metaspace can conditionally act
1 parent 9da338e commit 2eb5995

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

Sources/Tokenizers/PreTokenizer.swift

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,26 @@ import Foundation
99
import Hub
1010

1111
public protocol PreTokenizer {
12-
func preTokenize(text: String) -> [String]
13-
func preTokenize(texts: [String]) -> [String]
14-
func callAsFunction(texts: [String]) -> [String]
15-
func callAsFunction(text: String) -> [String]
12+
func preTokenize(text: String, firstSection: Bool) -> [String]
13+
func preTokenize(texts: [String], firstSection: Bool) -> [String]
14+
func callAsFunction(texts: [String], firstSection: Bool) -> [String]
15+
func callAsFunction(text: String, firstSection: Bool) -> [String]
1616

1717
init(config: Config)
1818
}
1919

2020
extension PreTokenizer {
21-
func preTokenize(texts: [String]) -> [String] {
22-
texts.flatMap { preTokenize(text: $0) }
21+
func preTokenize(texts: [String], firstSection: Bool = true) -> [String] {
22+
texts.flatMap { preTokenize(text: $0, firstSection: firstSection) }
2323
}
2424

25-
func callAsFunction(texts: [String]) -> [String] {
26-
return preTokenize(texts: texts)
25+
func callAsFunction(texts: [String], firstSection: Bool = true) -> [String] {
26+
return preTokenize(texts: texts, firstSection: firstSection)
2727
}
2828

29-
func callAsFunction(text: String) -> [String] {
30-
return preTokenize(text: text)
29+
func callAsFunction(text: String, firstSection: Bool = true) -> [String] {
30+
return preTokenize(text: text, firstSection: firstSection)
3131
}
32-
3332
}
3433

3534
enum PreTokenizerType: String {
@@ -71,9 +70,9 @@ class PreTokenizerSequence: PreTokenizer {
7170
preTokenizers = configs.compactMap { PreTokenizerFactory.fromConfig(config: $0) }
7271
}
7372

74-
func preTokenize(text: String) -> [String] {
73+
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
7574
preTokenizers.reduce([text]) { current, preTokenizer in
76-
preTokenizer(texts: current)
75+
preTokenizer(texts: current, firstSection: firstSection)
7776
}
7877
}
7978
}
@@ -85,7 +84,7 @@ class WhitespacePreTokenizer: PreTokenizer {
8584
re = #"\S+"#
8685
}
8786

88-
func preTokenize(text: String) -> [String] {
87+
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
8988
return text.ranges(of: re).map { String(text[$0]) }
9089
}
9190
}
@@ -125,7 +124,7 @@ class MetaspacePreTokenizer: PreTokenizer {
125124

126125
// https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L114
127126
// https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/tokenizers.js#L2153
128-
func preTokenize(text: String) -> [String] {
127+
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
129128
let normalized = text.replacingOccurrences(of: " ", with: stringReplacement)
130129

131130
// We add a prefix space if:
@@ -141,7 +140,7 @@ class MetaspacePreTokenizer: PreTokenizer {
141140
if prependScheme == .always {
142141
prepend = stringReplacement
143142
}
144-
if prependScheme == .first /* && first_section */ {
143+
if prependScheme == .first && firstSection {
145144
prepend = stringReplacement
146145
}
147146
}
@@ -164,7 +163,7 @@ class ByteLevelPreTokenizer: PreTokenizer {
164163
useRegex = config.useRegex?.boolValue ?? true
165164
}
166165

167-
func preTokenize(text: String) -> [String] {
166+
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
168167
// Split on whitespace and punctuation
169168
let tokens = useRegex ? text.ranges(of: RE).map({ String(text[$0]) }) : [text]
170169
return tokens.map { token in
@@ -186,7 +185,7 @@ class PunctuationPreTokenizer: PreTokenizer {
186185
re = "[^\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]+"
187186
}
188187

189-
func preTokenize(text: String) -> [String] {
188+
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
190189
// Ref: https://github.com/xenova/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1138
191190
return text.ranges(of: re).map { String(text[$0]) }
192191
}
@@ -200,7 +199,7 @@ class DigitsPreTokenizer: PreTokenizer {
200199
re = "[^\\d]+|\\d\(individualDigits ? "" : "+")"
201200
}
202201

203-
func preTokenize(text: String) -> [String] {
202+
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
204203
return text.ranges(of: re).map { String(text[$0]) }
205204
}
206205
}
@@ -214,7 +213,7 @@ class SplitPreTokenizer: PreTokenizer {
214213
invert = config.invert?.boolValue ?? false
215214
}
216215

217-
func preTokenize(text: String) -> [String] {
216+
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
218217
guard let pattern = pattern else { return [text] }
219218
return pattern.split(text, invert: invert)
220219
}

Sources/Tokenizers/Tokenizer.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ public class PreTrainedTokenizer: Tokenizer {
185185
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
186186
}
187187

188-
func preTokenize(_ text: String) -> [String] {
188+
func preTokenize(_ text: String, firstSection: Bool) -> [String] {
189189
guard let preTokenizer = preTokenizer else { return [text] }
190-
return preTokenizer(text: text)
190+
return preTokenizer(text: text, firstSection: firstSection)
191191
}
192192

193193
func normalize(_ text: String) -> String {
@@ -229,9 +229,9 @@ public class PreTrainedTokenizer: Tokenizer {
229229
} else {
230230
sections = [text]
231231
}
232-
return sections.map { x in
232+
return sections.enumerated().map { section, x in
233233
if addedTokens.contains(x) { return [x] }
234-
return preTokenize(normalize(x)).flatMap { model($0) }
234+
return preTokenize(normalize(x), firstSection: section == 0).flatMap { model($0) }
235235
}.flatMap { $0 }
236236
}
237237

0 commit comments

Comments
 (0)