Skip to content

Chat templates by @maiqingqiang #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ DerivedData/
.swiftpm/config/registries.json
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
.netrc
.idea
5 changes: 3 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ let package = Package(
.executable(name: "hub-cli", targets: ["HubCLI"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0")
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"),
.package(url: "https://github.com/maiqingqiang/Jinja", branch: "main")
],
targets: [
.executableTarget(
Expand All @@ -22,7 +23,7 @@ let package = Package(
.product(name: "ArgumentParser", package: "swift-argument-parser")]),
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
.target(name: "Tokenizers", dependencies: ["Hub"]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),
.target(name: "TensorUtils"),
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
Expand Down
2 changes: 1 addition & 1 deletion Sources/Tokenizers/BPETokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class BPETokenizer: PreTrainedTokenizerModel {
self.unknownToken = nil
self.unknownTokenId = nil
}

eosToken = tokenizerConfig.eosToken?.stringValue
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString]

Expand Down
22 changes: 12 additions & 10 deletions Sources/Tokenizers/PostProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import Foundation
import Hub

public protocol PostProcessor {
func postProcess(tokens: [String], tokensPair: [String]?) -> [String]
func callAsFunction(tokens: [String], tokensPair: [String]?) -> [String]
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool) -> [String]
func callAsFunction(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool) -> [String]

init(config: Config)
}

extension PostProcessor {
func callAsFunction(tokens: [String], tokensPair: [String]? = nil) -> [String] {
return postProcess(tokens: tokens, tokensPair: tokensPair)
func callAsFunction(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] {
return postProcess(tokens: tokens, tokensPair: tokensPair, addSpecialTokens: addSpecialTokens)
}
}

Expand Down Expand Up @@ -53,13 +53,15 @@ class TemplateProcessing: PostProcessor {
self.pair = pair
}

func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] {
func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] {
let config = tokensPair == nil ? single : pair

var toReturn: [String] = []
for item in config {
if let specialToken = item.SpecialToken {
toReturn.append(specialToken.id!.stringValue!)
if addSpecialTokens {
toReturn.append(specialToken.id!.stringValue!)
}
} else if let sequence = item.Sequence {
if sequence.id?.stringValue == "A" {
toReturn += tokens
Expand All @@ -74,7 +76,7 @@ class TemplateProcessing: PostProcessor {

class ByteLevelPostProcessor: PostProcessor {
required public init(config: Config) {}
func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] { tokens }
func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { tokens }
}

class RobertaProcessing: PostProcessor {
Expand All @@ -94,7 +96,7 @@ class RobertaProcessing: PostProcessor {
self.addPrefixSpace = config.addPrefixSpace?.boolValue ?? true
}

func postProcess(tokens: [String], tokensPair: [String]?) -> [String] {
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] {
var outTokens = tokens
var tokensPair = tokensPair
if trimOffset {
Expand Down
82 changes: 75 additions & 7 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import Hub
import Foundation
import Jinja

enum TokenizerError : Error {
case missingConfig
Expand Down Expand Up @@ -98,7 +99,8 @@ public protocol Tokenizer {

/// Main entry point
func encode(text: String) -> [Int]
func callAsFunction(_ text: String) -> [Int]
func encode(text: String, addSpecialTokens: Bool) -> [Int]
func callAsFunction(_ text: String, addSpecialTokens: Bool) -> [Int]

/// Decode
func decode(tokens: [Int]) -> String
Expand All @@ -115,11 +117,21 @@ public protocol Tokenizer {
var eosTokenId: Int? { get }
var unknownToken: String? { get }
var unknownTokenId: Int? { get }

func applyChatTemplate(messages: [[String: String]]) throws -> [Int]

func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?
) throws -> [Int]
}

public extension Tokenizer {
func callAsFunction(_ text: String) -> [Int] {
encode(text: text)
func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] {
encode(text: text, addSpecialTokens: addSpecialTokens)
}

func convertTokensToIds(_ tokens: [String]) -> [Int?] {
Expand All @@ -131,6 +143,17 @@ public extension Tokenizer {
}
}

let specialTokenAttributes: [String] = [
"bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"additional_special_tokens"
]

public class PreTrainedTokenizer: Tokenizer {
let model: TokenizingModel

Expand All @@ -150,8 +173,11 @@ public class PreTrainedTokenizer: Tokenizer {
private let normalizer: Normalizer?
private let postProcessor: PostProcessor?
private let decoder: Decoder?
private let tokenizerConfig: Config

private let cleanUpTokenizationSpaces: Bool

private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

required public init(tokenizerConfig: Config, tokenizerData: Config) throws {
var addedTokens: [String : Int] = [:]
Expand Down Expand Up @@ -195,7 +221,8 @@ public class PreTrainedTokenizer: Tokenizer {
self.postProcessor = PostProcessorFactory.fromConfig(config: tokenizerData.postProcessor)
self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder)
self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true

self.tokenizerConfig = tokenizerConfig

model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
}

Expand All @@ -209,9 +236,9 @@ public class PreTrainedTokenizer: Tokenizer {
return normalizer(text: text)
}

func postProcess(_ tokens: [String]) -> [String] {
func postProcess(_ tokens: [String], addSpecialTokens: Bool = true) -> [String] {
guard let postProcessor = postProcessor else { return tokens }
return postProcessor(tokens: tokens)
return postProcessor(tokens: tokens, addSpecialTokens: addSpecialTokens)
}

func decodeTokens(_ tokens: [String]) -> [String] {
Expand Down Expand Up @@ -265,8 +292,12 @@ public class PreTrainedTokenizer: Tokenizer {
}

/// Main entry point
public func encode(text: String, addSpecialTokens: Bool = true) -> [Int] {
return postProcess(tokenize(text: text), addSpecialTokens: addSpecialTokens).map { model.convertTokenToId($0)! }
}

public func encode(text: String) -> [Int] {
return postProcess(tokenize(text: text)).map { model.convertTokenToId($0)! }
return encode(text: text, addSpecialTokens: true)
}

/// Decode
Expand All @@ -285,6 +316,43 @@ public class PreTrainedTokenizer: Tokenizer {
public func convertIdToToken(_ id: Int) -> String? {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, maxLength: nil)
}

public func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
maxLength: Int?
) throws -> [Int] {
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
]

// TODO: maybe keep NSString here
for (key, value) in tokenizerConfig.dictionary as [String : Any] {
if specialTokenAttributes.contains(key), !(value is NSNull) {
context[key] = value
}
}

let rendered = try template.render(context)
var encodedTokens = encode(text: rendered, addSpecialTokens: false)
var maxLength = maxLength ?? encodedTokens.count
maxLength = min(maxLength, tokenizerConfig.modelMaxLength?.intValue ?? maxLength)
if encodedTokens.count > maxLength {
if truncation {
encodedTokens = Array(encodedTokens.prefix(maxLength))
}
}

return encodedTokens
}
}

// MARK: - Building
Expand Down