Skip to content

Commit 9da338e

Browse files
committed
Build added tokens split regexp, shortcut before pre-tokenization
1 parent 92611e4 commit 9da338e

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ public class PreTrainedTokenizer: Tokenizer {
141141

142142
private let addedTokens: Set<String>
143143
private let specialTokens: [String: Int]
144+
private let addedTokensRegex: NSRegularExpression?
144145

145146
private let preTokenizer: PreTokenizer?
146147
private let normalizer: Normalizer?
@@ -161,6 +162,16 @@ public class PreTrainedTokenizer: Tokenizer {
161162
specialTokens[content] = id
162163
}
163164
}
165+
166+
let addedTokensRegexString = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in
167+
guard let content = addedToken.content?.stringValue else { return nil }
168+
let prefix = (addedToken.lstrip?.boolValue ?? false ? #"\s*"# : "")
169+
let suffix = (addedToken.rstrip?.boolValue ?? false ? #"\s*"# : "")
170+
let token = NSRegularExpression.escapedPattern(for: content)
171+
return "\(prefix)(\(token))\(suffix)"
172+
}.joined(separator: "|")
173+
addedTokensRegex = try? NSRegularExpression(pattern: addedTokensRegexString, options: [])
174+
164175
// TODO: specialTokens are stored but never used
165176
self.specialTokens = specialTokens
166177
self.addedTokens = Set(addedTokens.keys)
@@ -211,7 +222,17 @@ public class PreTrainedTokenizer: Tokenizer {
211222
}
212223

213224
public func tokenize(text: String) -> [String] {
214-
preTokenize(normalize(text)).flatMap { model($0) }
225+
// Take care of special tokens first
226+
let sections: [String]
227+
if let regex = self.addedTokensRegex {
228+
sections = text.split(by: regex)
229+
} else {
230+
sections = [text]
231+
}
232+
return sections.map { x in
233+
if addedTokens.contains(x) { return [x] }
234+
return preTokenize(normalize(x)).flatMap { model($0) }
235+
}.flatMap { $0 }
215236
}
216237

217238
/// Main entry point

0 commit comments

Comments
 (0)