Skip to content

Commit 5d89b5d

Browse files
pcuencajohnmai-dev
andauthored
Chat templates by @maiqingqiang (#104)
* add jinja package * support chat template * Support `addSpecialTokens`. * Remove padding for now We need to get back to this to support consistently. --------- Co-authored-by: John Mai <maiqingqiang@foxmail.com>
1 parent c088078 commit 5d89b5d

File tree

5 files changed

+92
-20
lines changed

5 files changed

+92
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ DerivedData/
99
.swiftpm/config/registries.json
1010
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
1111
.netrc
12+
.idea

Package.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ let package = Package(
1212
.executable(name: "hub-cli", targets: ["HubCLI"]),
1313
],
1414
dependencies: [
15-
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0")
15+
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"),
16+
.package(url: "https://github.com/maiqingqiang/Jinja", branch: "main")
1617
],
1718
targets: [
1819
.executableTarget(
@@ -22,7 +23,7 @@ let package = Package(
2223
.product(name: "ArgumentParser", package: "swift-argument-parser")]),
2324
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
2425
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
25-
.target(name: "Tokenizers", dependencies: ["Hub"]),
26+
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),
2627
.target(name: "TensorUtils"),
2728
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
2829
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),

Sources/Tokenizers/BPETokenizer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class BPETokenizer: PreTrainedTokenizerModel {
7171
self.unknownToken = nil
7272
self.unknownTokenId = nil
7373
}
74-
74+
7575
eosToken = tokenizerConfig.eosToken?.stringValue
7676
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString]
7777

Sources/Tokenizers/PostProcessor.swift

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ import Foundation
99
import Hub
1010

1111
public protocol PostProcessor {
12-
func postProcess(tokens: [String], tokensPair: [String]?) -> [String]
13-
func callAsFunction(tokens: [String], tokensPair: [String]?) -> [String]
14-
12+
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool) -> [String]
13+
func callAsFunction(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool) -> [String]
14+
1515
init(config: Config)
1616
}
1717

