Skip to content

Commit c5b0160

Browse files
committed
Replace with enum for future extensibility
1 parent 750e2e9 commit c5b0160

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

Sources/Tokenizers/PreTokenizer.swift

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,32 @@
88
import Foundation
99
import Hub
1010

11+
public enum PreTokenizerOption: String {
12+
case firstSection
13+
}
14+
15+
public typealias PreTokenizerOptions = Set<PreTokenizerOption>
16+
1117
public protocol PreTokenizer {
12-
func preTokenize(text: String, firstSection: Bool) -> [String]
13-
func preTokenize(texts: [String], firstSection: Bool) -> [String]
14-
func callAsFunction(texts: [String], firstSection: Bool) -> [String]
15-
func callAsFunction(text: String, firstSection: Bool) -> [String]
18+
func preTokenize(text: String, options: PreTokenizerOptions) -> [String]
19+
func preTokenize(texts: [String], options: PreTokenizerOptions) -> [String]
20+
func callAsFunction(texts: [String], options: PreTokenizerOptions) -> [String]
21+
func callAsFunction(text: String, options: PreTokenizerOptions) -> [String]
1622

1723
init(config: Config)
1824
}
1925

2026
extension PreTokenizer {
21-
func preTokenize(texts: [String], firstSection: Bool = true) -> [String] {
22-
texts.flatMap { preTokenize(text: $0, firstSection: firstSection) }
27+
func preTokenize(texts: [String], options: PreTokenizerOptions = [.firstSection]) -> [String] {
28+
texts.flatMap { preTokenize(text: $0, options: options) }
2329
}
2430

25-
func callAsFunction(texts: [String], firstSection: Bool = true) -> [String] {
26-
return preTokenize(texts: texts, firstSection: firstSection)
31+
func callAsFunction(texts: [String], options: PreTokenizerOptions = [.firstSection]) -> [String] {
32+
return preTokenize(texts: texts, options: options)
2733
}
2834

29-
func callAsFunction(text: String, firstSection: Bool = true) -> [String] {
30-
return preTokenize(text: text, firstSection: firstSection)
35+
func callAsFunction(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
36+
return preTokenize(text: text, options: options)
3137
}
3238
}
3339

@@ -70,9 +76,9 @@ class PreTokenizerSequence: PreTokenizer {
7076
preTokenizers = configs.compactMap { PreTokenizerFactory.fromConfig(config: $0) }
7177
}
7278

73-
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
79+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
7480
preTokenizers.reduce([text]) { current, preTokenizer in
75-
preTokenizer(texts: current, firstSection: firstSection)
81+
preTokenizer(texts: current, options: options)
7682
}
7783
}
7884
}
@@ -84,7 +90,7 @@ class WhitespacePreTokenizer: PreTokenizer {
8490
re = #"\S+"#
8591
}
8692

87-
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
93+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
8894
return text.ranges(of: re).map { String(text[$0]) }
8995
}
9096
}
@@ -124,7 +130,7 @@ class MetaspacePreTokenizer: PreTokenizer {
124130

125131
// https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L114
126132
// https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/tokenizers.js#L2153
127-
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
133+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
128134
let normalized = text.replacingOccurrences(of: " ", with: stringReplacement)
129135

130136
// We add a prefix space if:
@@ -140,7 +146,7 @@ class MetaspacePreTokenizer: PreTokenizer {
140146
if prependScheme == .always {
141147
prepend = stringReplacement
142148
}
143-
if prependScheme == .first && firstSection {
149+
if prependScheme == .first && options.contains(.firstSection) {
144150
prepend = stringReplacement
145151
}
146152
}
@@ -163,7 +169,7 @@ class ByteLevelPreTokenizer: PreTokenizer {
163169
useRegex = config.useRegex?.boolValue ?? true
164170
}
165171

166-
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
172+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
167173
// Split on whitespace and punctuation
168174
let tokens = useRegex ? text.ranges(of: RE).map({ String(text[$0]) }) : [text]
169175
return tokens.map { token in
@@ -185,7 +191,7 @@ class PunctuationPreTokenizer: PreTokenizer {
185191
re = "[^\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]+"
186192
}
187193

188-
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
194+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
189195
// Ref: https://github.com/xenova/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1138
190196
return text.ranges(of: re).map { String(text[$0]) }
191197
}
@@ -199,7 +205,7 @@ class DigitsPreTokenizer: PreTokenizer {
199205
re = "[^\\d]+|\\d\(individualDigits ? "" : "+")"
200206
}
201207

202-
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
208+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
203209
return text.ranges(of: re).map { String(text[$0]) }
204210
}
205211
}
@@ -213,7 +219,7 @@ class SplitPreTokenizer: PreTokenizer {
213219
invert = config.invert?.boolValue ?? false
214220
}
215221

216-
func preTokenize(text: String, firstSection: Bool = true) -> [String] {
222+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
217223
guard let pattern = pattern else { return [text] }
218224
return pattern.split(text, invert: invert)
219225
}

Sources/Tokenizers/Tokenizer.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ public class PreTrainedTokenizer: Tokenizer {
185185
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
186186
}
187187

188-
func preTokenize(_ text: String, firstSection: Bool) -> [String] {
188+
func preTokenize(_ text: String, options: PreTokenizerOptions) -> [String] {
189189
guard let preTokenizer = preTokenizer else { return [text] }
190-
return preTokenizer(text: text, firstSection: firstSection)
190+
return preTokenizer(text: text, options: options)
191191
}
192192

193193
func normalize(_ text: String) -> String {
@@ -231,7 +231,7 @@ public class PreTrainedTokenizer: Tokenizer {
231231
}
232232
return sections.enumerated().map { section, x in
233233
if addedTokens.contains(x) { return [x] }
234-
return preTokenize(normalize(x), firstSection: section == 0).flatMap { model($0) }
234+
return preTokenize(normalize(x), options: section == 0 ? [.firstSection] : []).flatMap { model($0) }
235235
}.flatMap { $0 }
236236
}
237237

0 commit comments

Comments
 (0)