Skip to content

Commit 1a001b5

Browse files
authored
Add RobertaProcessing (#48)
* Add RobertaProcessing * Test RobertaProcessing * Test RobertaProcessing * Trim spaces from tokens pair * Document variables * Comment on trim spaces
1 parent c754d14 commit 1a001b5

File tree

4 files changed

+151
-2
lines changed

4 files changed

+151
-2
lines changed

Package.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ let package = Package(
3030
.testTarget(name: "HubTests", dependencies: ["Hub"]),
3131
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
3232
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]),
33-
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"])
33+
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
34+
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"])
3435
]
3536
)

Sources/Hub/Hub.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ public struct Config {
9898
guard let list = value as? [Any] else { return nil }
9999
return list.map { Config($0 as! [String : Any]) }
100100
}
101+
102+
/// Tuple of token identifier and string value
103+
public var tokenValue: (UInt, String)? { value as? (UInt, String) }
101104
}
102105

103106
public class LanguageModelConfigurationFromHub {

Sources/Tokenizers/PostProcessor.swift

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ extension PostProcessor {
2424
enum PostProcessorType: String {
2525
case TemplateProcessing
2626
case ByteLevel
27-
// case RobertaProcessing
27+
case RobertaProcessing
2828
}
2929

3030
struct PostProcessorFactory {
@@ -35,6 +35,7 @@ struct PostProcessorFactory {
3535
switch type {
3636
case .TemplateProcessing: return TemplateProcessing(config: config)
3737
case .ByteLevel : return ByteLevelPostProcessor(config: config)
38+
case .RobertaProcessing : return RobertaProcessing(config: config)
3839
default : fatalError("Unsupported PostProcessor type: \(typeName)")
3940
}
4041
}
@@ -75,3 +76,64 @@ class ByteLevelPostProcessor: PostProcessor {
7576
required public init(config: Config) {}
7677
func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] { tokens }
7778
}
79+
80+
class RobertaProcessing: PostProcessor {
81+
private let sep: (UInt, String)
82+
private let cls: (UInt, String)
83+
/// Trim all remaining space, or leave one space character if `addPrefixSpace` is `true`.
84+
private let trimOffset: Bool
85+
/// Keep one space character on each side. Depends on `trimOffsets` being `true`.
86+
private let addPrefixSpace: Bool
87+
88+
required public init(config: Config) {
89+
guard let sep = config.sep?.tokenValue else { fatalError("Missing `sep` processor configuration") }
90+
guard let cls = config.cls?.tokenValue else { fatalError("Missing `cls` processor configuration") }
91+
self.sep = sep
92+
self.cls = cls
93+
self.trimOffset = config.trimOffset?.boolValue ?? true
94+
self.addPrefixSpace = config.addPrefixSpace?.boolValue ?? true
95+
}
96+
97+
func postProcess(tokens: [String], tokensPair: [String]?) -> [String] {
98+
var outTokens = tokens
99+
var tokensPair = tokensPair
100+
if trimOffset {
101+
if addPrefixSpace {
102+
outTokens = outTokens.map({ trimExtraSpaces(token: $0) })
103+
tokensPair = tokensPair?.map({ trimExtraSpaces(token: $0) })
104+
} else {
105+
outTokens = outTokens.map({ $0.trimmingCharacters(in: .whitespaces) })
106+
tokensPair = tokensPair?.map({ $0.trimmingCharacters(in: .whitespaces) })
107+
}
108+
}
109+
110+
outTokens = [self.cls.1] + outTokens + [self.sep.1]
111+
if let tokensPair = tokensPair, !tokensPair.isEmpty {
112+
// Yes, it adds another `sep`.
113+
// https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/roberta/hub_interface.py#L58-L65
114+
outTokens += [self.sep.1] + tokensPair + [self.sep.1]
115+
}
116+
117+
return outTokens
118+
}
119+
120+
/// Some tokens need one space around them
121+
/// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L203-L235
122+
private func trimExtraSpaces(token: String) -> String {
123+
let prefixOffset = findPrefixIndex(text: token)
124+
let suffixOffset = findSuffixIndex(text: token)
125+
let prefixIndex = token.index(token.startIndex, offsetBy: prefixOffset)
126+
let suffixIndex = token.index(token.startIndex, offsetBy: token.count - suffixOffset)
127+
return String(token[prefixIndex..<suffixIndex])
128+
}
129+
130+
private func findPrefixIndex(text: String) -> Int {
131+
guard !text.isEmpty, text.first!.isWhitespace else { return 0 }
132+
return text.prefix(while: { $0.isWhitespace }).count - 1
133+
}
134+
135+
private func findSuffixIndex(text: String) -> Int {
136+
guard !text.isEmpty, text.last!.isWhitespace else { return 0 }
137+
return text.reversed().prefix(while: { $0.isWhitespace }).count - 1
138+
}
139+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import XCTest
2+
@testable import Tokenizers
3+
@testable import Hub
4+
5+
class PostProcessorTests: XCTestCase {
6+
func testRobertaProcessing() {
7+
let testCases: [(Config, [String], [String]?, [String])] = [
8+
// Should keep spaces; uneven spaces; ignore `addPrefixSpace`.
9+
(
10+
Config(["cls": (0, "[HEAD]") as (UInt, String),
11+
"sep": (0, "[END]") as (UInt, String),
12+
"trimOffset": false,
13+
"addPrefixSpace": true,
14+
]),
15+
[" The", " sun", "sets ", " in ", " the ", "west"],
16+
nil,
17+
["[HEAD]", " The", " sun", "sets ", " in ", " the ", "west", "[END]"]
18+
),
19+
// Should leave only one space around each token.
20+
(
21+
Config(["cls": (0, "[START]") as (UInt, String),
22+
"sep": (0, "[BREAK]") as (UInt, String),
23+
"trimOffset": true,
24+
"addPrefixSpace": true,
25+
]),
26+
[" The ", " sun", "sets ", " in ", " the ", "west"],
27+
nil,
28+
["[START]", " The ", " sun", "sets ", " in ", " the ", "west", "[BREAK]"]
29+
),
30+
// Should ignore empty tokens pair.
31+
(
32+
Config(["cls": (0, "[START]") as (UInt, String),
33+
"sep": (0, "[BREAK]") as (UInt, String),
34+
"trimOffset": true,
35+
"addPrefixSpace": true,
36+
]),
37+
[" The ", " sun", "sets ", " in ", " the ", "west"],
38+
[],
39+
["[START]", " The ", " sun", "sets ", " in ", " the ", "west", "[BREAK]"]
40+
),
41+
// Should trim all whitespace.
42+
(
43+
Config(["cls": (0, "[CLS]") as (UInt, String),
44+
"sep": (0, "[SEP]") as (UInt, String),
45+
"trimOffset": true,
46+
"addPrefixSpace": false,
47+
]),
48+
[" The ", " sun", "sets ", " in ", " the ", "west"],
49+
nil,
50+
["[CLS]", "The", "sun", "sets", "in", "the", "west", "[SEP]"]
51+
),
52+
// Should add tokens.
53+
(
54+
Config(["cls": (0, "[CLS]") as (UInt, String),
55+
"sep": (0, "[SEP]") as (UInt, String),
56+
"trimOffset": true,
57+
"addPrefixSpace": true,
58+
]),
59+
[" The ", " sun", "sets ", " in ", " the ", "west"],
60+
[".", "The", " cat ", " is ", " sitting ", " on", "the ", "mat"],
61+
["[CLS]", " The ", " sun", "sets ", " in ", " the ", "west", "[SEP]",
62+
"[SEP]", ".", "The", " cat ", " is ", " sitting ", " on", "the ",
63+
"mat", "[SEP]"]
64+
),
65+
(
66+
Config(["cls": (0, "[CLS]") as (UInt, String),
67+
"sep": (0, "[SEP]") as (UInt, String),
68+
"trimOffset": true,
69+
"addPrefixSpace": true,
70+
]),
71+
["", "", ","],
72+
["", "", "!"],
73+
["[CLS]", "", "", ",", "[SEP]", "[SEP]", "", "", "!", "[SEP]"]
74+
),
75+
]
76+
77+
for (config, tokens, tokensPair, expect) in testCases {
78+
let processor = RobertaProcessing(config: config)
79+
let output = processor.postProcess(tokens: tokens, tokensPair: tokensPair)
80+
XCTAssertEqual(output, expect)
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)