Skip to content

Commit 2c68d53

Browse files
Added support for Bert models (#137)
Co-authored-by: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com>
1 parent 4d25d20 commit 2c68d53

File tree

5 files changed

+79
-4
lines changed

5 files changed

+79
-4
lines changed

Sources/Tokenizers/Decoder.swift

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ extension Decoder {
2323

2424
enum DecoderType: String {
2525
case Sequence
26-
// case WordPiece
26+
case WordPiece
2727
case ByteLevel
2828
case Replace
2929
case ByteFallback
@@ -47,11 +47,47 @@ struct DecoderFactory {
4747
case .Fuse : return FuseDecoder(config: config)
4848
case .Strip : return StripDecoder(config: config)
4949
case .Metaspace : return MetaspaceDecoder(config: config)
50+
case .WordPiece : return WordPieceDecoder(config: config)
5051
default : fatalError("Unsupported Decoder type: \(typeName)")
5152
}
5253
}
5354
}
5455

56+
class WordPieceDecoder: Decoder {
57+
let prefix: String
58+
let cleanup: Bool
59+
60+
required public init(config: Config) {
61+
guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") }
62+
self.prefix = prefix
63+
self.cleanup = config.cleanup?.boolValue ?? false
64+
}
65+
66+
func decode(tokens: [String]) -> [String] {
67+
var newTokens = [String]()
68+
newTokens.reserveCapacity(tokens.count)
69+
for (index, token) in tokens.enumerated() {
70+
var decodedToken = token
71+
if index != 0 {
72+
if decodedToken.hasPrefix(prefix) {
73+
decodedToken = String(decodedToken.dropFirst(prefix.count))
74+
} else {
75+
decodedToken = " \(decodedToken)"
76+
}
77+
}
78+
if cleanup {
79+
decodedToken = cleanUpTokenization(decodedToken)
80+
}
81+
newTokens.append(decodedToken)
82+
}
83+
return newTokens
84+
}
85+
86+
private func cleanUpTokenization(_ token: String) -> String {
87+
return token.trimmingCharacters(in: .whitespacesAndNewlines)
88+
}
89+
}
90+
5591
class DecoderSequence: Decoder {
5692
let decoders: [Decoder]
5793

Sources/Tokenizers/Normalizer.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum NormalizerType: String {
3131
case NFKD
3232
case NFKC
3333
case Bert
34+
case BertNormalizer
3435
case Precompiled
3536
case StripAccents
3637
case Strip
@@ -51,7 +52,7 @@ struct NormalizerFactory {
5152
case .NFC: return NFCNormalizer(config: config)
5253
case .NFKD: return NFKDNormalizer(config: config)
5354
case .NFKC: return NFKCNormalizer(config: config)
54-
case .Bert: return BertNormalizer(config: config)
55+
case .Bert, .BertNormalizer: return BertNormalizer(config: config)
5556
case .Precompiled: return PrecompiledNormalizer(config: config)
5657
case .StripAccents: return StripAccentsNormalizer(config: config)
5758
case .Strip: return StripNormalizer(config: config)

Sources/Tokenizers/PreTokenizer.swift

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ enum PreTokenizerType: String {
4646
case Whitespace
4747
case WhitespaceSplit
4848
case Metaspace
49+
case BertPreTokenizer
4950
// Several more to be supported
5051
case Unknown = ""
5152
}
@@ -63,11 +64,25 @@ struct PreTokenizerFactory {
6364
case .Split: return SplitPreTokenizer(config: config)
6465
case .Whitespace, .WhitespaceSplit: return WhitespacePreTokenizer(config: config)
6566
case .Metaspace: return MetaspacePreTokenizer(config: config)
67+
case .BertPreTokenizer: return BertPreTokenizer(config: config)
6668
default: fatalError("Unsupported PreTokenizer type: \(typeName)")
6769
}
6870
}
6971
}
7072

73+
class BertPreTokenizer: PreTokenizer {
74+
let re: String
75+
76+
required init(config: Config) {
77+
// Ref: https://github.com/huggingface/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1002
78+
re = "[^\\s\(Constants.PUNCTUATION_REGEX)]+|[\(Constants.PUNCTUATION_REGEX)]"
79+
}
80+
81+
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
82+
return text.ranges(of: re).map { String(text[$0]) }
83+
}
84+
}
85+
7186
class PreTokenizerSequence: PreTokenizer {
7287
let preTokenizers: [PreTokenizer]
7388

@@ -184,11 +199,10 @@ class ByteLevelPreTokenizer: PreTokenizer {
184199
}
185200

186201
class PunctuationPreTokenizer: PreTokenizer {
187-
let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"#
188202
let re: String
189203

190204
required init(config: Config) {
191-
re = "[^\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]+"
205+
re = "[^\(Constants.PUNCTUATION_REGEX)]+|[\(Constants.PUNCTUATION_REGEX)]+"
192206
}
193207

194208
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {

Sources/Tokenizers/Utils.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@ struct Utils {
7777
}
7878
}
7979

80+
enum Constants {
81+
static let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"#
82+
}
83+

Tests/PreTokenizerTests/PreTokenizerTests.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,24 @@ class PreTokenizerTests: XCTestCase {
171171
["▁Hey", "▁my", "▁friend", "", "▁<s>", "▁how", "▁are", "▁you"]
172172
)
173173
}
174+
175+
func testBertPreTokenizer() {
176+
let preTokenizer1 = BertPreTokenizer(config: Config([:]))
177+
XCTAssertEqual(
178+
preTokenizer1.preTokenize(text: "Hey friend!"),
179+
["Hey", "friend", "!"]
180+
)
181+
XCTAssertEqual(
182+
preTokenizer1.preTokenize(text: "Hey friend! How are you?!?"),
183+
["Hey", "friend", "!", "How", "are", "you", "?", "!", "?"]
184+
)
185+
XCTAssertEqual(
186+
preTokenizer1.preTokenize(text: " Hey, friend , what's up? "),
187+
["Hey", ",", "friend", ",", "what", "\'", "s", "up", "?"]
188+
)
189+
XCTAssertEqual(
190+
preTokenizer1.preTokenize(text: " Hey, friend , 0 99 what's up? "),
191+
["Hey", ",", "friend", ",", "0", "99", "what", "\'", "s", "up", "?"]
192+
)
193+
}
174194
}

0 commit comments

Comments
 (0)