Skip to content

Commit a7a61a2

Browse files
authored
add StripNormalizer (#133)
1 parent 71963c3 commit a7a61a2

File tree

2 files changed

+102
-45
lines changed

2 files changed

+102
-45
lines changed

Sources/Tokenizers/Normalizer.swift

Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//
22
// Normalizer.swift
3-
//
3+
//
44
//
55
// Created by Pedro Cuenca on 17/7/23.
66
//
@@ -11,7 +11,7 @@ import Hub
1111
public protocol Normalizer {
1212
func normalize(text: String) -> String
1313
func callAsFunction(text: String) -> String
14-
14+
1515
init(config: Config)
1616
}
1717

@@ -33,6 +33,7 @@ enum NormalizerType: String {
3333
case Bert
3434
case Precompiled
3535
case StripAccents
36+
case Strip
3637
case Unknown = ""
3738
}
3839

@@ -43,29 +44,32 @@ struct NormalizerFactory {
4344
let type = NormalizerType(rawValue: typeName)
4445
switch type {
4546
case .Sequence: return NormalizerSequence(config: config)
46-
case .Prepend : return PrependNormalizer(config: config)
47-
case .Replace : return ReplaceNormalizer(config: config)
48-
case .Lowercase : return LowercaseNormalizer(config: config)
49-
case .NFD : return NFDNormalizer(config: config)
50-
case .NFC : return NFCNormalizer(config: config)
51-
case .NFKD : return NFKDNormalizer(config: config)
52-
case .NFKC : return NFKCNormalizer(config: config)
53-
case .Bert : return BertNormalizer(config: config)
54-
case .Precompiled : return PrecompiledNormalizer(config: config)
55-
case .StripAccents : return StripAccentsNormalizer(config: config)
56-
default : fatalError("Unsupported Normalizer type: \(typeName)")
47+
case .Prepend: return PrependNormalizer(config: config)
48+
case .Replace: return ReplaceNormalizer(config: config)
49+
case .Lowercase: return LowercaseNormalizer(config: config)
50+
case .NFD: return NFDNormalizer(config: config)
51+
case .NFC: return NFCNormalizer(config: config)
52+
case .NFKD: return NFKDNormalizer(config: config)
53+
case .NFKC: return NFKCNormalizer(config: config)
54+
case .Bert: return BertNormalizer(config: config)
55+
case .Precompiled: return PrecompiledNormalizer(config: config)
56+
case .StripAccents: return StripAccentsNormalizer(config: config)
57+
case .Strip: return StripNormalizer(config: config)
58+
default: fatalError("Unsupported Normalizer type: \(typeName)")
5759
}
5860
}
5961
}
6062

