@@ -25,6 +25,8 @@ enum PostProcessorType: String {
25
25
case TemplateProcessing
26
26
case ByteLevel
27
27
case RobertaProcessing
28
+ case BertProcessing
29
+ case Sequence
28
30
}
29
31
30
32
struct PostProcessorFactory {
@@ -33,10 +35,12 @@ struct PostProcessorFactory {
33
35
guard let typeName = config. type? . stringValue else { return nil }
34
36
let type = PostProcessorType ( rawValue: typeName)
35
37
switch type {
36
- case . TemplateProcessing: return TemplateProcessing ( config: config)
37
- case . ByteLevel : return ByteLevelPostProcessor ( config: config)
38
- case . RobertaProcessing : return RobertaProcessing ( config: config)
39
- default : fatalError ( " Unsupported PostProcessor type: \( typeName) " )
38
+ case . TemplateProcessing : return TemplateProcessing ( config: config)
39
+ case . ByteLevel : return ByteLevelPostProcessor ( config: config)
40
+ case . RobertaProcessing : return RobertaProcessing ( config: config)
41
+ case . BertProcessing : return BertProcessing ( config: config)
42
+ case . Sequence : return SequenceProcessing ( config: config)
43
+ default : fatalError ( " Unsupported PostProcessor type: \( typeName) " )
40
44
}
41
45
}
42
46
}
@@ -139,3 +143,51 @@ class RobertaProcessing: PostProcessor {
139
143
return text. reversed ( ) . prefix ( while: { $0. isWhitespace } ) . count - 1
140
144
}
141
145
}
146
+
147
+ class BertProcessing : PostProcessor {
148
+ private let sep : ( UInt , String )
149
+ private let cls : ( UInt , String )
150
+
151
+ required public init ( config: Config ) {
152
+ guard let sep = config. sep? . tokenValue else { fatalError ( " Missing `sep` processor configuration " ) }
153
+ guard let cls = config. cls? . tokenValue else { fatalError ( " Missing `cls` processor configuration " ) }
154
+ self . sep = sep
155
+ self . cls = cls
156
+ }
157
+
158
+ func postProcess( tokens: [ String ] , tokensPair: [ String ] ? , addSpecialTokens: Bool = true ) -> [ String ] {
159
+ guard addSpecialTokens else { return tokens + ( tokensPair ?? [ ] ) }
160
+
161
+ var outTokens = [ self . cls. 1 ] + tokens + [ self . sep. 1 ]
162
+ if let tokensPair = tokensPair, !tokensPair. isEmpty {
163
+ outTokens += tokensPair + [ self . sep. 1 ]
164
+ }
165
+
166
+ return outTokens
167
+ }
168
+ }
169
+
170
+ class SequenceProcessing : PostProcessor {
171
+ private let processors : [ PostProcessor ]
172
+
173
+ required public init ( config: Config ) {
174
+ guard let processorConfigs = config. processors? . arrayValue else {
175
+ fatalError ( " Missing `processors` configuration " )
176
+ }
177
+
178
+ self . processors = processorConfigs. compactMap { PostProcessorFactory . fromConfig ( config: $0) }
179
+ }
180
+
181
+ func postProcess( tokens: [ String ] , tokensPair: [ String ] ? , addSpecialTokens: Bool = true ) -> [ String ] {
182
+ var currentTokens = tokens
183
+ var currentTokensPair = tokensPair
184
+
185
+ for processor in processors {
186
+ let processed = processor. postProcess ( tokens: currentTokens, tokensPair: currentTokensPair, addSpecialTokens: addSpecialTokens)
187
+ currentTokens = processed
188
+ currentTokensPair = nil // After the first processor, we no longer have a separate pair
189
+ }
190
+
191
+ return currentTokens
192
+ }
193
+ }
0 commit comments