Skip to content

Commit e491a33

Browse files
Add Sequence and Bert processors (#129)
1 parent 104e8ce commit e491a33

File tree

1 file changed

+56
-4
lines changed

1 file changed

+56
-4
lines changed

Sources/Tokenizers/PostProcessor.swift

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ enum PostProcessorType: String {
2525
case TemplateProcessing
2626
case ByteLevel
2727
case RobertaProcessing
28+
case BertProcessing
29+
case Sequence
2830
}
2931

3032
struct PostProcessorFactory {
@@ -33,10 +35,12 @@ struct PostProcessorFactory {
3335
guard let typeName = config.type?.stringValue else { return nil }
3436
let type = PostProcessorType(rawValue: typeName)
3537
switch type {
36-
case .TemplateProcessing: return TemplateProcessing(config: config)
37-
case .ByteLevel : return ByteLevelPostProcessor(config: config)
38-
case .RobertaProcessing : return RobertaProcessing(config: config)
39-
default : fatalError("Unsupported PostProcessor type: \(typeName)")
38+
case .TemplateProcessing : return TemplateProcessing(config: config)
39+
case .ByteLevel : return ByteLevelPostProcessor(config: config)
40+
case .RobertaProcessing : return RobertaProcessing(config: config)
41+
case .BertProcessing : return BertProcessing(config: config)
42+
case .Sequence : return SequenceProcessing(config: config)
43+
default : fatalError("Unsupported PostProcessor type: \(typeName)")
4044
}
4145
}
4246
}
@@ -139,3 +143,51 @@ class RobertaProcessing: PostProcessor {
139143
return text.reversed().prefix(while: { $0.isWhitespace }).count - 1
140144
}
141145
}
146+
147+
class BertProcessing: PostProcessor {
148+
private let sep: (UInt, String)
149+
private let cls: (UInt, String)
150+
151+
required public init(config: Config) {
152+
guard let sep = config.sep?.tokenValue else { fatalError("Missing `sep` processor configuration") }
153+
guard let cls = config.cls?.tokenValue else { fatalError("Missing `cls` processor configuration") }
154+
self.sep = sep
155+
self.cls = cls
156+
}
157+
158+
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] {
159+
guard addSpecialTokens else { return tokens + (tokensPair ?? []) }
160+
161+
var outTokens = [self.cls.1] + tokens + [self.sep.1]
162+
if let tokensPair = tokensPair, !tokensPair.isEmpty {
163+
outTokens += tokensPair + [self.sep.1]
164+
}
165+
166+
return outTokens
167+
}
168+
}
169+
170+
class SequenceProcessing: PostProcessor {
171+
private let processors: [PostProcessor]
172+
173+
required public init(config: Config) {
174+
guard let processorConfigs = config.processors?.arrayValue else {
175+
fatalError("Missing `processors` configuration")
176+
}
177+
178+
self.processors = processorConfigs.compactMap { PostProcessorFactory.fromConfig(config: $0) }
179+
}
180+
181+
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] {
182+
var currentTokens = tokens
183+
var currentTokensPair = tokensPair
184+
185+
for processor in processors {
186+
let processed = processor.postProcess(tokens: currentTokens, tokensPair: currentTokensPair, addSpecialTokens: addSpecialTokens)
187+
currentTokens = processed
188+
currentTokensPair = nil // After the first processor, we no longer have a separate pair
189+
}
190+
191+
return currentTokens
192+
}
193+
}

0 commit comments

Comments
 (0)