1818
extension PostProcessor {
19-
func callAsFunction(tokens: [String], tokensPair: [String]? = nil) -> [String] {
20-
return postProcess(tokens: tokens, tokensPair: tokensPair)
19+
func callAsFunction(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] {
20+
return postProcess(tokens: tokens, tokensPair: tokensPair, addSpecialTokens: addSpecialTokens)
2121
}
2222
}
2323

@@ -53,13 +53,15 @@ class TemplateProcessing: PostProcessor {
5353
self.pair = pair
5454
}
5555

56-
func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] {
56+
func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] {
5757
let config = tokensPair == nil ? single : pair
58-
58+
5959
var toReturn: [String] = []
6060
for item in config {
6161
if let specialToken = item.SpecialToken {
62-
toReturn.append(specialToken.id!.stringValue!)
62+
if addSpecialTokens {
63+
toReturn.append(specialToken.id!.stringValue!)
64+
}
6365
} else if let sequence = item.Sequence {
6466
if sequence.id?.stringValue == "A" {
6567
toReturn += tokens
@@ -74,7 +76,7 @@ class TemplateProcessing: PostProcessor {
7476

7577
class ByteLevelPostProcessor: PostProcessor {
7678
required public init(config: Config) {}
77-
func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] { tokens }
79+
func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { tokens }
7880
}
7981

8082
class RobertaProcessing: PostProcessor {
@@ -94,7 +96,7 @@ class RobertaProcessing: PostProcessor {
9496
self.addPrefixSpace = config.addPrefixSpace?.boolValue ?? true
9597
}
9698

97-
func postProcess(tokens: [String], tokensPair: [String]?) -> [String] {
99+
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] {
98100
var outTokens = tokens
99101
var tokensPair = tokensPair
100102
if trimOffset {

Sources/Tokenizers/Tokenizer.swift

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import Hub
99
import Foundation
10+
import Jinja
1011

1112
enum TokenizerError : Error {
1213
case missingConfig
@@ -98,7 +99,8 @@ public protocol Tokenizer {
9899

99100
/// Main entry point
100101
func encode(text: String) -> [Int]
101-
func callAsFunction(_ text: String) -> [Int]
102+
func encode(text: String, addSpecialTokens: Bool) -> [Int]
103+
func callAsFunction(_ text: String, addSpecialTokens: Bool) -> [Int]
102104

103105
/// Decode
104106
func decode(tokens: [Int]) -> String
@@ -115,11 +117,21 @@ public protocol Tokenizer {
115117
var eosTokenId: Int? { get }
116118
var unknownToken: String? { get }
117119
var unknownTokenId: Int? { get }
120+
121+
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
122+
123+
func applyChatTemplate(
124+
messages: [[String: String]],
125+
chatTemplate: String?,
126+
addGenerationPrompt: Bool,
127+
truncation: Bool,
128+
maxLength: Int?
129+
) throws -> [Int]
118130
}
119131

120132
public extension Tokenizer {
121-
func callAsFunction(_ text: String) -> [Int] {
122-
encode(text: text)
133+
func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] {
134+
encode(text: text, addSpecialTokens: addSpecialTokens)
123135
}
124136

125137
func convertTokensToIds(_ tokens: [String]) -> [Int?] {
@@ -131,6 +143,17 @@ public extension Tokenizer {
131143
}
132144
}
133145

146+
let specialTokenAttributes: [String] = [
147+
"bos_token",
148+
"eos_token",
149+
"unk_token",
150+
"sep_token",
151+
"pad_token",
152+
"cls_token",
153+
"mask_token",
154+
"additional_special_tokens"
155+
]
156+
134157
public class PreTrainedTokenizer: Tokenizer {
135158
let model: TokenizingModel
136159

@@ -150,8 +173,11 @@ public class PreTrainedTokenizer: Tokenizer {
150173
private let normalizer: Normalizer?
151174
private let postProcessor: PostProcessor?
152175
private let decoder: Decoder?
176+
private let tokenizerConfig: Config
153177

154178
private let cleanUpTokenizationSpaces: Bool
179+
180+
private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
155181

156182
required public init(tokenizerConfig: Config, tokenizerData: Config) throws {
157183
var addedTokens: [String : Int] = [:]
@@ -195,7 +221,8 @@ public class PreTrainedTokenizer: Tokenizer {
195221
self.postProcessor = PostProcessorFactory.fromConfig(config: tokenizerData.postProcessor)
196222
self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder)
197223
self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true
198-
224+
self.tokenizerConfig = tokenizerConfig
225+
199226
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
200227
}
201228

@@ -209,9 +236,9 @@ public class PreTrainedTokenizer: Tokenizer {
209236
return normalizer(text: text)
210237
}
211238

212-
func postProcess(_ tokens: [String]) -> [String] {
239+
func postProcess(_ tokens: [String], addSpecialTokens: Bool = true) -> [String] {
213240
guard let postProcessor = postProcessor else { return tokens }
214-
return postProcessor(tokens: tokens)
241+
return postProcessor(tokens: tokens, addSpecialTokens: addSpecialTokens)
215242
}
216243

217244
func decodeTokens(_ tokens: [String]) -> [String] {
@@ -265,8 +292,12 @@ public class PreTrainedTokenizer: Tokenizer {
265292
}
266293

267294
/// Main entry point
295+
public func encode(text: String, addSpecialTokens: Bool = true) -> [Int] {
296+
return postProcess(tokenize(text: text), addSpecialTokens: addSpecialTokens).map { model.convertTokenToId($0)! }
297+
}
298+
268299
public func encode(text: String) -> [Int] {
269-
return postProcess(tokenize(text: text)).map { model.convertTokenToId($0)! }
300+
return encode(text: text, addSpecialTokens: true)
270301
}
271302

272303
/// Decode
@@ -285,6 +316,43 @@ public class PreTrainedTokenizer: Tokenizer {
285316
public func convertIdToToken(_ id: Int) -> String? {
286317
model.convertIdToToken(id)
287318
}
319+
320+
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
321+
try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, maxLength: nil)
322+
}
323+
324+
public func applyChatTemplate(
325+
messages: [[String: String]],
326+
chatTemplate: String?,
327+
addGenerationPrompt: Bool = false,
328+
truncation: Bool = false,
329+
maxLength: Int?
330+
) throws -> [Int] {
331+
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
332+
var context: [String: Any] = [
333+
"messages": messages,
334+
"add_generation_prompt": addGenerationPrompt
335+
]
336+
337+
// TODO: maybe keep NSString here
338+
for (key, value) in tokenizerConfig.dictionary as [String : Any] {
339+
if specialTokenAttributes.contains(key), !(value is NSNull) {
340+
context[key] = value
341+
}
342+
}
343+
344+
let rendered = try template.render(context)
345+
var encodedTokens = encode(text: rendered, addSpecialTokens: false)
346+
var maxLength = maxLength ?? encodedTokens.count
347+
maxLength = min(maxLength, tokenizerConfig.modelMaxLength?.intValue ?? maxLength)
348+
if encodedTokens.count > maxLength {
349+
if truncation {
350+
encodedTokens = Array(encodedTokens.prefix(maxLength))
351+
}
352+
}
353+
354+
return encodedTokens
355+
}
288356
}
289357

290358
// MARK: - Building

0 commit comments

Comments
 (0)