6163
class NormalizerSequence: Normalizer {
6264
let normalizers: [Normalizer]
63-
65+
6466
required public init(config: Config) {
65-
guard let configs = config.normalizers?.arrayValue else { fatalError("No normalizers in Sequence") }
67+
guard let configs = config.normalizers?.arrayValue else {
68+
fatalError("No normalizers in Sequence")
69+
}
6670
normalizers = configs.compactMap { NormalizerFactory.fromConfig(config: $0) }
6771
}
68-
72+
6973
public func normalize(text: String) -> String {
7074
normalizers.reduce(text) { current, normalizer in
7175
normalizer(text: current)
@@ -75,23 +79,23 @@ class NormalizerSequence: Normalizer {
7579

7680
class PrependNormalizer: Normalizer {
7781
let prepend: String
78-
82+
7983
required public init(config: Config) {
8084
prepend = config.prepend?.stringValue ?? ""
8185
}
82-
86+
8387
public func normalize(text: String) -> String {
8488
return prepend + text
8589
}
8690
}
8791

8892
class ReplaceNormalizer: Normalizer {
8993
let pattern: StringReplacePattern?
90-
94+
9195
required public init(config: Config) {
9296
self.pattern = StringReplacePattern.from(config: config)
9397
}
94-
98+
9599
public func normalize(text: String) -> String {
96100
guard let pattern = pattern else { return text }
97101
return pattern.replace(text)
@@ -106,7 +110,7 @@ class LowercaseNormalizer: Normalizer {
106110
}
107111
}
108112

109-
class NFDNormalizer: Normalizer {
113+
class NFDNormalizer: Normalizer {
110114
required public init(config: Config) {}
111115

112116
public func normalize(text: String) -> String {
@@ -122,7 +126,7 @@ class NFCNormalizer: Normalizer {
122126
}
123127
}
124128

125-
class NFKDNormalizer: Normalizer {
129+
class NFKDNormalizer: Normalizer {
126130
required init(config: Config) {}
127131

128132
func normalize(text: String) -> String {
@@ -172,15 +176,13 @@ class BertNormalizer: Normalizer {
172176
private func cleanText(text: String) -> String {
173177
text.map { c in
174178
guard let scalar = c.unicodeScalars.first,
175-
scalar.value != 0x0,
176-
scalar.value != 0xFFFD,
177-
!isControl(scalar)
179+
scalar.value != 0x0,
180+
scalar.value != 0xFFFD,
181+
!isControl(scalar)
178182
else { return "\(c)" }
179183

180184
// Replace whitespace: \t, \n, \r
181-
if scalar.value == 0x009 ||
182-
scalar.value == 0x00A ||
183-
scalar.value == 0x000D {
185+
if scalar.value == 0x009 || scalar.value == 0x00A || scalar.value == 0x000D {
184186
return " "
185187
} else {
186188
return "\(c)"
@@ -201,29 +203,27 @@ class BertNormalizer: Normalizer {
201203
}
202204

203205
private func isOther(_ c: Unicode.GeneralCategory) -> Bool {
204-
c == .control ||
205-
c == .format ||
206-
c == .surrogate ||
207-
c == .privateUse ||
208-
c == .unassigned
206+
c == .control || c == .format || c == .surrogate || c == .privateUse || c == .unassigned
209207
}
210208

211209
private func handleChineseChars(text: String) -> String {
212210
text.map { c in
213211
if let scalar = c.unicodeScalars.first, Utils.isChineseChar(scalar) {
214212
" \(c) "
215213
} else {
216-
"\(c)"
214+
"\(c)"
217215
}
218216
}
219217
.joined()
220218
}
221219

222220
private func stripAccents(text: String) -> String {
223221
text.decomposedStringWithCanonicalMapping
224-
.filter { $0.unicodeScalars.allSatisfy { scalar in
225-
!(0x0300 <= scalar.value && scalar.value <= 0x036F)
226-
}}
222+
.filter {
223+
$0.unicodeScalars.allSatisfy { scalar in
224+
!(0x0300 <= scalar.value && scalar.value <= 0x036F)
225+
}
226+
}
227227
}
228228
}
229229

@@ -245,7 +245,8 @@ class PrecompiledNormalizer: Normalizer {
245245
case 0x0001...0x0008, 0x000B, 0x000E...0x001F, 0x007F, 0x008F, 0x009F:
246246
// Non-printing control characters
247247
output.append("")
248-
case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B...0x200F, 0x2028, 0x2029, 0x2581, 0xFEFF, 0xFFFD:
248+
case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B...0x200F, 0x2028, 0x2029, 0x2581,
249+
0xFEFF, 0xFFFD:
249250
// Separators
250251
output.append(" ")
251252
case 0xFF5E:
@@ -257,7 +258,8 @@ class PrecompiledNormalizer: Normalizer {
257258
}
258259

259260
if hasFullwidthTilde {
260-
return output
261+
return
262+
output
261263
.split(by: "\u{FF5E}")
262264
.map({ $0.precomposedStringWithCompatibilityMapping })
263265
.joined(separator: "\u{FF5E}")
@@ -275,6 +277,30 @@ class StripAccentsNormalizer: Normalizer {
275277
}
276278
}
277279

280+
class StripNormalizer: Normalizer {
281+
let leftStrip: Bool
282+
let rightStrip: Bool
283+
284+
required init(config: Config) {
285+
self.leftStrip = config.stripLeft?.boolValue ?? true
286+
self.rightStrip = config.stripRight?.boolValue ?? true
287+
}
288+
289+
func normalize(text: String) -> String {
290+
var result = text
291+
292+
if leftStrip {
293+
result = String(result.drop(while: { $0.isWhitespace }))
294+
}
295+
296+
if rightStrip {
297+
result = String(result.reversed().drop(while: { $0.isWhitespace }).reversed())
298+
}
299+
300+
return result
301+
}
302+
}
303+
278304
enum StringReplacePattern {
279305
case regexp(regexp: NSRegularExpression, replacement: String)
280306
case string(pattern: String, replacement: String)
@@ -285,7 +311,8 @@ extension StringReplacePattern {
285311
switch self {
286312
case .regexp(let regexp, let replacement):
287313
let range = NSRange(text.startIndex..., in: text)
288-
let replaced = regexp.stringByReplacingMatches(in: text, options: [], range: range, withTemplate: replacement)
314+
let replaced = regexp.stringByReplacingMatches(
315+
in: text, options: [], range: range, withTemplate: replacement)
289316
return replaced
290317
case .string(let toReplace, let replacement):
291318
return text.replacingOccurrences(of: toReplace, with: replacement)

Tests/NormalizerTests/NormalizerTests.swift

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import XCTest
2-
@testable import Tokenizers
2+
33
@testable import Hub
4+
@testable import Tokenizers
45

56
class NormalizerTests: XCTestCase {
67

@@ -22,7 +23,7 @@ class NormalizerTests: XCTestCase {
2223
let normalizer = LowercaseNormalizer(config: config)
2324
XCTAssertEqual(normalizer.normalize(text: arg), expect)
2425
}
25-
26+
2627
let config = Config(["type": NormalizerType.Lowercase.rawValue])
2728
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? LowercaseNormalizer)
2829
}
@@ -68,11 +69,11 @@ class NormalizerTests: XCTestCase {
6869
let normalizer = NFCNormalizer(config: config)
6970
XCTAssertEqual(normalizer.normalize(text: arg), expect)
7071
}
71-
72+
7273
let config = Config(["type": NormalizerType.NFC.rawValue])
7374
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? NFCNormalizer)
7475
}
75-
76+
7677
func testNFKDNormalizer() {
7778
let testCases: [(String, String)] = [
7879
("café", "cafe\u{301}"),
@@ -118,7 +119,7 @@ class NormalizerTests: XCTestCase {
118119
let config = Config(["type": NormalizerType.NFKC.rawValue])
119120
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? NFKCNormalizer)
120121
}
121-
122+
122123
func testBertNormalizer() {
123124
let testCases: [(String, String)] = [
124125
("Café", "café"),
@@ -141,6 +142,7 @@ class NormalizerTests: XCTestCase {
141142
let config = Config(["type": NormalizerType.Bert.rawValue])
142143
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? BertNormalizer)
143144
}
145+
144146
func testPrecompiledNormalizer() {
145147
let testCases: [(String, String)] = [
146148
("café", "café"),
@@ -188,4 +190,32 @@ class NormalizerTests: XCTestCase {
188190
let config = Config(["type": NormalizerType.StripAccents.rawValue])
189191
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? StripAccentsNormalizer)
190192
}
193+
194+
func testStripNormalizer() {
195+
let testCases: [(String, String, Bool, Bool)] = [
196+
(" hello ", "hello", true, true),
197+
(" hello ", "hello ", true, false),
198+
(" hello ", " hello", false, true),
199+
(" hello ", " hello ", false, false),
200+
("\t\nHello\t\n", "Hello", true, true),
201+
(" ", "", true, true),
202+
("", "", true, true),
203+
]
204+
205+
for (input, expected, leftStrip, rightStrip) in testCases {
206+
let config = Config([
207+
"type": NormalizerType.Strip.rawValue,
208+
"stripLeft": leftStrip,
209+
"stripRight": rightStrip,
210+
])
211+
let normalizer = StripNormalizer(config: config)
212+
XCTAssertEqual(
213+
normalizer.normalize(text: input), expected,
214+
"Failed for input: '\(input)', leftStrip: \(leftStrip), rightStrip: \(rightStrip)")
215+
}
216+
217+
let config = Config(["type": NormalizerType.Strip.rawValue])
218+
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? StripNormalizer)
219+
}
220+
191221
}

0 commit comments

Comments
 (0)