Skip to content

Commit 0a606f5

Browse files
authored
Added Tokens (#93)
* Split by regexp with capture groups The other split helpers we have don't work for capture groups. We had to resort to raw `NSRegularExpression`s * Build added tokens split regexp, shortcut before pre-tokenization * Update PreTokenizers so Metaspace can conditionally act * Create LlamaPreTrainedTokenizer subclass We need some custom behaviour that's not in the config :( * Rename test * Replace with enum for future extensibility
1 parent 5e02089 commit 0a606f5

File tree

3 files changed

+200
-26
lines changed

3 files changed

+200
-26
lines changed

Sources/Tokenizers/PreTokenizer.swift

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,33 @@
88
import Foundation
99
import Hub
1010

11+
public enum PreTokenizerOption: String {
12+
case firstSection
13+
}
14+
15+
public typealias PreTokenizerOptions = Set<PreTokenizerOption>
16+
1117
public protocol PreTokenizer {
12-
func preTokenize(text: String) -> [String]
13-
func preTokenize(texts: [String]) -> [String]
14-
func callAsFunction(texts: [String]) -> [String]
15-
func callAsFunction(text: String) -> [String]
18+
func preTokenize(text: String, options: PreTokenizerOptions) -> [String]
19+
func preTokenize(texts: [String], options: PreTokenizerOptions) -> [String]
20+
func callAsFunction(texts: [String], options: PreTokenizerOptions) -> [String]
21+
func callAsFunction(text: String, options: PreTokenizerOptions) -> [String]
1622

1723
init(config: Config)
1824
}
1925

2026
extension PreTokenizer {
21-
func preTokenize(texts: [String]) -> [String] {
22-
texts.flatMap { preTokenize(text: $0) }
27+
func preTokenize(texts: [String], options: PreTokenizerOptions = [.firstSection]) -> [String] {
28+
texts.flatMap { preTokenize(text: $0, options: options) }
2329
}
2430

25-
func callAsFunction(texts: [String]) -> [String] {
26-
return preTokenize(texts: texts)
31+
func callAsFunction(texts: [String], options: PreTokenizerOptions = [.firstSection]) -> [String] {
32+
return preTokenize(texts: texts, options: options)
2733
}
2834

29-
func callAsFunction(text: String) -> [String] {
30-
return preTokenize(text: text)
35+
func callAsFunction(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
36+
return preTokenize(text: text, options: options)
3137
}
32-
3338
}
3439

3540
enum PreTokenizerType: String {
@@ -71,9 +76,9 @@ class PreTokenizerSequence: PreTokenizer {
7176
preTokenizers = configs.compactMap { PreTokenizerFactory.fromConfig(config: $0) }
7277
}
7378

74-
func preTokenize(text: String) -> [String] {
79+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
7580
preTokenizers.reduce([text]) { current, preTokenizer in
76-
preTokenizer(texts: current)
81+
preTokenizer(texts: current, options: options)
7782
}
7883
}
7984
}
@@ -85,7 +90,7 @@ class WhitespacePreTokenizer: PreTokenizer {
8590
re = #"\S+"#
8691
}
8792

88-
func preTokenize(text: String) -> [String] {
93+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
8994
return text.ranges(of: re).map { String(text[$0]) }
9095
}
9196
}
@@ -125,7 +130,7 @@ class MetaspacePreTokenizer: PreTokenizer {
125130

126131
// https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L114
127132
// https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/tokenizers.js#L2153
128-
func preTokenize(text: String) -> [String] {
133+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
129134
let normalized = text.replacingOccurrences(of: " ", with: stringReplacement)
130135

131136
// We add a prefix space if:
@@ -141,7 +146,7 @@ class MetaspacePreTokenizer: PreTokenizer {
141146
if prependScheme == .always {
142147
prepend = stringReplacement
143148
}
144-
if prependScheme == .first /* && first_section */ {
149+
if prependScheme == .first && options.contains(.firstSection) {
145150
prepend = stringReplacement
146151
}
147152
}
@@ -164,7 +169,7 @@ class ByteLevelPreTokenizer: PreTokenizer {
164169
useRegex = config.useRegex?.boolValue ?? true
165170
}
166171

167-
func preTokenize(text: String) -> [String] {
172+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
168173
// Split on whitespace and punctuation
169174
let tokens = useRegex ? text.ranges(of: RE).map({ String(text[$0]) }) : [text]
170175
return tokens.map { token in
@@ -186,7 +191,7 @@ class PunctuationPreTokenizer: PreTokenizer {
186191
re = "[^\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]+"
187192
}
188193

189-
func preTokenize(text: String) -> [String] {
194+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
190195
// Ref: https://github.com/xenova/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1138
191196
return text.ranges(of: re).map { String(text[$0]) }
192197
}
@@ -200,7 +205,7 @@ class DigitsPreTokenizer: PreTokenizer {
200205
re = "[^\\d]+|\\d\(individualDigits ? "" : "+")"
201206
}
202207

203-
func preTokenize(text: String) -> [String] {
208+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
204209
return text.ranges(of: re).map { String(text[$0]) }
205210
}
206211
}
@@ -214,7 +219,7 @@ class SplitPreTokenizer: PreTokenizer {
214219
invert = config.invert?.boolValue ?? false
215220
}
216221

217-
func preTokenize(text: String) -> [String] {
222+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
218223
guard let pattern = pattern else { return [text] }
219224
return pattern.split(text, invert: invert)
220225
}
@@ -248,7 +253,7 @@ extension StringSplitPattern {
248253
}
249254
}
250255

251-
extension String {
256+
public extension String {
252257
func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range<Index>] {
253258
var result: [Range<Index>] = []
254259
var start = startIndex
@@ -277,6 +282,42 @@ extension String {
277282
return result
278283
}
279284

285+
/// This version supports capture groups, wheres the one above doesn't
286+
func split(by captureRegex: NSRegularExpression) -> [String] {
287+
// Find the matching capture groups
288+
let selfRange = NSRange(startIndex..<endIndex, in: self)
289+
let matches = captureRegex.matches(in: self, options: [], range: selfRange)
290+
291+
if matches.first == nil { return [self] }
292+
293+
var result: [String] = []
294+
var start = startIndex
295+
for match in matches {
296+
// Append prefix before matched separator
297+
let prefixEnd = index(startIndex, offsetBy: match.range.lowerBound)
298+
if start < prefixEnd {
299+
result.append(String(self[start..<prefixEnd]))
300+
}
301+
start = index(startIndex, offsetBy: match.range.upperBound)
302+
303+
// Append separator, supporting capture groups
304+
for r in (0..<match.numberOfRanges).reversed() {
305+
let matchRange = match.range(at: r)
306+
if let sepRange = Range(matchRange, in:self) {
307+
result.append(String(self[sepRange]))
308+
break
309+
}
310+
}
311+
}
312+
313+
// Append remaining suffix
314+
let beginningOfEnd = index(startIndex, offsetBy: matches.last!.range.upperBound)
315+
if beginningOfEnd < endIndex {
316+
result.append(String(self[beginningOfEnd...]))
317+
}
318+
319+
return result
320+
}
280321
}
281322

282323
public enum SplitDelimiterBehavior {

Sources/Tokenizers/Tokenizer.swift

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ public class PreTrainedTokenizer: Tokenizer {
141141

142142
private let addedTokens: Set<String>
143143
private let specialTokens: [String: Int]
144+
private let addedTokensRegex: NSRegularExpression?
144145

145146
private let preTokenizer: PreTokenizer?
146147
private let normalizer: Normalizer?
@@ -161,6 +162,16 @@ public class PreTrainedTokenizer: Tokenizer {
161162
specialTokens[content] = id
162163
}
163164
}
165+
166+
let addedTokensRegexString = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in
167+
guard let content = addedToken.content?.stringValue else { return nil }
168+
let prefix = (addedToken.lstrip?.boolValue ?? false ? #"\s*"# : "")
169+
let suffix = (addedToken.rstrip?.boolValue ?? false ? #"\s*"# : "")
170+
let token = NSRegularExpression.escapedPattern(for: content)
171+
return "\(prefix)(\(token))\(suffix)"
172+
}.joined(separator: "|")
173+
addedTokensRegex = try? NSRegularExpression(pattern: addedTokensRegexString, options: [])
174+
164175
// TODO: specialTokens are stored but never used
165176
self.specialTokens = specialTokens
166177
self.addedTokens = Set(addedTokens.keys)
@@ -174,9 +185,9 @@ public class PreTrainedTokenizer: Tokenizer {
174185
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
175186
}
176187

177-
func preTokenize(_ text: String) -> [String] {
188+
func preTokenize(_ text: String, options: PreTokenizerOptions) -> [String] {
178189
guard let preTokenizer = preTokenizer else { return [text] }
179-
return preTokenizer(text: text)
190+
return preTokenizer(text: text, options: options)
180191
}
181192

182193
func normalize(_ text: String) -> String {
@@ -211,7 +222,17 @@ public class PreTrainedTokenizer: Tokenizer {
211222
}
212223

213224
public func tokenize(text: String) -> [String] {
214-
preTokenize(normalize(text)).flatMap { model($0) }
225+
// Take care of special tokens first
226+
let sections: [String]
227+
if let regex = self.addedTokensRegex {
228+
sections = text.split(by: regex)
229+
} else {
230+
sections = [text]
231+
}
232+
return sections.enumerated().map { section, x in
233+
if addedTokens.contains(x) { return [x] }
234+
return preTokenize(normalize(x), options: section == 0 ? [.firstSection] : []).flatMap { model($0) }
235+
}.flatMap { $0 }
215236
}
216237

217238
/// Main entry point
@@ -241,9 +262,32 @@ public class PreTrainedTokenizer: Tokenizer {
241262

242263
public struct AutoTokenizer {}
243264

265+
struct PreTrainedTokenizerClasses {
266+
/// Class overrides for custom behaviour
267+
/// Not to be confused with the TokenizerModel classes defined in TokenizerModel
268+
static let tokenizerClasses: [String : PreTrainedTokenizer.Type] = [
269+
"LlamaTokenizer": LlamaPreTrainedTokenizer.self
270+
]
271+
}
272+
244273
extension AutoTokenizer {
274+
static func tokenizerClass(for tokenizerConfig: Config) -> PreTrainedTokenizer.Type {
275+
guard let tokenizerClassName = tokenizerConfig.tokenizerClass?.stringValue else {
276+
return PreTrainedTokenizer.self
277+
}
278+
279+
// Some tokenizer_class entries use a Fast suffix
280+
let tokenizerName = tokenizerClassName.replacingOccurrences(of: "Fast", with: "")
281+
if let tokenizerClass = PreTrainedTokenizerClasses.tokenizerClasses[tokenizerName] {
282+
return tokenizerClass
283+
}
284+
285+
return PreTrainedTokenizer.self
286+
}
287+
245288
public static func from(tokenizerConfig: Config, tokenizerData: Config) throws -> Tokenizer {
246-
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
289+
let tokenizerClass = tokenizerClass(for: tokenizerConfig)
290+
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
247291
}
248292

249293
public static func from(
@@ -254,7 +298,7 @@ extension AutoTokenizer {
254298
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
255299
let tokenizerData = try await config.tokenizerData
256300

257-
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
301+
return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
258302
}
259303

260304
public static func from(
@@ -281,3 +325,25 @@ class CodeLlamaTokenizer: BPETokenizer {}
281325
class CohereTokenizer : BPETokenizer {}
282326

283327
class T5Tokenizer : UnigramTokenizer {}
328+
329+
330+
// MARK: - PreTrainedTokenizer classes
331+
332+
let sentencePieceUnderline = ""
333+
334+
// See https://github.com/xenova/transformers.js/blob/1a9964fb09b8f54fcbeac46dc6aae8d76795809d/src/tokenizers.js#L3203 for these exceptions
335+
class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
336+
let isLegacy: Bool
337+
338+
required init(tokenizerConfig: Config, tokenizerData: Config) throws {
339+
isLegacy = tokenizerConfig.legacy?.boolValue ?? true
340+
var configDictionary = tokenizerData.dictionary
341+
if !isLegacy {
342+
configDictionary.removeValue(forKey: "normalizer")
343+
configDictionary["pre_tokenizer"] = ["type": "Metaspace", "replacement": sentencePieceUnderline, "add_prefix_space": true, "prepend_scheme": "first"]
344+
}
345+
let updatedData = Config(configDictionary)
346+
347+
try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData)
348+
}
349+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//
2+
// AddedTokensTests.swift
3+
//
4+
//
5+
// Created by Pedro Cuenca on 20240426.
6+
//
7+
8+
import XCTest
9+
import Tokenizers
10+
import Hub
11+
12+
class AddedTokensTests: XCTestCase {
13+
func testPhiAddedTokens() async throws {
14+
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Phi-3-mini-128k-instruct-4bit")
15+
let inputIds = tokenizer("This is the <|end|>. My only friend, the <|end|>")
16+
XCTAssertEqual(inputIds, [1, 910, 338, 278, 29871, 32007, 29889, 1619, 871, 5121, 29892, 278, 29871, 32007])
17+
18+
let decoded = tokenizer.decode(tokens: inputIds)
19+
XCTAssertEqual(decoded, "<s> This is the <|end|>. My only friend, the <|end|>")
20+
}
21+
22+
func testSplitWithCaptureGroups() {
23+
let addedTokensRegexp = #"(<\|end\|>)\s*|(<\|raw\|>)\s*"#
24+
let captureRegex = try! NSRegularExpression(pattern: addedTokensRegexp, options: [])
25+
26+
XCTAssertEqual(
27+
"eating <|raw|> meat <|end|> That's all".split(by: captureRegex),
28+
["eating ", "<|raw|>", "meat ", "<|end|>", "That's all"]
29+
)
30+
31+
XCTAssertEqual(
32+
"<|raw|>".split(by: captureRegex),
33+
["<|raw|>"]
34+
)
35+
36+
XCTAssertEqual(
37+
"This string doesn't have those separators".split(by: captureRegex),
38+
["This string doesn't have those separators"]
39+
)
40+
41+
XCTAssertEqual(
42+
"start <|end|>".split(by: captureRegex),
43+
["start ", "<|end|>"]
44+
)
45+
46+
XCTAssertEqual(
47+
"start <|end|> ".split(by: captureRegex),
48+
["start ", "<|end|>"]
49+
)
50+
51+
XCTAssertEqual(
52+
"start <|end|> ".split(by: captureRegex),
53+
["start ", "<|end|>"]
54+
)
55+
56+
XCTAssertEqual(
57+
"start <|end|> for real".split(by: captureRegex),
58+
["start ", "<|end|>", "for real"]
59+
)
60+
61+
XCTAssertEqual(
62+
"<|raw|><|end|>".split(by: captureRegex),
63+
["<|raw|>", "<|end|>"]
64+
)
65+
66+
}
67+
}

0 commit comments

Comments
 (0)