From 9e7637ad7583f3cb3e9010d0aee5c7c5c24322f2 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Fri, 3 Jan 2025 14:24:08 +0100 Subject: [PATCH 1/3] Add swift-format config --- .swift-format | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 .swift-format diff --git a/.swift-format b/.swift-format new file mode 100644 index 0000000..ae77f3e --- /dev/null +++ b/.swift-format @@ -0,0 +1,9 @@ +{ + "version": 1, + "indentation": { + "spaces": 4 + }, + "lineLength": 120, + "multiElementCollectionTrailingCommas": true, + "spacesAroundRangeFormationOperators": true +} From bf03133d5a4bfdc4a6083251d6f683fdc221fdeb Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Fri, 3 Jan 2025 14:24:18 +0100 Subject: [PATCH 2/3] Add .pre-commit-config.yaml --- .pre-commit-config.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..28c3445 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/slessans/pre-commit-swift-format + rev: "fd627de92bdf84a75c924ed95691336d14e94cf1" + hooks: + - id: swift-format + args: ["--configuration", ".swift-format"] From e404f264567c1cf8cda0162896ba8b098593491d Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Sun, 19 Jan 2025 20:03:36 +0100 Subject: [PATCH 3/3] Format with swift-format --- Package.swift | 16 +- Sources/Generation/Generation.swift | 34 ++- Sources/Generation/GenerationConfig.swift | 26 +- Sources/Hub/Downloader.swift | 23 +- Sources/Hub/Hub.swift | 49 ++-- Sources/Hub/HubApi.swift | 235 ++++++++++-------- Sources/HubCLI/HubCLI.swift | 21 +- Sources/Models/LanguageModel.swift | 114 +++++---- Sources/Models/LanguageModelTypes.swift | 21 +- .../TemperatureLogitsWarper.swift | 2 +- .../LogitsWarper/TopKLogitsWarper.swift | 4 +- Sources/TensorUtils/MLMultiArray+Utils.swift | 80 +++--- Sources/TensorUtils/MLShapedArray+Utils.swift | 18 +- Sources/TensorUtils/Math.swift | 38 +-- Sources/TensorUtils/Weights.swift | 13 +- Sources/Tokenizers/BPETokenizer.swift | 49 ++-- Sources/Tokenizers/BertTokenizer.swift | 77 +++--- Sources/Tokenizers/ByteEncoder.swift | 2 +- Sources/Tokenizers/Decoder.swift | 91 ++++--- Sources/Tokenizers/Normalizer.swift | 11 +- Sources/Tokenizers/PostProcessor.swift | 33 +-- Sources/Tokenizers/PreTokenizer.swift | 79 +++--- Sources/Tokenizers/TokenLattice.swift | 57 +++-- Sources/Tokenizers/Tokenizer.swift | 143 ++++++----- Sources/Tokenizers/Trie.swift | 28 +-- Sources/Tokenizers/UnigramTokenizer.swift | 36 +-- Sources/Tokenizers/Utils.swift | 29 +-- Sources/TransformersCLI/main.swift | 19 +- Tests/HubTests/HubApiTests.swift | 47 ++-- Tests/HubTests/HubTests.swift | 16 +- Tests/NormalizerTests/NormalizerTests.swift | 6 +- .../PostProcessorTests.swift | 79 +++--- .../PreTokenizerTests/PreTokenizerTests.swift | 29 ++- .../TensorUtilsTests/LogitsWarperTests.swift | 14 +- Tests/TensorUtilsTests/TensorUtilsTests.swift | 7 +- Tests/TensorUtilsTests/WeightsTests.swift | 13 +- Tests/TokenizersTests/AddedTokensTests.swift | 4 +- .../TokenizersTests/BertTokenizerTests.swift | 98 ++++---- Tests/TokenizersTests/ChatTemplateTests.swift | 21 +- Tests/TokenizersTests/DecoderTests.swift | 16 +- Tests/TokenizersTests/FactoryTests.swift | 19 +- Tests/TokenizersTests/SplitTests.swift | 24 +- Tests/TokenizersTests/SquadDataset.swift | 7 +- Tests/TokenizersTests/TokenizerTests.swift | 83 ++++--- Tests/TokenizersTests/TrieTests.swift | 17 +- 45 files changed, 1008 insertions(+), 840 deletions(-) diff --git a/Package.swift b/Package.swift index 704e5c6..9841097 100644 --- a/Package.swift +++ b/Package.swift @@ -13,24 +13,30 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"), - .package(url: "https://github.com/maiqingqiang/Jinja", from: "1.0.6") + .package(url: "https://github.com/maiqingqiang/Jinja", from: "1.0.6"), ], targets: [ .executableTarget( name: "TransformersCLI", dependencies: [ "Models", "Generation", "Tokenizers", - .product(name: "ArgumentParser", package: "swift-argument-parser")]), - .executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]), + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ]), + .executableTarget( + name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]), .target(name: "Hub", resources: [.process("FallbackConfigs")]), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]), .target(name: "TensorUtils"), .target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]), .target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]), - .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]), + .testTarget( + name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], + resources: [.process("Resources"), .process("Vocabs")]), .testTarget(name: "HubTests", dependencies: ["Hub"]), .testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]), - .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]), + .testTarget( + name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")] + ), .testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]), .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]), ] diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 6cfd8ab..93a21b2 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -1,13 +1,13 @@ // // Generation.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // -import Tokenizers import CoreML import TensorUtils +import Tokenizers public enum GenerationMode { case contrastiveSearch @@ -29,13 +29,20 @@ public typealias PredictionStringCallback = (String) -> Void // TODO: callbacks (for streaming) public protocol Generation { - func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput - - func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String + func greedySearch( + config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? + ) async -> GenerationOutput + + func generate( + config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, + callback: PredictionStringCallback? + ) async -> String } -public extension Generation { - func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput { +extension Generation { + public func greedySearch( + config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil + ) async -> GenerationOutput { // Iterate until we find the eos token or reach the max length // TODO: additional stopping criteria var outputTokens = tokens @@ -48,9 +55,11 @@ public extension Generation { } return outputTokens } - + /// https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552 - func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput { + public func sample( + config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil + ) async -> GenerationOutput { // Iterate until we find the eos token or reach the max length // TODO: additional stopping criteria var outputTokens = tokens @@ -68,7 +77,10 @@ public extension Generation { return outputTokens } - func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback? = nil) async -> String { + public func generate( + config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, + callback: PredictionStringCallback? = nil + ) async -> String { let tokens = tokenizer.encode(text: prompt) var generationConfig = config generationConfig.maxLength = config.maxNewTokens + tokens.count @@ -86,7 +98,7 @@ public extension Generation { default: fatalError("Generation mode \(generationConfig.generationMode) not implemented yet") } - + return tokenizer.decode(tokens: output) } diff --git a/Sources/Generation/GenerationConfig.swift b/Sources/Generation/GenerationConfig.swift index a9eee7b..93c7437 100644 --- a/Sources/Generation/GenerationConfig.swift +++ b/Sources/Generation/GenerationConfig.swift @@ -1,6 +1,6 @@ // // GenerationConfig.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // @@ -19,12 +19,16 @@ public struct GenerationConfig { public var topK = 50 public var topP = 1.0 public var repetitionPenalty = 1.0 - + public var padTokenId: Int? = nil public var bosTokenId: Int? = nil public var eosTokenId: Int? = nil - - public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) { + + public init( + maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, + penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, + repetitionPenalty: Double = 1.0 + ) { self.maxLength = maxLength self.maxNewTokens = maxNewTokens self.doSample = doSample @@ -38,19 +42,19 @@ public struct GenerationConfig { } } -public extension GenerationConfig { - var generationMode: GenerationMode { +extension GenerationConfig { + public var generationMode: GenerationMode { // Exclude this case from the pattern matching below if topK > 1 && !doSample && penaltyAlpha != nil && penaltyAlpha! > 0 { return .contrastiveSearch } - + switch (numBeams, numBeamGroups, doSample) { - case (1, 1, false) : return .greedy - case (1, 1, true) : return .sample + case (1, 1, false): return .greedy + case (1, 1, true): return .sample case (2..., 1, false): return .beam - case (2..., 2..., _) : return .groupBeam - default : return .unsupported + case (2..., 2..., _): return .groupBeam + default: return .unsupported } } } diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift index 1f3b8c6..aed0abf 100644 --- a/Sources/Hub/Downloader.swift +++ b/Sources/Hub/Downloader.swift @@ -6,8 +6,8 @@ // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE // -import Foundation import Combine +import Foundation class Downloader: NSObject, ObservableObject { private(set) var destination: URL @@ -86,16 +86,16 @@ class Downloader: NSObject, ObservableObject { stateSubscriber = downloadState.sink { state in switch state { case .completed: semaphore.signal() - case .failed: semaphore.signal() - default: break + case .failed: semaphore.signal() + default: break } } semaphore.wait() switch downloadState.value { case .completed(let url): return url - case .failed(let error): throw error - default: throw DownloadError.unexpectedError + case .failed(let error): throw error + default: throw DownloadError.unexpectedError } } @@ -105,7 +105,10 @@ class Downloader: NSObject, ObservableObject { } extension Downloader: URLSessionDownloadDelegate { - func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) { + func urlSession( + _: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, + totalBytesExpectedToWrite: Int64 + ) { downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite)) } @@ -122,10 +125,10 @@ extension Downloader: URLSessionDownloadDelegate { func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { if let error = error { downloadState.value = .failed(error) -// } else if let response = task.response as? HTTPURLResponse { -// print("HTTP response status code: \(response.statusCode)") -// let headers = response.allHeaderFields -// print("HTTP response headers: \(headers)") + // } else if let response = task.response as? HTTPURLResponse { + // print("HTTP response status code: \(response.statusCode)") + // let headers = response.allHeaderFields + // print("HTTP response headers: \(headers)") } } } diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 4116dcb..bb53922 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -1,6 +1,6 @@ // // Hub.swift -// +// // // Created by Pedro Cuenca on 18/5/23. // @@ -9,24 +9,24 @@ import Foundation public struct Hub {} -public extension Hub { - enum HubClientError: Error { +extension Hub { + public enum HubClientError: Error { case parse case authorizationRequired case unexpectedError case httpStatusCode(Int) } - - enum RepoType: String { + + public enum RepoType: String { case models case datasets case spaces } - - struct Repo { + + public struct Repo { public let id: String public let type: RepoType - + public init(id: String, type: RepoType = .models) { self.id = id self.type = type @@ -45,17 +45,18 @@ public struct Config { } func camelCase(_ string: String) -> String { - return string + return + string .split(separator: "_") .enumerated() .map { $0.offset == 0 ? $0.element.lowercased() : $0.element.capitalized } .joined() } - + func uncamelCase(_ string: String) -> String { let scalars = string.unicodeScalars var result = "" - + var previousCharacterIsLowercase = false for scalar in scalars { if CharacterSet.uppercaseLetters.contains(scalar) { @@ -70,11 +71,10 @@ public struct Config { previousCharacterIsLowercase = true } } - + return result } - public subscript(dynamicMember member: String) -> Config? { let key = (dictionary[member as NSString] != nil ? member : uncamelCase(member)) as NSString if let value = dictionary[key] as? [NSString: Any] { @@ -88,17 +88,17 @@ public struct Config { public var value: Any? { return dictionary["value"] } - + public var intValue: Int? { value as? Int } public var boolValue: Bool? { value as? Bool } public var stringValue: String? { value as? String } - + // Instead of doing this we could provide custom classes and decode to them public var arrayValue: [Config]? { guard let list = value as? [Any] else { return nil } - return list.map { Config($0 as! [NSString : Any]) } + return list.map { Config($0 as! [NSString: Any]) } } - + /// Tuple of token identifier and string value public var tokenValue: (UInt, String)? { value as? (UInt, String) } } @@ -120,7 +120,7 @@ public class LanguageModelConfigurationFromHub { return try await self.loadConfig(modelName: modelName, hubApi: hubApi) } } - + public init( modelFolder: URL, hubApi: HubApi = .shared @@ -140,12 +140,13 @@ public class LanguageModelConfigurationFromHub { get async throws { if let hubConfig = try await configPromise!.value.tokenizerConfig { // Try to guess the class if it's not present and the modelType is - if let _ = hubConfig.tokenizerClass?.stringValue { return hubConfig } + if hubConfig.tokenizerClass?.stringValue != nil { return hubConfig } guard let modelType = try await modelType else { return hubConfig } // If the config exists but doesn't contain a tokenizerClass, use a fallback config if we have it if let fallbackConfig = Self.fallbackTokenizerConfig(for: modelType) { - let configuration = fallbackConfig.dictionary.merging(hubConfig.dictionary, uniquingKeysWith: { current, _ in current }) + let configuration = fallbackConfig.dictionary.merging( + hubConfig.dictionary, uniquingKeysWith: { current, _ in current }) return Config(configuration) } @@ -183,7 +184,7 @@ public class LanguageModelConfigurationFromHub { return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi) } - + func loadConfig( modelFolder: URL, hubApi: HubApi = .shared @@ -192,7 +193,7 @@ public class LanguageModelConfigurationFromHub { let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json")) let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json")) let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json")) - + let configs = Configurations( modelConfig: modelConfig, tokenizerConfig: tokenizerConfig, @@ -202,7 +203,9 @@ public class LanguageModelConfigurationFromHub { } static func fallbackTokenizerConfig(for modelType: String) -> Config? { - guard let url = Bundle.module.url(forResource: "\(modelType)_tokenizer_config", withExtension: "json") else { return nil } + guard let url = Bundle.module.url(forResource: "\(modelType)_tokenizer_config", withExtension: "json") else { + return nil + } do { let data = try Data(contentsOf: url) let parsed = try JSONSerialization.jsonObject(with: data, options: []) diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index fdf1256..0694e47 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -15,8 +15,11 @@ public struct HubApi { public typealias RepoType = Hub.RepoType public typealias Repo = Hub.Repo - - public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false) { + + public init( + downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", + useBackgroundSession: Bool = false + ) { self.hfToken = hfToken ?? Self.hfTokenFromEnv() if let downloadBase { self.downloadBase = downloadBase @@ -27,12 +30,12 @@ public struct HubApi { self.endpoint = endpoint self.useBackgroundSession = useBackgroundSession } - + public static let shared = HubApi() } -private extension HubApi { - static func hfTokenFromEnv() -> String? { +extension HubApi { + fileprivate static func hfTokenFromEnv() -> String? { let possibleTokens = [ { ProcessInfo.processInfo.environment["HF_TOKEN"] }, { ProcessInfo.processInfo.environment["HUGGING_FACE_HUB_TOKEN"] }, @@ -52,7 +55,7 @@ private extension HubApi { ) } }, - { try? String(contentsOf: .homeDirectory.appendingPathComponent(".huggingface/token"), encoding: .utf8) } + { try? String(contentsOf: .homeDirectory.appendingPathComponent(".huggingface/token"), encoding: .utf8) }, ] return possibleTokens .lazy @@ -63,35 +66,35 @@ private extension HubApi { } /// File retrieval -public extension HubApi { +extension HubApi { /// Model data for parsed filenames - struct Sibling: Codable { + public struct Sibling: Codable { let rfilename: String } - - struct SiblingsResponse: Codable { + + public struct SiblingsResponse: Codable { let siblings: [Sibling] } - + /// Throws error if the response code is not 20X - func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { + public func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) if let hfToken = hfToken { request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") } let (data, response) = try await URLSession.shared.data(for: request) guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } - + switch response.statusCode { - case 200..<300: break - case 400..<500: throw Hub.HubClientError.authorizationRequired + case 200 ..< 300: break + case 400 ..< 500: throw Hub.HubClientError.authorizationRequired default: throw Hub.HubClientError.httpStatusCode(response.statusCode) } return (data, response) } - - func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { + + public func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) request.httpMethod = "HEAD" if let hfToken = hfToken { @@ -102,54 +105,54 @@ public extension HubApi { guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } switch response.statusCode { - case 200..<300: break - case 400..<500: throw Hub.HubClientError.authorizationRequired + case 200 ..< 300: break + case 400 ..< 500: throw Hub.HubClientError.authorizationRequired default: throw Hub.HubClientError.httpStatusCode(response.statusCode) } - + return (data, response) } - - func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] { + + public func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] { // Read repo info and only parse "siblings" let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)")! let (data, _) = try await httpGet(for: url) let response = try JSONDecoder().decode(SiblingsResponse.self, from: data) let filenames = response.siblings.map { $0.rfilename } guard globs.count > 0 else { return filenames } - + var selected: Set = [] for glob in globs { selected = selected.union(filenames.matching(glob: glob)) } return Array(selected) } - - func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { + + public func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { return try await getFilenames(from: Repo(id: repoId), matching: globs) } - - func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { + + public func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { return try await getFilenames(from: repo, matching: [glob]) } - - func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { + + public func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { return try await getFilenames(from: Repo(id: repoId), matching: [glob]) } } /// Configuration loading helpers -public extension HubApi { +extension HubApi { /// Assumes the file has already been downloaded. /// `filename` is relative to the download base. - func configuration(from filename: String, in repo: Repo) throws -> Config { + public func configuration(from filename: String, in repo: Repo) throws -> Config { let fileURL = localRepoLocation(repo).appending(path: filename) return try configuration(fileURL: fileURL) } - + /// Assumes the file is already present at local url. /// `fileURL` is a complete local file path for the given model - func configuration(fileURL: URL) throws -> Config { + public func configuration(fileURL: URL) throws -> Config { let data = try Data(contentsOf: fileURL) let parsed = try JSONSerialization.jsonObject(with: data, options: []) guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse } @@ -158,10 +161,10 @@ public extension HubApi { } /// Whoami -public extension HubApi { - func whoami() async throws -> Config { +extension HubApi { + public func whoami() async throws -> Config { guard hfToken != nil else { throw Hub.HubClientError.authorizationRequired } - + let url = URL(string: "\(endpoint)/api/whoami-v2")! let (data, _) = try await httpGet(for: url) @@ -172,12 +175,12 @@ public extension HubApi { } /// Snaphsot download -public extension HubApi { - func localRepoLocation(_ repo: Repo) -> URL { +extension HubApi { + public func localRepoLocation(_ repo: Repo) -> URL { downloadBase.appending(component: repo.type.rawValue).appending(component: repo.id) } - - struct HubFileDownloader { + + public struct HubFileDownloader { let repo: Repo let repoDestination: URL let relativeFilename: String @@ -192,22 +195,23 @@ public extension HubApi { url = url.appending(component: repo.type.rawValue) } url = url.appending(path: repo.id) - url = url.appending(path: "resolve/main") // TODO: revisions + url = url.appending(path: "resolve/main") // TODO: revisions url = url.appending(path: relativeFilename) return url } - + var destination: URL { repoDestination.appending(path: relativeFilename) } - + var downloaded: Bool { FileManager.default.fileExists(atPath: destination.path) } - + func prepareDestination() throws { let directoryURL = destination.deletingLastPathComponent() - try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) + try FileManager.default.createDirectory( + at: directoryURL, withIntermediateDirectories: true, attributes: nil) } // Note we go from Combine in Downloader to callback-based progress reporting @@ -232,7 +236,9 @@ public extension HubApi { } @discardableResult - func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot( + from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { let filenames = try await getFilenames(from: repo, matching: globs) let progress = Progress(totalUnitCount: Int64(filenames.count)) let repoDestination = localRepoLocation(repo) @@ -255,36 +261,42 @@ public extension HubApi { progressHandler(progress) return repoDestination } - + @discardableResult - func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot( + from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } - + @discardableResult - func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot( + from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler) } - + @discardableResult - func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot( + from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler) } } /// Metadata -public extension HubApi { +extension HubApi { /// A structure representing metadata for a remote file - struct FileMetadata { + public struct FileMetadata { /// The file's Git commit hash public let commitHash: String? - + /// Server-provided ETag for caching public let etag: String? - + /// Stringified URL location of the file public let location: String - + /// The file's size in bytes public let size: Int? } @@ -293,105 +305,122 @@ public extension HubApi { guard let etag = etag else { return nil } return etag.trimmingPrefix("W/").trimmingCharacters(in: CharacterSet(charactersIn: "\"")) } - - func getFileMetadata(url: URL) async throws -> FileMetadata { + + public func getFileMetadata(url: URL) async throws -> FileMetadata { let (_, response) = try await httpHead(for: url) - + return FileMetadata( commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"), etag: normalizeEtag( (response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag")) ), location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString, - size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "") + size: Int( + response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value( + forHTTPHeaderField: "Content-Length") ?? "") ) } - - func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [FileMetadata] { + + public func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [FileMetadata] { let files = try await getFilenames(from: repo, matching: globs) - let url = URL(string: "\(endpoint)/\(repo.id)/resolve/main")! // TODO: revisions - var selectedMetadata: Array = [] + let url = URL(string: "\(endpoint)/\(repo.id)/resolve/main")! // TODO: revisions + var selectedMetadata: [FileMetadata] = [] for file in files { let fileURL = url.appending(path: file) selectedMetadata.append(try await getFileMetadata(url: fileURL)) } return selectedMetadata } - - func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [FileMetadata] { + + public func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [FileMetadata] { return try await getFileMetadata(from: Repo(id: repoId), matching: globs) } - - func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [FileMetadata] { + + public func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [FileMetadata] { return try await getFileMetadata(from: repo, matching: [glob]) } - - func getFileMetadata(from repoId: String, matching glob: String) async throws -> [FileMetadata] { + + public func getFileMetadata(from repoId: String, matching glob: String) async throws -> [FileMetadata] { return try await getFileMetadata(from: Repo(id: repoId), matching: [glob]) } } /// Stateless wrappers that use `HubApi` instances -public extension Hub { - static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { +extension Hub { + public static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { return try await HubApi.shared.getFilenames(from: repo, matching: globs) } - - static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { + + public static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs) } - - static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { + + public static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { return try await HubApi.shared.getFilenames(from: repo, matching: glob) } - - static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { + + public static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob) } - - static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + + public static func snapshot( + from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) } - - static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) + + public static func snapshot( + from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { + return try await HubApi.shared.snapshot( + from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } - - static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + + public static func snapshot( + from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler) } - - static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) + + public static func snapshot( + from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { + return try await HubApi.shared.snapshot( + from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) } - - static func whoami(token: String) async throws -> Config { + + public static func whoami(token: String) async throws -> Config { return try await HubApi(hfToken: token).whoami() } - - static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata { + + public static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata { return try await HubApi.shared.getFileMetadata(url: fileURL) } - - static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { + + public static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi + .FileMetadata] + { return try await HubApi.shared.getFileMetadata(from: repo, matching: globs) } - - static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { + + public static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi + .FileMetadata] + { return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs) } - - static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] { + + public static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] { return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob]) } - - static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] { + + public static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] + { return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob]) } } -public extension [String] { - func matching(glob: String) -> [String] { +extension [String] { + public func matching(glob: String) -> [String] { filter { fnmatch(glob, $0, 0) == 0 } } } diff --git a/Sources/HubCLI/HubCLI.swift b/Sources/HubCLI/HubCLI.swift index fb0cc72..99d27f9 100644 --- a/Sources/HubCLI/HubCLI.swift +++ b/Sources/HubCLI/HubCLI.swift @@ -1,6 +1,5 @@ import ArgumentParser import Foundation - import Hub let defaultTokenLocation = NSString("~/.cache/huggingface/token").expandingTildeInPath @@ -33,7 +32,7 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { case model case dataset case space - + var asHubApiRepoType: HubApi.RepoType { switch self { case .model: return .models @@ -42,7 +41,7 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { } } } - + @Argument(help: "Repo ID") var repo: String @@ -51,17 +50,19 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { @Option(help: "Glob patterns for files to include") var include: [String] = [] - + @Option(help: "Hugging Face token. If empty, will attempt to read from the filesystem at \(defaultTokenLocation)") var token: String? = nil - + func run() async throws { let hubApi = HubApi(hfToken: hfToken) let repo = Hub.Repo(id: repo, type: repoType.asHubApiRepoType) let downloadedTo = try await hubApi.snapshot(from: repo, matching: include) { progress in DispatchQueue.main.async { let totalPercent = 100 * progress.fractionCompleted - print("\(progress.completedUnitCount)/\(progress.totalUnitCount) \(totalPercent.formatted("%.02f"))%", terminator: "\r") + print( + "\(progress.completedUnitCount)/\(progress.totalUnitCount) \(totalPercent.formatted("%.02f"))%", + terminator: "\r") fflush(stdout) } } @@ -71,16 +72,16 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { struct Whoami: AsyncParsableCommand, SubcommandWithToken { static let configuration = CommandConfiguration(abstract: "whoami") - + @Option(help: "Hugging Face token. If empty, will attempt to read from the filesystem at \(defaultTokenLocation)") var token: String? = nil - + func run() async throws { let hubApi = HubApi(hfToken: hfToken) let userInfo = try await hubApi.whoami() if let name = userInfo.name?.stringValue, - let fullname = userInfo.fullname?.stringValue, - let email = userInfo.email?.stringValue + let fullname = userInfo.fullname?.stringValue, + let email = userInfo.email?.stringValue { print("\(name) (\(fullname) <\(email)>)") } else { diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 457755a..0e87e00 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -1,24 +1,24 @@ // // LanguageModel.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // import CoreML -import Tokenizers import Generation import Hub +import Tokenizers public class LanguageModel { public let model: MLModel - + public let minContextLength: Int public let maxContextLength: Int - + let input_ids = "input_ids" let attention_mask = "attention_mask" - + struct Configurations { var modelConfig: Config var tokenizerConfig: Config? @@ -30,15 +30,15 @@ public class LanguageModel { public required init(model: MLModel) { self.model = model - + // We assume inputs named "input_ids" with shape (1, seq_length) // Perhaps we should convert to vectors of shape (seq_length) and use sequenceConstraint instead of shapeConstraint let inputDescription = model.modelDescription.inputDescriptionsByName["input_ids"] - + guard let shapeConstraint = inputDescription?.multiArrayConstraint?.shapeConstraint else { fatalError("Cannot obtain shape information") } - + switch shapeConstraint.type { case .enumerated: // TODO: support a set of fixed shapes (keeping the first one here) @@ -55,13 +55,13 @@ public class LanguageModel { minContextLength = 128 maxContextLength = 128 } - + self.configuration = LanguageModelConfigurationFromHub(modelName: modelName) } } -public extension LanguageModel { - static func loadCompiled(url: URL, computeUnits: MLComputeUnits = .cpuAndGPU) throws -> LanguageModel { +extension LanguageModel { + public static func loadCompiled(url: URL, computeUnits: MLComputeUnits = .cpuAndGPU) throws -> LanguageModel { let config = MLModelConfiguration() config.computeUnits = computeUnits let model = try MLModel(contentsOf: url, configuration: config) @@ -69,63 +69,67 @@ public extension LanguageModel { } } -public extension LanguageModel { - var description: String { +extension LanguageModel { + public var description: String { if let description = model.modelDescription.metadata[MLModelMetadataKey.description] as? String, - !description.isEmpty { + !description.isEmpty + { return description } return model.configuration.modelDisplayName ?? "" } - + /// `name_or_path` in the Python world - var modelName: String { - if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String : String], - let name = userFields["co.huggingface.exporters.name"] { + public var modelName: String { + if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String], + let name = userFields["co.huggingface.exporters.name"] + { return name } // This is usually the basename of the file, that's our best bet if no metadata exists - guard let modelName = model.configuration.modelDisplayName else { fatalError("Models must have a name that identifies them") } + guard let modelName = model.configuration.modelDisplayName else { + fatalError("Models must have a name that identifies them") + } return modelName } - - var inputIdsDescription: MLFeatureDescription { + + public var inputIdsDescription: MLFeatureDescription { model.modelDescription.inputDescriptionsByName[input_ids]! } - - var inputIdsName: String { + + public var inputIdsName: String { inputIdsDescription.name } - + /// The expected shape of the models latent sample input - var inputIdsShape: [Int] { + public var inputIdsShape: [Int] { inputIdsDescription.multiArrayConstraint!.shape.map { $0.intValue } } - - var requiresAttention: Bool { + + public var requiresAttention: Bool { model.modelDescription.inputDescriptionsByName[attention_mask] != nil } - + // MLShapedArrayProtocol is either a MLShapedArray or a MLShapedArraySlice - func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { + public func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { // TODO: exceptions - + // Maybe pad or truncate let maxTokens = min(tokens.count, maxContextLength) - let padLength = maxTokens >= minContextLength ? 0 : minContextLength-maxTokens - let inputTokens = Array(tokens[0..= minContextLength ? 0 : minContextLength - maxTokens + let inputTokens = Array(tokens[0 ..< maxTokens]) + Array(repeating: config.padTokenId ?? 0, count: padLength) + let inputIds = MLShapedArray(scalars: inputTokens.map { Int32($0) }, shape: inputIdsShape) var inputDictionary = [inputIdsName: MLFeatureValue(shapedArray: inputIds)] if requiresAttention { let mask = Array(repeating: 1, count: maxTokens) + Array(repeating: 0, count: padLength) - let attentionMask = MLShapedArray(scalars: mask.map{ Int32($0) }, shape: inputIdsShape) + let attentionMask = MLShapedArray(scalars: mask.map { Int32($0) }, shape: inputIdsShape) inputDictionary[attention_mask] = MLFeatureValue(shapedArray: attentionMask) } let input = try! MLDictionaryFeatureProvider(dictionary: inputDictionary) - + let output = try! model.prediction(from: input) - + // TODO: maybe try to support models with "token_scores" too (after the softmax) assert(output.featureNames.first! == "logits") @@ -136,61 +140,63 @@ public extension LanguageModel { } /// async properties downloaded from the configuration -public extension LanguageModel { - var modelConfig: Config { +extension LanguageModel { + public var modelConfig: Config { get async throws { try await configuration!.modelConfig } } - - var tokenizerConfig: Config? { + + public var tokenizerConfig: Config? { get async throws { try await configuration!.tokenizerConfig } } - - var tokenizerData: Config { + + public var tokenizerData: Config { get async throws { try await configuration!.tokenizerData } } - - var modelType: String? { + + public var modelType: String? { get async throws { try await modelConfig.modelType?.stringValue } } - - var textGenerationParameters: Config? { + + public var textGenerationParameters: Config? { get async throws { try await modelConfig.taskSpecificParams?.textGeneration } } - - var defaultDoSample: Bool { + + public var defaultDoSample: Bool { get async throws { try await textGenerationParameters?.doSample?.boolValue ?? true } } - var bosTokenId: Int? { + public var bosTokenId: Int? { get async throws { let modelConfig = try await modelConfig return modelConfig.bosTokenId?.intValue } } - - var eosTokenId: Int? { + + public var eosTokenId: Int? { get async throws { let modelConfig = try await modelConfig return modelConfig.eosTokenId?.intValue } } - - var tokenizer: Tokenizer { + + public var tokenizer: Tokenizer { get async throws { guard _tokenizer == nil else { return _tokenizer! } - guard let tokenizerConfig = try await tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" } + guard let tokenizerConfig = try await tokenizerConfig else { + throw "Cannot retrieve Tokenizer configuration" + } let tokenizerData = try await tokenizerData _tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) return _tokenizer! diff --git a/Sources/Models/LanguageModelTypes.swift b/Sources/Models/LanguageModelTypes.swift index 08d7d48..b22e40d 100644 --- a/Sources/Models/LanguageModelTypes.swift +++ b/Sources/Models/LanguageModelTypes.swift @@ -1,13 +1,13 @@ // // LanguageModelTypes.swift -// +// // // Created by Pedro Cuenca on 8/5/23. // import CoreML -import Tokenizers import Generation +import Tokenizers public protocol LanguageModelProtocol { /// `name_or_path` in the Python world @@ -15,16 +15,16 @@ public protocol LanguageModelProtocol { var tokenizer: Tokenizer { get async throws } var model: MLModel { get } - + init(model: MLModel) - + /// Make prediction callable (this works like __call__ in Python) func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol } -public extension LanguageModelProtocol { - func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { +extension LanguageModelProtocol { + public func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { predictNextTokenScores(tokens, config: config) } } @@ -34,9 +34,12 @@ public protocol TextGenerationModel: Generation, LanguageModelProtocol { func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback?) async throws -> String } -public extension TextGenerationModel { +extension TextGenerationModel { @discardableResult - func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async throws -> String { - try await self.generate(config: config, prompt: prompt, model: self.callAsFunction, tokenizer: self.tokenizer, callback: callback) + public func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) + async throws -> String + { + try await self.generate( + config: config, prompt: prompt, model: self.callAsFunction, tokenizer: self.tokenizer, callback: callback) } } diff --git a/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift index 53dc0db..9eb1630 100644 --- a/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift @@ -2,7 +2,7 @@ import Foundation public struct TemperatureLogitsWarper: LogitsWarper { public var temperature: Float - + public init(temperature: Float) { self.temperature = temperature } diff --git a/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift b/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift index a236d84..580b199 100644 --- a/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift @@ -1,5 +1,5 @@ -import Foundation import Accelerate +import Foundation /// Top-K. /// Select the k most-probable element indices from `arr` @@ -7,7 +7,7 @@ import Accelerate /// and their probabilities. public struct TopKLogitsWarper: LogitsWarper { public var k: Int - + public init(k: Int) { self.k = k } diff --git a/Sources/TensorUtils/MLMultiArray+Utils.swift b/Sources/TensorUtils/MLMultiArray+Utils.swift index ddb2760..c4eda3b 100644 --- a/Sources/TensorUtils/MLMultiArray+Utils.swift +++ b/Sources/TensorUtils/MLMultiArray+Utils.swift @@ -6,12 +6,12 @@ // Copyright © 2019 Hugging Face. All rights reserved. // -import Foundation import CoreML +import Foundation -public extension MLMultiArray { +extension MLMultiArray { /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) - static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { + public static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { var shape = Array(repeating: 1, count: dims) shape[shape.count - 1] = arr.count /// Examples: @@ -25,9 +25,9 @@ public extension MLMultiArray { } return o } - + /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) - static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray { + public static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray { var shape = Array(repeating: 1, count: dims) shape[shape.count - 1] = arr.count /// Examples: @@ -41,31 +41,31 @@ public extension MLMultiArray { } return o } - + /// This will concatenate all dimensions into one one-dim array. - static func toIntArray(_ o: MLMultiArray) -> [Int] { + public static func toIntArray(_ o: MLMultiArray) -> [Int] { var arr = Array(repeating: 0, count: o.count) let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for i in 0.. [Int] { Self.toIntArray(self) } - + + public func toIntArray() -> [Int] { Self.toIntArray(self) } + /// This will concatenate all dimensions into one one-dim array. - static func toDoubleArray(_ o: MLMultiArray) -> [Double] { + public static func toDoubleArray(_ o: MLMultiArray) -> [Double] { var arr: [Double] = Array(repeating: 0, count: o.count) let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for i in 0.. [Double] { Self.toDoubleArray(self) } - + + public func toDoubleArray() -> [Double] { Self.toDoubleArray(self) } + /// Helper to construct a sequentially-indexed multi array, /// useful for debugging and unit tests /// Example in 3 dimensions: @@ -77,29 +77,28 @@ public extension MLMultiArray { /// [ 16, 17, 18, 19 ], /// [ 20, 21, 22, 23 ]]] /// ``` - static func testTensor(shape: [Int]) -> MLMultiArray { + public static func testTensor(shape: [Int]) -> MLMultiArray { let arr = try! MLMultiArray(shape: shape as [NSNumber], dataType: .double) let ptr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) - for i in 0.. MLMultiArray { + public static func slice(_ o: MLMultiArray, indexing: [Indexing]) -> MLMultiArray { assert( indexing.count == o.shape.count ) @@ -118,12 +117,12 @@ public extension MLMultiArray { selectDims: selectDims ) } - + /// Slice an array according to a list, according to `sliceDim` (which dimension to slice on) /// and a dictionary of `dim` to `index`. /// /// You must select all other dimensions than the slice dimension (cf. the assert). - static func slice(_ o: MLMultiArray, sliceDim: Int, selectDims: [Int: Int]) -> MLMultiArray { + public static func slice(_ o: MLMultiArray, sliceDim: Int, selectDims: [Int: Int]) -> MLMultiArray { assert( selectDims.count + 1 == o.shape.count ) @@ -131,13 +130,13 @@ public extension MLMultiArray { shape[sliceDim] = o.shape[sliceDim] /// print("About to slice ndarray of shape \(o.shape) into ndarray of shape \(shape)") let arr = try! MLMultiArray(shape: shape, dataType: .double) - + /// let srcPtr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) /// TODO: use srcPtr instead of array subscripting. let dstPtr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) - for i in 0.. String { return String(repeating: " ", count: x) } - + // This function is called recursively for every dimension. // Add an entry for this dimension to the end of the array. var indices = indices + [0] - - let d = indices.count - 1 // the current dimension - let N = shape[d].intValue // how many elements in this dimension + + let d = indices.count - 1 // the current dimension + let N = shape[d].intValue // how many elements in this dimension var s = "[" - if indices.count < shape.count { // not last dimension yet? - for i in 0.. { - var floats: [Float] { +extension MLShapedArray { + public var floats: [Float] { guard self.strides.first == 1, self.strides.count == 1 else { // For some reason this path is slow. // If strides is not 1, we can write a Metal kernel to copy the values properly. return self.scalars } - + // Fast path: memcpy let mlArray = MLMultiArray(self) return mlArray.floats ?? self.scalars } } -public extension MLShapedArraySlice { - var floats: [Float] { +extension MLShapedArraySlice { + public var floats: [Float] { guard self.strides.first == 1, self.strides.count == 1 else { // For some reason this path is slow. // If strides is not 1, we can write a Metal kernel to copy the values properly. @@ -35,10 +35,10 @@ public extension MLShapedArraySlice { } } -public extension MLMultiArray { - var floats: [Float]? { +extension MLMultiArray { + public var floats: [Float]? { guard self.dataType == .float32 else { return nil } - + var result: [Float] = Array(repeating: 0, count: self.count) return self.withUnsafeBytes { ptr in guard let source = ptr.baseAddress else { return nil } diff --git a/Sources/TensorUtils/Math.swift b/Sources/TensorUtils/Math.swift index 4050ac1..12ad579 100644 --- a/Sources/TensorUtils/Math.swift +++ b/Sources/TensorUtils/Math.swift @@ -6,9 +6,9 @@ // Copyright © 2019 Hugging Face. All rights reserved. // -import Foundation import Accelerate import CoreML +import Foundation /// /// From M.I. Hollemans @@ -16,10 +16,10 @@ import CoreML /// https://github.com/hollance/CoreMLHelpers /// public struct Math { - + /** Returns the index and value of the largest element in the array. - + - Parameters: - ptr: Pointer to the first element in memory. - count: How many elements to look at. @@ -31,7 +31,7 @@ public struct Math { vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) return (Int(maxIndex), maxValue) } - + /** Returns the index and value of the largest element in the array. - Parameters: @@ -45,14 +45,14 @@ public struct Math { vDSP_maxviD(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) return (Int(maxIndex), maxValue) } - + public static func argmax32(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) { var maxValue: Float = 0 var maxIndex: vDSP_Length = 0 vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) return (Int(maxIndex), maxValue) } - + /// MLMultiArray helper. /// Works in our specific use case. public static func argmax(_ multiArray: MLMultiArray) -> (Int, Double) { @@ -60,7 +60,7 @@ public struct Math { let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer)) return Math.argmax(ptr, count: multiArray.count) } - + /// MLMultiArray helper. /// Works in our specific use case. public static func argmax32(_ multiArray: MLMultiArray) -> (Int, Float) { @@ -88,7 +88,7 @@ public struct Math { let i = randomNumber(probabilities: probs) return indexes[i] } - + /** Computes the "softmax" function over an array. Based on code from https://github.com/nikolaypavlov/MLPNeuralNet/ @@ -103,31 +103,31 @@ public struct Math { public static func softmax(_ x: [Float]) -> [Float] { var x = x let len = vDSP_Length(x.count) - + // Find the maximum value in the input array. var max: Float = 0 vDSP_maxv(x, 1, &max, len) - + // Subtract the maximum from all the elements in the array. // Now the highest value in the array is 0. max = -max vDSP_vsadd(x, 1, &max, &x, 1, len) - + // Exponentiate all the elements in the array. var count = Int32(x.count) vvexpf(&x, x, &count) - + // Compute the sum of all exponentiated values. var sum: Float = 0 vDSP_sve(x, 1, &sum, len) - + // Divide each element by the sum. This normalizes the array contents // so that they all add up to 1. vDSP_vsdiv(x, 1, &sum, &x, 1, len) - + return x } - + /// Multinomial sampling /// /// From https://stackoverflow.com/questions/30309556/generate-random-numbers-with-a-given-distribution @@ -152,16 +152,16 @@ public struct Math { // MLShapedArray versions -public extension Math { - static func argmax(_ shapedArray: MLShapedArray) -> (Int, Float) { +extension Math { + public static func argmax(_ shapedArray: MLShapedArray) -> (Int, Float) { shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices") return Math.argmax32(ptr.baseAddress!, count: shapedArray.count, stride: strides.first!) } } - + // TODO: handle Double, etc. - static func argmax(_ shapedArray: some MLShapedArrayProtocol) -> (Int, Float) { + public static func argmax(_ shapedArray: some MLShapedArrayProtocol) -> (Int, Float) { shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices") let floatsPtr = ptr.baseAddress as! UnsafePointer diff --git a/Sources/TensorUtils/Weights.swift b/Sources/TensorUtils/Weights.swift index 2050e01..8529671 100644 --- a/Sources/TensorUtils/Weights.swift +++ b/Sources/TensorUtils/Weights.swift @@ -1,6 +1,5 @@ import CoreML - public struct Weights { enum WeightsError: Error { @@ -21,7 +20,7 @@ public struct Weights { else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") } let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) - switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) { + switch ([UInt8](data.subdata(in: 0 ..< 4)), [UInt8](data.subdata(in: 4 ..< 6))) { case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf")) case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") default: return try Safetensor.from(data: data) @@ -62,15 +61,15 @@ struct Safetensor { } static func from(data: Data) throws -> Weights { - let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) }) + let headerSize: Int = data.subdata(in: 0 ..< 8).withUnsafeBytes({ $0.load(as: Int.self) }) guard headerSize < data.count else { throw Error.invalidFile } - let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8))) + let header = try Header.from(data: data.subdata(in: 8 ..< (headerSize + 8))) var dict = [String: MLMultiArray]() for (key, point) in header { guard let offsets = point?.dataOffsets, offsets.count == 2, - let shape = point?.shape as? [NSNumber], - let dType = try point?.dataType + let shape = point?.shape as? [NSNumber], + let dType = try point?.dataType else { continue } let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in @@ -78,7 +77,7 @@ struct Safetensor { } let start = 8 + offsets[0] + headerSize let end = 8 + offsets[1] + headerSize - let tensorData = data.subdata(in: start.. Bool { return lhs.a == rhs.a && lhs.b == rhs.b } @@ -30,9 +30,8 @@ struct BytePair: Hashable { } } - class BPETokenizer: PreTrainedTokenizerModel { - let bpeRanks: Dictionary + let bpeRanks: [BytePair: Int] private let tokensToIds: [NSString: Int] private let idsToTokens: [Int: NSString] @@ -60,21 +59,23 @@ class BPETokenizer: PreTrainedTokenizerModel { } } - required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws { - guard let merges = Self.mergesFromConfig(tokenizerData.model?.merges) else { fatalError("BPETokenizer requires merges") } + required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws { + guard let merges = Self.mergesFromConfig(tokenizerData.model?.merges) else { + fatalError("BPETokenizer requires merges") + } guard let vocab = tokenizerData.model?.vocab?.dictionary as? [NSString: Int] else { throw TokenizerError.missingVocab } - var bpeRanks: Dictionary = [:] + var bpeRanks: [BytePair: Int] = [:] for (i, merge) in merges.enumerated() { let bp = BytePair(tuple: merge) bpeRanks[bp] = i } self.bpeRanks = bpeRanks - - self.tokensToIds = vocab.merging(addedTokens as [NSString : Int]) { $1 } + + self.tokensToIds = vocab.merging(addedTokens as [NSString: Int]) { $1 } self.idsToTokens = Utils.invert(self.tokensToIds) - + // Populate tokens if let unknownToken = TokenizerModel.unknownToken(from: tokenizerConfig) { self.unknownToken = unknownToken @@ -83,7 +84,7 @@ class BPETokenizer: PreTrainedTokenizerModel { self.unknownToken = nil self.unknownTokenId = nil } - + eosToken = tokenizerConfig.eosToken?.stringValue eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString] @@ -96,7 +97,7 @@ class BPETokenizer: PreTrainedTokenizerModel { func convertTokenToId(_ token: String) -> Int? { return tokensToIds[token as NSString] ?? self.unknownTokenId } - + func convertIdToToken(_ id: Int) -> String? { return idsToTokens[id] as String? } @@ -108,7 +109,7 @@ class BPETokenizer: PreTrainedTokenizerModel { return Array(token.utf8).compactMap { byteEncoder[$0] }.joined() } } - + func hexaEncode(text: String) -> [String] { let RE = #"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"# let tokens = text.ranges(of: RE).map { String(text[$0]) } @@ -116,27 +117,27 @@ class BPETokenizer: PreTrainedTokenizerModel { return Array(token.utf8).map { String(format: "<0x%02X>", $0) } } } - + private func getPairs(word: [String]) -> Set { var s = Set() - for i in 0.. String { if token.count <= 1 { return token } - + var word = Array(token).map { String($0) } var pairs = Array(getPairs(word: word)) - + while true { let bigrams = pairs.filter { (bp) -> Bool in bpeRanks[bp] != nil } if bigrams.count == 0 { @@ -150,16 +151,16 @@ class BPETokenizer: PreTrainedTokenizerModel { var newWord: [String] = [] var i = 0 while i < word.count { - if let j = word[i.. [String] { let text = tokenizeChineseCharsIfNeed(text) var tokens: [String] = [] @@ -69,7 +72,7 @@ public class BertTokenizer { } return tokens } - + private func convertTokensToIds(tokens: [String]) throws -> [Int] { if tokens.count > maxLen { throw TokenizerError.tooLong( @@ -82,26 +85,26 @@ public class BertTokenizer { } return tokens.compactMap { vocab[$0] } } - + /// Main entry point func tokenizeToIds(text: String) -> [Int] { return try! convertTokensToIds(tokens: tokenize(text: text)) } - + func tokenToId(token: String) -> Int { return vocab[token]! } - + /// Un-tokenization: get tokens from tokenIds func unTokenize(tokens: [Int]) -> [String] { return tokens.compactMap { ids_to_tokens[$0] } } - + /// Un-tokenization: func convertWordpieceToBasicTokenList(_ wordpieceTokenList: [String]) -> String { var tokenList: [String] = [] var individualToken: String = "" - + for token in wordpieceTokenList { if token.starts(with: "##") { individualToken += String(token.suffix(token.count - 2)) @@ -109,21 +112,21 @@ public class BertTokenizer { if individualToken.count > 0 { tokenList.append(individualToken) } - + individualToken = token } } - + tokenList.append(individualToken) - + return tokenList.joined(separator: " ") } - + private func tokenizeChineseCharsIfNeed(_ text: String) -> String { guard tokenizeChineseChars else { return text } - + return text.map { c in if let scalar = c.unicodeScalars.first, Utils.isChineseChar(scalar) { " \(c) " @@ -134,28 +137,26 @@ public class BertTokenizer { } } - extension BertTokenizer: PreTrainedTokenizerModel { public var unknownToken: String? { wordpieceTokenizer.unkToken } public var unknownTokenId: Int? { vocab[unknownToken!] } func encode(text: String) -> [Int] { tokenizeToIds(text: text) } - + func decode(tokens: [Int]) -> String { let tokens = unTokenize(tokens: tokens) return convertWordpieceToBasicTokenList(tokens) } - + public func convertTokenToId(_ token: String) -> Int? { return vocab[token] ?? unknownTokenId } - + public func convertIdToToken(_ id: Int) -> String? { return ids_to_tokens[id] } } - class BasicTokenizer { let doLowerCase: Bool @@ -164,7 +165,7 @@ class BasicTokenizer { } let neverSplit = [ - "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]" + "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", ] func maybeStripAccents(_ text: String) -> String { @@ -211,11 +212,11 @@ extension Character { if isPunctuation { return true } if let value = unicodeScalars.first?.value { switch value { - case 33...47: return true - case 58...64: return true - case 91...96: return true - case 123...126: return true - default: return false + case 33 ... 47: return true + case 58 ... 64: return true + case 91 ... 96: return true + case 123 ... 126: return true + default: return false } } return false @@ -226,11 +227,11 @@ class WordpieceTokenizer { let unkToken = "[UNK]" private let maxInputCharsPerWord = 100 private let vocab: [String: Int] - + init(vocab: [String: Int]) { self.vocab = vocab } - + /// `word`: A single token. /// Warning: this differs from the `pytorch-transformers` implementation. /// This should have already been passed through `BasicTokenizer`. @@ -246,7 +247,7 @@ class WordpieceTokenizer { var end = word.count var cur_substr: String? = nil while start < end { - var substr = Utils.substr(word, start.. 0 { substr = "##\(substr)" } diff --git a/Sources/Tokenizers/ByteEncoder.swift b/Sources/Tokenizers/ByteEncoder.swift index e636ca6..1d0532a 100644 --- a/Sources/Tokenizers/ByteEncoder.swift +++ b/Sources/Tokenizers/ByteEncoder.swift @@ -8,7 +8,7 @@ import Foundation -let byteEncoder: Dictionary = [ +let byteEncoder: [UTF8.CodeUnit: String] = [ 33: "!", 34: "\"", 35: "#", diff --git a/Sources/Tokenizers/Decoder.swift b/Sources/Tokenizers/Decoder.swift index 81d0d60..9f1a152 100644 --- a/Sources/Tokenizers/Decoder.swift +++ b/Sources/Tokenizers/Decoder.swift @@ -1,6 +1,6 @@ // // Decoder.swift -// +// // // Created by Pedro Cuenca on 17/7/23. // @@ -11,7 +11,7 @@ import Hub public protocol Decoder { func decode(tokens: [String]) -> [String] func callAsFunction(tokens: [String]) -> [String] - + init(config: Config) } @@ -40,15 +40,15 @@ struct DecoderFactory { guard let typeName = config.type?.stringValue else { return nil } let type = DecoderType(rawValue: typeName) switch type { - case .Sequence : return DecoderSequence(config: config) - case .ByteLevel : return ByteLevelDecoder(config: config, addedTokens: addedTokens) - case .Replace : return ReplaceDecoder(config: config) + case .Sequence: return DecoderSequence(config: config) + case .ByteLevel: return ByteLevelDecoder(config: config, addedTokens: addedTokens) + case .Replace: return ReplaceDecoder(config: config) case .ByteFallback: return ByteFallbackDecoder(config: config) - case .Fuse : return FuseDecoder(config: config) - case .Strip : return StripDecoder(config: config) - case .Metaspace : return MetaspaceDecoder(config: config) - case .WordPiece : return WordPieceDecoder(config: config) - default : fatalError("Unsupported Decoder type: \(typeName)") + case .Fuse: return FuseDecoder(config: config) + case .Strip: return StripDecoder(config: config) + case .Metaspace: return MetaspaceDecoder(config: config) + case .WordPiece: return WordPieceDecoder(config: config) + default: fatalError("Unsupported Decoder type: \(typeName)") } } } @@ -61,17 +61,22 @@ class WordPieceDecoder: Decoder { private let re = try! NSRegularExpression(pattern: "\\s(\\.|\\?|\\!|\\,|'\\s|n't|'m|'s|'ve|'re)", options: []) required public init(config: Config) { - guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") } + guard let prefix = config.prefix?.stringValue else { + fatalError("Missing `prefix` configuration for WordPieceDecoder.") + } self.prefix = prefix self.cleanup = config.cleanup?.boolValue ?? false } func decode(tokens: [String]) -> [String] { let firstToken = cleanup ? cleanUpTokenization(tokens.first!) : tokens.first! - return [firstToken] + tokens.dropFirst().map { token in - let token = token.hasPrefix(prefix) ? token.replacingCharacters(in: token.range(of: prefix)!, with: "") : " \(token)" - return cleanup ? cleanUpTokenization(token) : token - } + return [firstToken] + + tokens.dropFirst().map { token in + let token = + token.hasPrefix(prefix) + ? token.replacingCharacters(in: token.range(of: prefix)!, with: "") : " \(token)" + return cleanup ? cleanUpTokenization(token) : token + } } // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L40 @@ -84,12 +89,12 @@ class WordPieceDecoder: Decoder { class DecoderSequence: Decoder { let decoders: [Decoder] - + required public init(config: Config) { guard let configs = config.decoders?.arrayValue else { fatalError("No decoders in Sequence") } decoders = configs.compactMap { DecoderFactory.fromConfig(config: $0) } } - + func decode(tokens: [String]) -> [String] { decoders.reduce(tokens) { current, decoder in decoder(tokens: current) @@ -99,26 +104,26 @@ class DecoderSequence: Decoder { class ByteLevelDecoder: Decoder { let addedTokens: Set - + required public init(config: Config) { self.addedTokens = [] } - + init(config: Config, addedTokens: Set?) { self.addedTokens = addedTokens ?? [] } - + func decode(tokens: [String]) -> [String] { var subTexts: [String] = [] var currentSubText: [String] = [] - + func convertTokensToString(_ tokens: [String]) -> String { let text = tokens.joined(separator: "") - + let utfCodepoints = text.map { byteDecoder[String($0)]! } return String(decoding: utfCodepoints, as: UTF8.self) } - + for token in tokens { if addedTokens.contains(token) { if !currentSubText.isEmpty { @@ -130,22 +135,22 @@ class ByteLevelDecoder: Decoder { currentSubText.append(token) } } - + if !currentSubText.isEmpty { subTexts.append(convertTokensToString(currentSubText)) } - + return subTexts } } class ReplaceDecoder: Decoder { let pattern: StringReplacePattern? - + required public init(config: Config) { self.pattern = StringReplacePattern.from(config: config) } - + func decode(tokens: [String]) -> [String] { guard let pattern = pattern else { return tokens } return tokens.map { pattern.replace($0) } @@ -154,7 +159,7 @@ class ReplaceDecoder: Decoder { class ByteFallbackDecoder: Decoder { required public init(config: Config) {} - + func decode(tokens: [String]) -> [String] { var newTokens: [String] = [] var byteTokens: [Int] = [] @@ -165,9 +170,9 @@ class ByteFallbackDecoder: Decoder { } let startIndex = token.index(token.startIndex, offsetBy: 3) let endIndex = token.index(token.startIndex, offsetBy: 5) - return Int(token[startIndex.. [String] { [tokens.joined(separator: "")] } @@ -197,16 +202,22 @@ class StripDecoder: Decoder { let content: String let start: Int let stop: Int - + required public init(config: Config) { - guard let content = config.content?.stringValue else { fatalError("Incorrect StripDecoder configuration: can't parse `content`.") } - guard let start = config.start?.intValue else { fatalError("Incorrect StripDecoder configuration: can't parse `start`.") } - guard let stop = config.stop?.intValue else { fatalError("Incorrect StripDecoder configuration: can't parse `stop`.") } + guard let content = config.content?.stringValue else { + fatalError("Incorrect StripDecoder configuration: can't parse `content`.") + } + guard let start = config.start?.intValue else { + fatalError("Incorrect StripDecoder configuration: can't parse `start`.") + } + guard let stop = config.stop?.intValue else { + fatalError("Incorrect StripDecoder configuration: can't parse `stop`.") + } self.content = content self.start = start self.stop = stop } - + func decode(tokens: [String]) -> [String] { tokens.map { token in token.trimmingFromStart(upto: start).trimmingFromEnd(upto: stop) @@ -217,7 +228,7 @@ class StripDecoder: Decoder { class MetaspaceDecoder: Decoder { let addPrefixSpace: Bool let replacement: String - + required public init(config: Config) { addPrefixSpace = config.addPrefixSpace?.boolValue ?? false replacement = config.replacement?.stringValue ?? "_" @@ -235,8 +246,8 @@ class MetaspaceDecoder: Decoder { } // We could use firstIndex(where:), lastIndex(where:) for possibly better efficiency (and do both ends at once) -public extension String { - func trimmingFromStart(character: Character = " ", upto: Int) -> String { +extension String { + public func trimmingFromStart(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 while trimmed < upto && result.first == character { @@ -246,7 +257,7 @@ public extension String { return result } - func trimmingFromEnd(character: Character = " ", upto: Int) -> String { + public func trimmingFromEnd(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 while trimmed < upto && result.last == character { diff --git a/Sources/Tokenizers/Normalizer.swift b/Sources/Tokenizers/Normalizer.swift index e9730c8..fcc3121 100644 --- a/Sources/Tokenizers/Normalizer.swift +++ b/Sources/Tokenizers/Normalizer.swift @@ -220,9 +220,10 @@ class BertNormalizer: Normalizer { private func stripAccents(text: String) -> String { // This might be the same as `text.folding(options: .diacriticInsensitive, locale: nil)` - String(text.decomposedStringWithCanonicalMapping.unicodeScalars.filter { scalar in - !(0x0300 <= scalar.value && scalar.value <= 0x036F) - }) + String( + text.decomposedStringWithCanonicalMapping.unicodeScalars.filter { scalar in + !(0x0300 <= scalar.value && scalar.value <= 0x036F) + }) } } @@ -241,10 +242,10 @@ class PrecompiledNormalizer: Normalizer { for scalar in text.unicodeScalars { switch scalar.value { - case 0x0001...0x0008, 0x000B, 0x000E...0x001F, 0x007F, 0x008F, 0x009F: + case 0x0001 ... 0x0008, 0x000B, 0x000E ... 0x001F, 0x007F, 0x008F, 0x009F: // Non-printing control characters output.append("") - case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B...0x200F, 0x2028, 0x2029, 0x2581, + case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B ... 0x200F, 0x2028, 0x2029, 0x2581, 0xFEFF, 0xFFFD: // Separators output.append(" ") diff --git a/Sources/Tokenizers/PostProcessor.swift b/Sources/Tokenizers/PostProcessor.swift index 0b26415..c8175b7 100644 --- a/Sources/Tokenizers/PostProcessor.swift +++ b/Sources/Tokenizers/PostProcessor.swift @@ -1,6 +1,6 @@ // // PostProcessor.swift -// +// // // Created by Pedro Cuenca on 17/7/23. // @@ -35,12 +35,12 @@ struct PostProcessorFactory { guard let typeName = config.type?.stringValue else { return nil } let type = PostProcessorType(rawValue: typeName) switch type { - case .TemplateProcessing : return TemplateProcessing(config: config) - case .ByteLevel : return ByteLevelPostProcessor(config: config) - case .RobertaProcessing : return RobertaProcessing(config: config) - case .BertProcessing : return BertProcessing(config: config) - case .Sequence : return SequenceProcessing(config: config) - default : fatalError("Unsupported PostProcessor type: \(typeName)") + case .TemplateProcessing: return TemplateProcessing(config: config) + case .ByteLevel: return ByteLevelPostProcessor(config: config) + case .RobertaProcessing: return RobertaProcessing(config: config) + case .BertProcessing: return BertProcessing(config: config) + case .Sequence: return SequenceProcessing(config: config) + default: fatalError("Unsupported PostProcessor type: \(typeName)") } } } @@ -48,15 +48,15 @@ struct PostProcessorFactory { class TemplateProcessing: PostProcessor { let single: [Config] let pair: [Config] - + required public init(config: Config) { guard let single = config.single?.arrayValue else { fatalError("Missing `single` processor configuration") } guard let pair = config.pair?.arrayValue else { fatalError("Missing `pair` processor configuration") } - + self.single = single self.pair = pair } - + func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { let config = tokensPair == nil ? single : pair @@ -80,7 +80,9 @@ class TemplateProcessing: PostProcessor { class ByteLevelPostProcessor: PostProcessor { required public init(config: Config) {} - func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { tokens } + func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { + tokens + } } class RobertaProcessing: PostProcessor { @@ -99,7 +101,7 @@ class RobertaProcessing: PostProcessor { self.trimOffset = config.trimOffset?.boolValue ?? true self.addPrefixSpace = config.addPrefixSpace?.boolValue ?? true } - + func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] { var outTokens = tokens var tokensPair = tokensPair @@ -107,7 +109,7 @@ class RobertaProcessing: PostProcessor { if addPrefixSpace { outTokens = outTokens.map({ trimExtraSpaces(token: $0) }) tokensPair = tokensPair?.map({ trimExtraSpaces(token: $0) }) - } else { + } else { outTokens = outTokens.map({ $0.trimmingCharacters(in: .whitespaces) }) tokensPair = tokensPair?.map({ $0.trimmingCharacters(in: .whitespaces) }) } @@ -130,7 +132,7 @@ class RobertaProcessing: PostProcessor { let suffixOffset = findSuffixIndex(text: token) let prefixIndex = token.index(token.startIndex, offsetBy: prefixOffset) let suffixIndex = token.index(token.startIndex, offsetBy: token.count - suffixOffset) - return String(token[prefixIndex.. Int { @@ -183,7 +185,8 @@ class SequenceProcessing: PostProcessor { var currentTokensPair = tokensPair for processor in processors { - let processed = processor.postProcess(tokens: currentTokens, tokensPair: currentTokensPair, addSpecialTokens: addSpecialTokens) + let processed = processor.postProcess( + tokens: currentTokens, tokensPair: currentTokensPair, addSpecialTokens: addSpecialTokens) currentTokens = processed currentTokensPair = nil // After the first processor, we no longer have a separate pair } diff --git a/Sources/Tokenizers/PreTokenizer.swift b/Sources/Tokenizers/PreTokenizer.swift index d7eb972..8e5c416 100644 --- a/Sources/Tokenizers/PreTokenizer.swift +++ b/Sources/Tokenizers/PreTokenizer.swift @@ -1,6 +1,6 @@ // // PreTokenizer.swift -// +// // // Created by Pedro Cuenca on 18/7/23. // @@ -31,7 +31,7 @@ extension PreTokenizer { func callAsFunction(texts: [String], options: PreTokenizerOptions = [.firstSection]) -> [String] { return preTokenize(texts: texts, options: options) } - + func callAsFunction(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { return preTokenize(text: text, options: options) } @@ -57,7 +57,7 @@ struct PreTokenizerFactory { guard let typeName = config.type?.stringValue else { return nil } let type = PreTokenizerType(rawValue: typeName) switch type { - case .Sequence : return PreTokenizerSequence(config: config) + case .Sequence: return PreTokenizerSequence(config: config) case .ByteLevel: return ByteLevelPreTokenizer(config: config) case .Punctuation: return PunctuationPreTokenizer(config: config) case .Digits: return DigitsPreTokenizer(config: config) @@ -85,12 +85,12 @@ class BertPreTokenizer: PreTokenizer { class PreTokenizerSequence: PreTokenizer { let preTokenizers: [PreTokenizer] - + required init(config: Config) { guard let configs = config.pretokenizers?.arrayValue else { fatalError("No pretokenizers in Sequence") } preTokenizers = configs.compactMap { PreTokenizerFactory.fromConfig(config: $0) } } - + func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { preTokenizers.reduce([text]) { current, preTokenizer in preTokenizer(texts: current, options: options) @@ -114,40 +114,40 @@ class WhitespacePreTokenizer: PreTokenizer { class MetaspacePreTokenizer: PreTokenizer { /// Whether to add a prefix space to the first token let addPrefixSpace: Bool - + /// Replacement character let replacement: String - + /// Optional string representation of the replacement character. let stringReplacement: String - + enum PrependScheme: String { case first case never case always - + static var defaultScheme: PrependScheme { .always } static func from(rawValue value: String?) -> PrependScheme { guard let value = value else { return defaultScheme } return PrependScheme(rawValue: value) ?? defaultScheme } } - + /// The metaspace prepend scheme, see https://github.com/huggingface/tokenizers/pull/1357 let prependScheme: PrependScheme - + required init(config: Config) { addPrefixSpace = config.addPrefixSpace?.boolValue ?? false replacement = config.replacement?.stringValue ?? " " stringReplacement = config.strRep?.stringValue ?? replacement prependScheme = PrependScheme.from(rawValue: config.prependScheme?.stringValue) } - + // https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L114 // https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/tokenizers.js#L2153 func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { let normalized = text.replacingOccurrences(of: " ", with: stringReplacement) - + // We add a prefix space if: // (1) The addPrefixSpace option is enabled and the normalized // token does not already start with the replacement character. @@ -165,7 +165,7 @@ class MetaspacePreTokenizer: PreTokenizer { prepend = stringReplacement } } - + // Split in `MergedWithNext` mode, although usually the input to this function is already pre-tokenized // https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L127 return (prepend + normalized).split(by: replacement, behavior: .mergedWithNext) @@ -177,13 +177,13 @@ class ByteLevelPreTokenizer: PreTokenizer { let trimOffsets: Bool let useRegex: Bool let RE = #"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"# - + required init(config: Config) { addPrefixSpace = config.addPrefixSpace?.boolValue ?? false trimOffsets = config.trimOffsets?.boolValue ?? true useRegex = config.useRegex?.boolValue ?? true } - + func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { // Split on whitespace and punctuation let tokens = useRegex ? text.ranges(of: RE).map({ String(text[$0]) }) : [text] @@ -267,31 +267,36 @@ extension StringSplitPattern { } } -public extension String { - func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range] { +extension String { + public func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range] { var result: [Range] = [] var start = startIndex - while let range = range(of: string, options: options, range: start.. [String] { + + public func split( + by string: String, options: CompareOptions = .regularExpression, includeSeparators: Bool = false, + omittingEmptySubsequences: Bool = true + ) -> [String] { var result: [String] = [] var start = startIndex - while let range = range(of: string, options: options, range: start.. [String] { + public func split(by captureRegex: NSRegularExpression) -> [String] { // Find the matching capture groups - let selfRange = NSRange(startIndex.. [String] { +extension String { + public func split(by string: String, options: CompareOptions = .regularExpression, behavior: SplitDelimiterBehavior) + -> [String] + { func mergedWithNext(ranges: [Range]) -> [Range] { var merged: [Range] = [] var currentStart = startIndex for range in ranges { if range.lowerBound == startIndex { continue } - let mergedRange = currentStart..]) -> [Range] { var merged: [Range] = [] var currentStart = startIndex for range in ranges { - let mergedRange = currentStart.. [TokenLatticeNode] { - for offset in 0...count { + for offset in 0 ... count { guard beginNodes[offset].count > 0 else { return [] } - + for rnode in beginNodes[offset] { rnode.prev = nil var bestScore: Float = 0 @@ -75,27 +75,27 @@ extension TokenLattice { bestScore = score } } - + if bestNode != nil { rnode.prev = bestNode rnode.backtraceScore = bestScore } } } - + let root = beginNodes[count][0] guard let prev = root.prev else { return [] } // TODO: the reference implementations have a few more clones here: verify var result: [TokenLatticeNode] = [] - var node = prev //.clone() + var node = prev //.clone() while node.prev != nil { result.append(node.clone()) - node = node.prev! //.clone() + node = node.prev! //.clone() } return result.reversed() } - + /// Returns the substring of the sentence to be tokenized associated to the specified node /// /// - Parameters: @@ -105,16 +105,16 @@ extension TokenLattice { func piece(_ node: TokenLatticeNode) -> any StringProtocol { let start = sentence.index(sentence.startIndex, offsetBy: node.startOffset) let end = sentence.index(start, offsetBy: node.length) - return sentence[start.. TokenLatticeNode { - TokenLatticeNode(tokenId: tokenId, startOffset: startOffset, length: length, score: score, prev: prev, backtraceScore: backtraceScore) + TokenLatticeNode( + tokenId: tokenId, startOffset: startOffset, length: length, score: score, prev: prev, + backtraceScore: backtraceScore) } } diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index 22cf015..e45f293 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -5,8 +5,8 @@ // Created by Pedro Cuenca on 6/5/23. // -import Hub import Foundation +import Hub import Jinja enum TokenizerError: Error { @@ -41,48 +41,50 @@ public protocol TokenizingModel { var fuseUnknownTokens: Bool { get } } -public extension TokenizingModel { - func callAsFunction(_ text: String) -> [String] { +extension TokenizingModel { + public func callAsFunction(_ text: String) -> [String] { tokenize(text: text) } - func convertTokensToIds(_ tokens: [String]) -> [Int?] { + public func convertTokensToIds(_ tokens: [String]) -> [Int?] { return tokens.map { convertTokenToId($0) } } - func convertIdsToTokens(_ ids: [Int]) -> [String?] { + public func convertIdsToTokens(_ ids: [Int]) -> [String?] { return ids.map { convertIdToToken($0) } } } /// A tokenizer model that is set up with Hub configuration data public protocol PreTrainedTokenizerModel: TokenizingModel { - init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws + init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws } struct TokenizerModel { - static let knownTokenizers: [String : PreTrainedTokenizerModel.Type] = [ - "BertTokenizer" : BertTokenizer.self, + static let knownTokenizers: [String: PreTrainedTokenizerModel.Type] = [ + "BertTokenizer": BertTokenizer.self, "DistilbertTokenizer": BertTokenizer.self, "DistilBertTokenizer": BertTokenizer.self, - "CodeGenTokenizer" : CodeGenTokenizer.self, - "CodeLlamaTokenizer" : CodeLlamaTokenizer.self, - "FalconTokenizer" : FalconTokenizer.self, - "GemmaTokenizer" : GemmaTokenizer.self, - "GPT2Tokenizer" : GPT2Tokenizer.self, - "LlamaTokenizer" : LlamaTokenizer.self, - "T5Tokenizer" : T5Tokenizer.self, - "WhisperTokenizer" : WhisperTokenizer.self, - "CohereTokenizer" : CohereTokenizer.self, - "Qwen2Tokenizer" : Qwen2Tokenizer.self, - "PreTrainedTokenizer": BPETokenizer.self + "CodeGenTokenizer": CodeGenTokenizer.self, + "CodeLlamaTokenizer": CodeLlamaTokenizer.self, + "FalconTokenizer": FalconTokenizer.self, + "GemmaTokenizer": GemmaTokenizer.self, + "GPT2Tokenizer": GPT2Tokenizer.self, + "LlamaTokenizer": LlamaTokenizer.self, + "T5Tokenizer": T5Tokenizer.self, + "WhisperTokenizer": WhisperTokenizer.self, + "CohereTokenizer": CohereTokenizer.self, + "Qwen2Tokenizer": Qwen2Tokenizer.self, + "PreTrainedTokenizer": BPETokenizer.self, ] static func unknownToken(from tokenizerConfig: Config) -> String? { return tokenizerConfig.unkToken?.content?.stringValue ?? tokenizerConfig.unkToken?.stringValue } - public static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws -> TokenizingModel { + public static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws + -> TokenizingModel + { guard let tokenizerClassName = tokenizerConfig.tokenizerClass?.stringValue else { throw TokenizerError.missingTokenizerClassInConfig } @@ -93,7 +95,8 @@ struct TokenizerModel { throw TokenizerError.unsupportedTokenizer(tokenizerName) } - return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) + return try tokenizerClass.init( + tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) } } @@ -149,20 +152,20 @@ public protocol Tokenizer { ) throws -> [Int] } -public extension Tokenizer { - func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] { +extension Tokenizer { + public func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] { encode(text: text, addSpecialTokens: addSpecialTokens) } - - func decode(tokens: [Int]) -> String { + + public func decode(tokens: [Int]) -> String { decode(tokens: tokens, skipSpecialTokens: false) } - func convertTokensToIds(_ tokens: [String]) -> [Int?] { + public func convertTokensToIds(_ tokens: [String]) -> [Int?] { return tokens.map { convertTokenToId($0) } } - func convertIdsToTokens(_ ids: [Int]) -> [String?] { + public func convertIdsToTokens(_ ids: [Int]) -> [String?] { return ids.map { convertIdToToken($0) } } } @@ -175,7 +178,7 @@ let specialTokenAttributes: [String] = [ "pad_token", "cls_token", "mask_token", - "additional_special_tokens" + "additional_special_tokens", ] public class PreTrainedTokenizer: Tokenizer { @@ -202,11 +205,13 @@ public class PreTrainedTokenizer: Tokenizer { private let cleanUpTokenizationSpaces: Bool required public init(tokenizerConfig: Config, tokenizerData: Config) throws { - var addedTokens: [String : Int] = [:] - var specialTokens: [String : Int] = [:] + var addedTokens: [String: Int] = [:] + var specialTokens: [String: Int] = [:] for addedToken in tokenizerData.addedTokens?.arrayValue ?? [] { guard let id = addedToken.id?.intValue else { continue /* malformed: token with no id */ } - guard let content = addedToken.content?.stringValue else { continue /* malformed: token with no content */ } + guard let content = addedToken.content?.stringValue else { + continue /* malformed: token with no content */ + } addedTokens[content] = id if addedToken.special?.boolValue ?? false { @@ -216,14 +221,15 @@ public class PreTrainedTokenizer: Tokenizer { // Convert to tuples for easier access, then sort by length (descending) to avoid early partial matches // (https://github.com/xenova/transformers.js/commit/c305c3824f628f1f02806a6310bd3b18b0f7f8f5) - let unwrappedAddedTokens : [(content: String, prefix: Bool, suffix: Bool)] = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in - guard let content = addedToken.content?.stringValue else { return nil } - let prefix = addedToken.lstrip?.boolValue ?? false - let suffix = addedToken.rstrip?.boolValue ?? false - return (content: content, prefix: prefix, suffix: suffix) - }.sorted { - $0.content.count > $1.content.count - } + let unwrappedAddedTokens: [(content: String, prefix: Bool, suffix: Bool)] = + (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in + guard let content = addedToken.content?.stringValue else { return nil } + let prefix = addedToken.lstrip?.boolValue ?? false + let suffix = addedToken.rstrip?.boolValue ?? false + return (content: content, prefix: prefix, suffix: suffix) + }.sorted { + $0.content.count > $1.content.count + } // then concatenate into regular expression let addedTokensRegexString = unwrappedAddedTokens.map { @@ -245,7 +251,8 @@ public class PreTrainedTokenizer: Tokenizer { self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true self.tokenizerConfig = tokenizerConfig - model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) + model = try TokenizerModel.from( + tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) } func preTokenize(_ text: String, options: PreTokenizerOptions) -> [String] { @@ -272,7 +279,8 @@ public class PreTrainedTokenizer: Tokenizer { func cleanUp(text: String) -> String { guard cleanUpTokenizationSpaces else { return text } - return text + return + text .replacingOccurrences(of: " .", with: ".") .replacingOccurrences(of: " ?", with: "?") .replacingOccurrences(of: " !", with: "!") @@ -328,7 +336,8 @@ public class PreTrainedTokenizer: Tokenizer { let tokenStrings: [String] if skipSpecialTokens { let specialTokenIDs = Set(specialTokens.values) - tokenStrings = tokens + tokenStrings = + tokens .filter { !specialTokenIDs.contains($0) } .compactMap { model.convertIdToToken($0) } } else { @@ -380,18 +389,20 @@ public class PreTrainedTokenizer: Tokenizer { } else if let valueFromConfig = tokenizerConfig.chatTemplate { if let arrayValue = valueFromConfig.arrayValue { // If the config specifies a list of chat templates, convert them to a dictionary - let templateDict = Dictionary(uniqueKeysWithValues: arrayValue.compactMap { item in - guard let name = item.name?.stringValue, let template = item.template?.stringValue else { - return nil - } - return (name, template) - }) + let templateDict = [String: String]( + uniqueKeysWithValues: arrayValue.compactMap { item in + guard let name = item.name?.stringValue, let template = item.template?.stringValue else { + return nil + } + return (name, template) + }) if let chatTemplate, case .name(let name) = chatTemplate { // Select chat template from config by name if let matchingDictEntry = templateDict[name] { selectedChatTemplate = matchingDictEntry } else { - throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config") + throw TokenizerError.chatTemplate( + "No chat template named \"\(name)\" was found in the tokenizer config") } } else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] { // Use tool use chat template from config @@ -413,13 +424,13 @@ public class PreTrainedTokenizer: Tokenizer { let template = try Template(selectedChatTemplate) var context: [String: Any] = [ "messages": messages, - "add_generation_prompt": addGenerationPrompt - // TODO: Add `tools` entry when support is added in Jinja - // "tools": tools + "add_generation_prompt": addGenerationPrompt, + // TODO: Add `tools` entry when support is added in Jinja + // "tools": tools ] // TODO: maybe keep NSString here - for (key, value) in tokenizerConfig.dictionary as [String : Any] { + for (key, value) in tokenizerConfig.dictionary as [String: Any] { if specialTokenAttributes.contains(key), !(value is NSNull) { context[key] = value } @@ -446,7 +457,7 @@ public struct AutoTokenizer {} struct PreTrainedTokenizerClasses { /// Class overrides for custom behaviour /// Not to be confused with the TokenizerModel classes defined in TokenizerModel - static let tokenizerClasses: [String : PreTrainedTokenizer.Type] = [ + static let tokenizerClasses: [String: PreTrainedTokenizer.Type] = [ "LlamaTokenizer": LlamaPreTrainedTokenizer.self ] } @@ -496,18 +507,17 @@ extension AutoTokenizer { // MARK: - Tokenizer model classes -class GPT2Tokenizer : BPETokenizer {} -class FalconTokenizer : BPETokenizer {} -class LlamaTokenizer : BPETokenizer {} -class CodeGenTokenizer : BPETokenizer {} -class WhisperTokenizer : BPETokenizer {} -class GemmaTokenizer : BPETokenizer {} +class GPT2Tokenizer: BPETokenizer {} +class FalconTokenizer: BPETokenizer {} +class LlamaTokenizer: BPETokenizer {} +class CodeGenTokenizer: BPETokenizer {} +class WhisperTokenizer: BPETokenizer {} +class GemmaTokenizer: BPETokenizer {} class CodeLlamaTokenizer: BPETokenizer {} -class CohereTokenizer : BPETokenizer {} -class Qwen2Tokenizer : BPETokenizer {} - -class T5Tokenizer : UnigramTokenizer {} +class CohereTokenizer: BPETokenizer {} +class Qwen2Tokenizer: BPETokenizer {} +class T5Tokenizer: UnigramTokenizer {} // MARK: - PreTrainedTokenizer classes @@ -522,7 +532,10 @@ class LlamaPreTrainedTokenizer: PreTrainedTokenizer { var configDictionary = tokenizerData.dictionary if !isLegacy { configDictionary.removeValue(forKey: "normalizer") - configDictionary["pre_tokenizer"] = ["type": "Metaspace", "replacement": sentencePieceUnderline, "add_prefix_space": true, "prepend_scheme": "first"] + configDictionary["pre_tokenizer"] = [ + "type": "Metaspace", "replacement": sentencePieceUnderline, "add_prefix_space": true, + "prepend_scheme": "first", + ] } let updatedData = Config(configDictionary) diff --git a/Sources/Tokenizers/Trie.swift b/Sources/Tokenizers/Trie.swift index 6c7f79c..9d95904 100644 --- a/Sources/Tokenizers/Trie.swift +++ b/Sources/Tokenizers/Trie.swift @@ -10,16 +10,16 @@ import Foundation public struct Trie { public typealias Node = TrieNode - + var root: Node - + public init(root: Node? = nil) { self.root = root ?? Node() } } -public extension Trie { - func insert(_ element: any Sequence) { +extension Trie { + public func insert(_ element: any Sequence) { var node = root for item in element { if let child = node.children[item] { @@ -32,14 +32,14 @@ public extension Trie { } node.isLeaf = true } - - func append(contentsOf container: any Sequence>) { + + public func append(contentsOf container: any Sequence>) { for t in container { insert(t) } } - + /// Find all leaf nodes that share a common prefix with the input sequence (usually a text) /// Returns an array - func commonPrefixSearch(_ text: any Sequence) -> [[T]] { + public func commonPrefixSearch(_ text: any Sequence) -> [[T]] { var node = root var seqs: [[T]] = [] var seq: [T] = [] @@ -53,17 +53,17 @@ public extension Trie { } return seqs } - + /// Find all leaf nodes that share a common prefix with the input sequence (usually a text) /// Returns an iterator - func commonPrefixSearchIterator(_ text: any Sequence) -> LeavesWithCommonPrefixIterator { + public func commonPrefixSearchIterator(_ text: any Sequence) -> LeavesWithCommonPrefixIterator { return LeavesWithCommonPrefixIterator(node: root, text: text) } } -public extension Trie { +extension Trie { // Only used for testing, could migrate to collection - func get(_ element: any Sequence) -> Node? { + public func get(_ element: any Sequence) -> Node? { var node = root for item in element { guard let child = node.children[item] else { return nil } @@ -79,12 +79,12 @@ public class TrieNode { var children: [T: TrieNode] = [:] } -public struct LeavesWithCommonPrefixIterator : Sequence, IteratorProtocol { +public struct LeavesWithCommonPrefixIterator: Sequence, IteratorProtocol { var node: TrieNode var text: any Sequence var seq: [T] = [] lazy var iterator = text.makeIterator() as any IteratorProtocol - + public mutating func next() -> [T]? { while true { guard let item = iterator.next() else { return nil } diff --git a/Sources/Tokenizers/UnigramTokenizer.swift b/Sources/Tokenizers/UnigramTokenizer.swift index 58a7217..801fb88 100644 --- a/Sources/Tokenizers/UnigramTokenizer.swift +++ b/Sources/Tokenizers/UnigramTokenizer.swift @@ -15,13 +15,13 @@ class UnigramTokenizer: PreTrainedTokenizerModel { var score: Float } let vocab: [SentencePieceToken] - + let unknownPiece: SentencePieceToken var unknownTokenScore: Float { unknownPiece.score } - + public let unknownTokenId: Int? public var unknownToken: String? { unknownPiece.token } - + let minScore: Float let tokensToIds: [NSString: Int] @@ -34,15 +34,16 @@ class UnigramTokenizer: PreTrainedTokenizerModel { let fuseUnknownTokens: Bool = true private let trie: Trie - - required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws { + + required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws { guard let configVocab = tokenizerData.model?.vocab?.value as? [[Any]] else { throw TokenizerError.missingVocab } - + vocab = try configVocab.map { piece in guard let token = piece.first as? String, - let scoreValue = piece.last else { + let scoreValue = piece.last + else { throw TokenizerError.malformedVocab } @@ -54,20 +55,20 @@ class UnigramTokenizer: PreTrainedTokenizerModel { } else { throw TokenizerError.malformedVocab } - + return SentencePieceToken(token: token, score: score) } - + minScore = vocab.reduce(999) { partial, token in min(partial, token.score) } - + guard let unknownTokenId = tokenizerData.model?.unkId?.intValue else { throw TokenizerError.malformedVocab } self.unknownTokenId = unknownTokenId self.unknownPiece = SentencePieceToken(token: vocab[unknownTokenId].token, score: minScore - 10) - + tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token as NSString }.enumerated().map { ($1, $0) }) - bosTokenId = tokensToIds[bosToken! as NSString] // May be nil + bosTokenId = tokensToIds[bosToken! as NSString] // May be nil eosToken = tokenizerConfig.eosToken?.stringValue eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString] @@ -79,21 +80,21 @@ class UnigramTokenizer: PreTrainedTokenizerModel { func convertTokenToId(_ token: String) -> Int? { return tokensToIds[token as NSString] ?? self.unknownTokenId } - + func convertIdToToken(_ id: Int) -> String? { return vocab[id].token } - + func tokenize(text: String) -> [String] { var lattice = TokenLattice(sentence: text, bosTokenId: bosTokenId ?? 0, eosTokenId: eosTokenId ?? 0) - + // Populate nodes let sentence = lattice.sentence var beginPos = 0 while beginPos < sentence.count { let mblen = 1 var hasSingleNode = false - + let beginIndex = sentence.index(sentence.startIndex, offsetBy: beginPos) for token in trie.commonPrefixSearchIterator(sentence[beginIndex...]).map({ String($0) }) { guard let tokenId = tokensToIds[token as NSString] else { fatalError("Token not in vocab: \(token)") } @@ -104,7 +105,8 @@ class UnigramTokenizer: PreTrainedTokenizerModel { } } if !hasSingleNode { - lattice.insert(startOffset: beginPos, length: mblen, score: unknownTokenScore, tokenId: unknownTokenId ?? 0) + lattice.insert( + startOffset: beginPos, length: mblen, score: unknownTokenScore, tokenId: unknownTokenId ?? 0) } beginPos += mblen } diff --git a/Sources/Tokenizers/Utils.swift b/Sources/Tokenizers/Utils.swift index 9efacc2..bc43f59 100644 --- a/Sources/Tokenizers/Utils.swift +++ b/Sources/Tokenizers/Utils.swift @@ -17,7 +17,7 @@ struct Utils { print("[\(label)] \(diff)ms") return result } - + /// Time a block in seconds and return (output, time) static func time(_ block: () -> T) -> (T, Double) { let startTime = CFAbsoluteTimeGetCurrent() @@ -25,24 +25,24 @@ struct Utils { let diff = CFAbsoluteTimeGetCurrent() - startTime return (result, diff) } - + /// Return unix timestamp in ms static func dateNow() -> Int64 { // Use `Int` when we don't support 32-bits devices/OSes anymore. // Int crashes on iPhone 5c. return Int64(Date().timeIntervalSince1970 * 1000) } - + /// Clamp a val to [min, max] static func clamp(_ val: T, _ vmin: T, _ vmax: T) -> T { return min(max(vmin, val), vmax) } - + /// Fake func that can throw. static func fakeThrowable(_ input: T) throws -> T { return input } - + /// Substring static func substr(_ s: String, _ r: Range) -> String? { let stringCount = s.count @@ -51,11 +51,11 @@ struct Utils { } let startIndex = s.index(s.startIndex, offsetBy: r.lowerBound) let endIndex = s.index(startIndex, offsetBy: r.upperBound - r.lowerBound) - return String(s[startIndex..(_ dict: Dictionary) -> Dictionary { + static func invert(_ dict: [K: V]) -> [V: K] { var inverted: [V: K] = [:] for (k, v) in dict { inverted[v] = k @@ -66,18 +66,13 @@ struct Utils { /// Checks if a character is considered Chinese /// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) static func isChineseChar(_ c: UnicodeScalar) -> Bool { - (c.value >= 0x4E00 && c.value <= 0x9FFF) || - (c.value >= 0x3400 && c.value <= 0x4DBF) || - (c.value >= 0x20000 && c.value <= 0x2A6DF) || - (c.value >= 0x2A700 && c.value <= 0x2B73F) || - (c.value >= 0x2B740 && c.value <= 0x2B81F) || - (c.value >= 0x2B820 && c.value <= 0x2CEAF) || - (c.value >= 0xF900 && c.value <= 0xFAFF) || - (c.value >= 0x2F800 && c.value <= 0x2FA1F) + (c.value >= 0x4E00 && c.value <= 0x9FFF) || (c.value >= 0x3400 && c.value <= 0x4DBF) + || (c.value >= 0x20000 && c.value <= 0x2A6DF) || (c.value >= 0x2A700 && c.value <= 0x2B73F) + || (c.value >= 0x2B740 && c.value <= 0x2B81F) || (c.value >= 0x2B820 && c.value <= 0x2CEAF) + || (c.value >= 0xF900 && c.value <= 0xFAFF) || (c.value >= 0x2F800 && c.value <= 0x2FA1F) } } enum Constants { static let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"# } - diff --git a/Sources/TransformersCLI/main.swift b/Sources/TransformersCLI/main.swift index a040946..7a50207 100644 --- a/Sources/TransformersCLI/main.swift +++ b/Sources/TransformersCLI/main.swift @@ -1,9 +1,8 @@ import ArgumentParser import CoreML import Foundation - -import Models import Generation +import Models @available(iOS 16.2, macOS 13.1, *) struct TransformersCLI: ParsableCommand { @@ -23,7 +22,7 @@ struct TransformersCLI: ParsableCommand { @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") var computeUnits: ComputeUnits = .cpuAndGPU - + func generate(model: LanguageModel, config: GenerationConfig, prompt: String, printOutput: Bool = true) { let semaphore = DispatchSemaphore(value: 0) Task.init { [config] in @@ -56,11 +55,11 @@ struct TransformersCLI: ParsableCommand { func compile(at url: URL) throws -> URL { #if os(watchOS) - fatalError("Model compilation is not supported on watchOS") + fatalError("Model compilation is not supported on watchOS") #else - if url.pathExtension == "mlmodelc" { return url } - print("Compiling model \(url)") - return try MLModel.compileModel(at: url) + if url.pathExtension == "mlmodelc" { return url } + print("Compiling model \(url)") + return try MLModel.compileModel(at: url) #endif } @@ -69,15 +68,15 @@ struct TransformersCLI: ParsableCommand { let compiledURL = try compile(at: url) print("Loading model \(compiledURL)") let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits) - + // Using greedy generation for now var config = model.defaultGenerationConfig config.doSample = false config.maxNewTokens = maxLength - + print("Warming up...") generate(model: model, config: config, prompt: prompt, printOutput: false) - + print("Generating") generate(model: model, config: config, prompt: prompt) } diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 9871ba6..319dbf7 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -4,9 +4,10 @@ // Created by Pedro Cuenca on 20231230. // -@testable import Hub import XCTest +@testable import Hub + class HubApiTests: XCTestCase { override func setUp() { // Put setup code here. This method is called before the invocation of each test method in the class. @@ -30,7 +31,8 @@ class HubApiTests: XCTestCase { func testFilenameRetrievalWithGlob() async { do { try await { - let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.json") + let filenames = try await Hub.getFilenames( + from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.json") XCTAssertEqual( Set(filenames), Set([ @@ -44,7 +46,8 @@ class HubApiTests: XCTestCase { // Glob patterns are case sensitive try await { - let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.JSON") + let filenames = try await Hub.getFilenames( + from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.JSON") XCTAssertEqual( filenames, [] @@ -58,7 +61,8 @@ class HubApiTests: XCTestCase { func testFilenameRetrievalFromDirectories() async { do { // Contents of all directories matching a pattern - let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.mlpackage/*") + let filenames = try await Hub.getFilenames( + from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.mlpackage/*") XCTAssertEqual( Set(filenames), Set([ @@ -78,7 +82,8 @@ class HubApiTests: XCTestCase { func testFilenameRetrievalWithMultiplePatterns() async { do { let patterns = ["config.json", "tokenizer.json", "tokenizer_*.json"] - let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml", matching: patterns) + let filenames = try await Hub.getFilenames( + from: "coreml-projects/Llama-2-7b-chat-coreml", matching: patterns) XCTAssertEqual( Set(filenames), Set(["config.json", "tokenizer.json", "tokenizer_config.json"]) @@ -87,12 +92,13 @@ class HubApiTests: XCTestCase { XCTFail("\(error)") } } - + func testGetFileMetadata() async throws { do { - let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json") + let url = URL( + string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json") let metadata = try await Hub.getFileMetadata(fileURL: url!) - + XCTAssertNotNil(metadata.commitHash) XCTAssertNotNil(metadata.etag) XCTAssertEqual(metadata.location, url?.absoluteString) @@ -101,12 +107,13 @@ class HubApiTests: XCTestCase { XCTFail("\(error)") } } - + func testGetFileMetadataBlobPath() async throws { do { - let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json") + let url = URL( + string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json") let metadata = try await Hub.getFileMetadata(fileURL: url!) - + XCTAssertNotNil(metadata.commitHash) XCTAssertTrue(metadata.etag != nil && metadata.etag!.hasPrefix("d6ceb9")) XCTAssertEqual(metadata.location, url?.absoluteString) @@ -115,13 +122,13 @@ class HubApiTests: XCTestCase { XCTFail("\(error)") } } - + func testGetFileMetadataWithRevision() async throws { do { let revision = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2" let url = URL(string: "https://huggingface.co/julien-c/dummy-unknown/resolve/\(revision)/config.json") let metadata = try await Hub.getFileMetadata(fileURL: url!) - + XCTAssertEqual(metadata.commitHash, revision) XCTAssertNotNil(metadata.etag) XCTAssertGreaterThan(metadata.etag!.count, 0) @@ -134,7 +141,9 @@ class HubApiTests: XCTestCase { func testGetFileMetadataWithBlobSearch() async throws { let repo = "coreml-projects/Llama-2-7b-chat-coreml" - let metadataFromBlob = try await Hub.getFileMetadata(from: repo, matching: "*.json").sorted { $0.location < $1.location } + let metadataFromBlob = try await Hub.getFileMetadata(from: repo, matching: "*.json").sorted { + $0.location < $1.location + } let files = try await Hub.getFilenames(from: repo, matching: "*.json").sorted() for (metadata, file) in zip(metadataFromBlob, files) { XCTAssertNotNil(metadata.commitHash) @@ -167,7 +176,9 @@ class SnapshotDownloadTests: XCTestCase { var filenames: [String] = [] let prefix = downloadDestination.appending(path: "models/\(repo)").path.appending("/") - if let enumerator = FileManager.default.enumerator(at: url, includingPropertiesForKeys: [.isRegularFileKey], options: [.skipsHiddenFiles], errorHandler: nil) { + if let enumerator = FileManager.default.enumerator( + at: url, includingPropertiesForKeys: [.isRegularFileKey], options: [.skipsHiddenFiles], errorHandler: nil) + { for case let fileURL as URL in enumerator { do { let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey]) @@ -211,7 +222,9 @@ class SnapshotDownloadTests: XCTestCase { func testDownloadInBackground() async throws { let hubApi = HubApi(downloadBase: downloadDestination, useBackgroundSession: true) var lastProgress: Progress? = nil - let downloadedTo = try await hubApi.snapshot(from: repo, matching: "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json") { progress in + let downloadedTo = try await hubApi.snapshot( + from: repo, matching: "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json" + ) { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") lastProgress = progress @@ -224,7 +237,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual( Set(downloadedFilenames), Set([ - "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json" ]) ) } diff --git a/Tests/HubTests/HubTests.swift b/Tests/HubTests/HubTests.swift index 1d7bc86..8eb055c 100644 --- a/Tests/HubTests/HubTests.swift +++ b/Tests/HubTests/HubTests.swift @@ -5,8 +5,8 @@ // import XCTest -@testable import Hub +@testable import Hub class HubTests: XCTestCase { let downloadDestination: URL = { @@ -30,28 +30,28 @@ class HubTests: XCTestCase { do { let configLoader = LanguageModelConfigurationFromHub(modelName: "t5-base", hubApi: hubApi) let config = try await configLoader.modelConfig - + // Test leaf value (Int) guard let eos = config.eos_token_id?.intValue else { XCTFail("nil leaf value (Int)") return } XCTAssertEqual(eos, 1) - + // Test leaf value (String) guard let modelType = config.model_type?.stringValue else { XCTFail("nil leaf value (String)") return } XCTAssertEqual(modelType, "t5") - + // Test leaf value (Array) guard let architectures = config.architectures?.value as? [String] else { XCTFail("nil array") return } XCTAssertEqual(architectures, ["T5ForConditionalGeneration"]) - + // Test nested wrapper guard let taskParams = config.task_specific_params else { XCTFail("nil nested wrapper") @@ -68,7 +68,7 @@ class HubTests: XCTestCase { XCTFail("Cannot download test configuration from the Hub: \(error)") } } - + func testConfigCamelCase() async { do { let configLoader = LanguageModelConfigurationFromHub(modelName: "t5-base", hubApi: hubApi) @@ -80,14 +80,14 @@ class HubTests: XCTestCase { return } XCTAssertEqual(eos, 1) - + // Test leaf value (String) guard let modelType = config.modelType?.stringValue else { XCTFail("nil leaf value (String)") return } XCTAssertEqual(modelType, "t5") - + guard let summarizationMaxLength = config.taskSpecificParams?.summarization?.maxLength?.intValue else { XCTFail("cannot traverse nested containers") return diff --git a/Tests/NormalizerTests/NormalizerTests.swift b/Tests/NormalizerTests/NormalizerTests.swift index fea423a..997a167 100644 --- a/Tests/NormalizerTests/NormalizerTests.swift +++ b/Tests/NormalizerTests/NormalizerTests.swift @@ -122,11 +122,11 @@ class NormalizerTests: XCTestCase { func testStripAccents() { let testCases: [(String, String)] = [ - ("département", "departement"), + ("département", "departement") ] //TODO: test combinations with/without lowercase - let config = Config(["stripAccents":true]) + let config = Config(["stripAccents": true]) let normalizer = BertNormalizer(config: config) for (arg, expect) in testCases { XCTAssertEqual(normalizer.normalize(text: arg), expect) @@ -147,7 +147,7 @@ class NormalizerTests: XCTestCase { ] for (arg, expect) in testCases { - let config = Config(["stripAccents":false]) + let config = Config(["stripAccents": false]) let normalizer = BertNormalizer(config: config) XCTAssertEqual(normalizer.normalize(text: arg), expect) } diff --git a/Tests/PostProcessorTests/PostProcessorTests.swift b/Tests/PostProcessorTests/PostProcessorTests.swift index 347bc38..6283c0c 100644 --- a/Tests/PostProcessorTests/PostProcessorTests.swift +++ b/Tests/PostProcessorTests/PostProcessorTests.swift @@ -1,73 +1,82 @@ import XCTest -@testable import Tokenizers + @testable import Hub +@testable import Tokenizers class PostProcessorTests: XCTestCase { func testRobertaProcessing() { - let testCases: [(Config, [String], [String]?, [String])] = [ + let testCases: [(Config, [String], [String]?, [String])] = [ // Should keep spaces; uneven spaces; ignore `addPrefixSpace`. ( - Config(["cls": (0, "[HEAD]") as (UInt, String), - "sep": (0, "[END]") as (UInt, String), - "trimOffset": false, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[HEAD]") as (UInt, String), + "sep": (0, "[END]") as (UInt, String), + "trimOffset": false, + "addPrefixSpace": true, + ]), [" The", " sun", "sets ", " in ", " the ", "west"], nil, ["[HEAD]", " The", " sun", "sets ", " in ", " the ", "west", "[END]"] ), // Should leave only one space around each token. ( - Config(["cls": (0, "[START]") as (UInt, String), - "sep": (0, "[BREAK]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[START]") as (UInt, String), + "sep": (0, "[BREAK]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": true, + ]), [" The ", " sun", "sets ", " in ", " the ", "west"], nil, ["[START]", " The ", " sun", "sets ", " in ", " the ", "west", "[BREAK]"] ), // Should ignore empty tokens pair. ( - Config(["cls": (0, "[START]") as (UInt, String), - "sep": (0, "[BREAK]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[START]") as (UInt, String), + "sep": (0, "[BREAK]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": true, + ]), [" The ", " sun", "sets ", " in ", " the ", "west"], [], ["[START]", " The ", " sun", "sets ", " in ", " the ", "west", "[BREAK]"] ), // Should trim all whitespace. ( - Config(["cls": (0, "[CLS]") as (UInt, String), - "sep": (0, "[SEP]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": false, - ]), + Config([ + "cls": (0, "[CLS]") as (UInt, String), + "sep": (0, "[SEP]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": false, + ]), [" The ", " sun", "sets ", " in ", " the ", "west"], nil, ["[CLS]", "The", "sun", "sets", "in", "the", "west", "[SEP]"] ), // Should add tokens. ( - Config(["cls": (0, "[CLS]") as (UInt, String), - "sep": (0, "[SEP]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[CLS]") as (UInt, String), + "sep": (0, "[SEP]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": true, + ]), [" The ", " sun", "sets ", " in ", " the ", "west"], [".", "The", " cat ", " is ", " sitting ", " on", "the ", "mat"], - ["[CLS]", " The ", " sun", "sets ", " in ", " the ", "west", "[SEP]", - "[SEP]", ".", "The", " cat ", " is ", " sitting ", " on", "the ", - "mat", "[SEP]"] + [ + "[CLS]", " The ", " sun", "sets ", " in ", " the ", "west", "[SEP]", + "[SEP]", ".", "The", " cat ", " is ", " sitting ", " on", "the ", + "mat", "[SEP]", + ] ), ( - Config(["cls": (0, "[CLS]") as (UInt, String), - "sep": (0, "[SEP]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[CLS]") as (UInt, String), + "sep": (0, "[SEP]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": true, + ]), [" 你 ", " 好 ", ","], [" 凯 ", " 蒂 ", "!"], ["[CLS]", " 你 ", " 好 ", ",", "[SEP]", "[SEP]", " 凯 ", " 蒂 ", "!", "[SEP]"] diff --git a/Tests/PreTokenizerTests/PreTokenizerTests.swift b/Tests/PreTokenizerTests/PreTokenizerTests.swift index 34e29c5..688a003 100644 --- a/Tests/PreTokenizerTests/PreTokenizerTests.swift +++ b/Tests/PreTokenizerTests/PreTokenizerTests.swift @@ -4,8 +4,9 @@ // Created by Jan Krukowski on 23/11/2023. // -import XCTest import Hub +import XCTest + @testable import Tokenizers class PreTokenizerTests: XCTestCase { @@ -136,7 +137,13 @@ class PreTokenizerTests: XCTestCase { [" ", " ", " ", "Hey,", " ", " ", " ", " ", "friend,", " ", " ", " ", " ", "what's", " ", "up?", " ", " "] ) - let preTokenizer3 = SplitPreTokenizer(config: Config(["pattern": ["Regex": "(?i:\'s|\'t|\'re|\'ve|\'m|\'ll|\'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"], "invert": true])) + let preTokenizer3 = SplitPreTokenizer( + config: Config([ + "pattern": [ + "Regex": + "(?i:\'s|\'t|\'re|\'ve|\'m|\'ll|\'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + ], "invert": true, + ])) XCTAssertEqual( preTokenizer3.preTokenize(text: "Hello"), ["Hello"] @@ -151,19 +158,21 @@ class PreTokenizerTests: XCTestCase { ["Hey", " friend", "!", " ", " How", " are", " you", "?!?"] ) } - + // https://github.com/huggingface/tokenizers/pull/1357 func testMetaspacePreTokenizer() { // Prepend "always" - let preTokenizer = MetaspacePreTokenizer(config: Config([ - "add_prefix_space": true, - "replacement": "▁", - "prepend_scheme": "always" - ])) - + let preTokenizer = MetaspacePreTokenizer( + config: Config([ + "add_prefix_space": true, + "replacement": "▁", + "prepend_scheme": "always", + ])) + // TODO: different sections on let text = "Hey my friend how▁are you" - let tokens = text + let tokens = + text .split(by: "", includeSeparators: true) .flatMap { preTokenizer.preTokenize(text: $0) } diff --git a/Tests/TensorUtilsTests/LogitsWarperTests.swift b/Tests/TensorUtilsTests/LogitsWarperTests.swift index 0260967..395e58b 100644 --- a/Tests/TensorUtilsTests/LogitsWarperTests.swift +++ b/Tests/TensorUtilsTests/LogitsWarperTests.swift @@ -4,8 +4,9 @@ // Created by Jan Krukowski on 09/12/2023. // -import XCTest import CoreML +import XCTest + @testable import TensorUtils final class LogitsWarperTests: XCTestCase { @@ -88,7 +89,7 @@ final class LogitsWarperTests: XCTestCase { } func testRepetitionPenaltyWarper() { - let indices = Array(0..<10) + let indices = Array(0 ..< 10) let logits = indices.map({ Float($0) }) let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits) @@ -99,19 +100,20 @@ final class LogitsWarperTests: XCTestCase { XCTAssertEqual(result2.indices, indices) let logits2 = indices.map({ Float($0) / 3.75 }) XCTAssertEqual(result2.logits, logits2, accuracy: accuracy) - + let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119]) XCTAssertEqual(result3.indices, [0, 1, 2]) XCTAssertEqual(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4) - - let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158]) + + let result4 = RepetitionPenaltyWarper(penalty: 1.11)( + [2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158]) XCTAssertEqual(result4.indices, [2, 3, 4]) XCTAssertEqual(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4) let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966]) XCTAssertEqual(result5.indices, [0, 1, 2]) XCTAssertEqual(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4) - + let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755]) XCTAssertEqual(result6.indices, [3, 1, 2]) XCTAssertEqual(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4) diff --git a/Tests/TensorUtilsTests/TensorUtilsTests.swift b/Tests/TensorUtilsTests/TensorUtilsTests.swift index 6355165..6071a76 100644 --- a/Tests/TensorUtilsTests/TensorUtilsTests.swift +++ b/Tests/TensorUtilsTests/TensorUtilsTests.swift @@ -4,8 +4,9 @@ // Created by Jan Krukowski on 25/11/2023. // -import XCTest import CoreML +import XCTest + @testable import TensorUtils final class TensorUtilsTests: XCTestCase { @@ -44,8 +45,8 @@ final class TensorUtilsTests: XCTestCase { } func testSoftmax() { - XCTAssertEqual(Math.softmax([]), []) - + XCTAssertEqual(Math.softmax([]), []) + let result1 = Math.softmax([3.0, 4.0, 1.0, 2.0]) XCTAssertEqual(result1, [0.23688284, 0.6439143, 0.032058604, 0.08714432], accuracy: accuracy) XCTAssertEqual(result1.reduce(0, +), 1.0, accuracy: accuracy) diff --git a/Tests/TensorUtilsTests/WeightsTests.swift b/Tests/TensorUtilsTests/WeightsTests.swift index 5d2e478..2b0e37a 100644 --- a/Tests/TensorUtilsTests/WeightsTests.swift +++ b/Tests/TensorUtilsTests/WeightsTests.swift @@ -1,11 +1,13 @@ -@testable import TensorUtils -@testable import Hub import XCTest +@testable import Hub +@testable import TensorUtils + class WeightsTests: XCTestCase { let downloadDestination: URL = { - FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests") + FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending( + component: "huggingface-tests") }() var hubApi: HubApi { HubApi(downloadBase: downloadDestination) } @@ -14,7 +16,8 @@ class WeightsTests: XCTestCase { let repo = "google/bert_uncased_L-2_H-128_A-2" let modelDir = try await hubApi.snapshot(from: repo, matching: ["config.json", "model.safetensors"]) - let files = try FileManager.default.contentsOfDirectory(at: modelDir, includingPropertiesForKeys: [.isReadableKey]) + let files = try FileManager.default.contentsOfDirectory( + at: modelDir, includingPropertiesForKeys: [.isReadableKey]) XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "config.json" })) XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "model.safetensors" })) @@ -25,7 +28,7 @@ class WeightsTests: XCTestCase { XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.shape.count, 1) XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.dataType, .float32) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.count, 3906816) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.count, 3_906_816) XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.shape.count, 2) XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[0, 0]].floatValue, -0.0041, accuracy: 1e-3) diff --git a/Tests/TokenizersTests/AddedTokensTests.swift b/Tests/TokenizersTests/AddedTokensTests.swift index c82e45f..785b28e 100644 --- a/Tests/TokenizersTests/AddedTokensTests.swift +++ b/Tests/TokenizersTests/AddedTokensTests.swift @@ -5,9 +5,9 @@ // Created by Pedro Cuenca on 20240426. // -import XCTest -import Tokenizers import Hub +import Tokenizers +import XCTest class AddedTokensTests: XCTestCase { func testPhiAddedTokens() async throws { diff --git a/Tests/TokenizersTests/BertTokenizerTests.swift b/Tests/TokenizersTests/BertTokenizerTests.swift index d30ae99..59d1363 100644 --- a/Tests/TokenizersTests/BertTokenizerTests.swift +++ b/Tests/TokenizersTests/BertTokenizerTests.swift @@ -7,9 +7,8 @@ // import XCTest -@testable import Tokenizers - +@testable import Tokenizers class BertTokenizerTests: XCTestCase { @@ -20,7 +19,7 @@ class BertTokenizerTests: XCTestCase { override func tearDown() { // Put teardown code here. This method is called after the invocation of each test method in the class. } - + lazy var bertTokenizer: BertTokenizer = { let vocab = { let url = Bundle.module.url(forResource: "bert-vocab", withExtension: "txt")! @@ -32,93 +31,93 @@ class BertTokenizerTests: XCTestCase { } return vocab }() - + return BertTokenizer(vocab: vocab, merges: nil) }() func testBasicTokenizer() { let basicTokenizer = BasicTokenizer() - + let text = "Brave gaillard, d'où [UNK] êtes vous?" let tokens = ["brave", "gaillard", ",", "d", "\'", "ou", "[UNK]", "etes", "vous", "?"] - + XCTAssertEqual( basicTokenizer.tokenize(text: text), tokens ) /// Verify that `XCTAssertEqual` does what deep equality checks on arrays of strings. XCTAssertEqual(["foo", "bar"], ["foo", "bar"]) } - + /// For each Squad question tokenized by python, check that we get the same output through the `BasicTokenizer` func testFullBasicTokenizer() { let url = Bundle.module.url(forResource: "basic_tokenized_questions", withExtension: "json")! let json = try! Data(contentsOf: url) let decoder = JSONDecoder() let sampleTokens = try! decoder.decode([[String]].self, from: json) - + let basicTokenizer = BasicTokenizer() - + XCTAssertEqual(sampleTokens.count, Squad.examples.count) - + for (i, example) in Squad.examples.enumerated() { let output = basicTokenizer.tokenize(text: example.question) XCTAssertEqual(output, sampleTokens[i]) } } - + /// For each Squad question tokenized by python, check that we get the same output through the whole `BertTokenizer` func testFullBertTokenizer() { let url = Bundle.module.url(forResource: "tokenized_questions", withExtension: "json")! let json = try! Data(contentsOf: url) let decoder = JSONDecoder() let sampleTokens = try! decoder.decode([[Int]].self, from: json) - + let tokenizer = bertTokenizer - + XCTAssertEqual(sampleTokens.count, Squad.examples.count) - + for (i, example) in Squad.examples.enumerated() { let output = tokenizer.tokenizeToIds(text: example.question) XCTAssertEqual(output, sampleTokens[i]) } } - + func testMixedChineseEnglishTokenization() { let tokenizer = bertTokenizer let text = "你好,世界!Hello, world!" let expectedTokens = ["[UNK]", "[UNK]", ",", "世", "[UNK]", "!", "hello", ",", "world", "!"] let tokens = tokenizer.tokenize(text: text) - + XCTAssertEqual(tokens, expectedTokens) } - + func testPureChineseTokenization() { let tokenizer = bertTokenizer let text = "明日,大家上山看日出。" - let expectedTokens = ["明", "日", ",", "大", "家", "上", "山", "[UNK]", "日", "出","。"] + let expectedTokens = ["明", "日", ",", "大", "家", "上", "山", "[UNK]", "日", "出", "。"] let tokens = tokenizer.tokenize(text: text) - + XCTAssertEqual(tokens, expectedTokens) } - + func testChineseWithNumeralsTokenization() { let tokenizer = bertTokenizer let text = "2020年奥运会在东京举行。" let expectedTokens = ["2020", "年", "[UNK]", "[UNK]", "会", "[UNK]", "[UNK]", "京", "[UNK]", "行", "。"] let tokens = tokenizer.tokenize(text: text) - + XCTAssertEqual(tokens, expectedTokens) } - + func testChineseWithSpecialTokens() { let tokenizer = bertTokenizer let text = "[CLS] 机器学习是未来。 [SEP]" let expectedTokens = ["[CLS]", "[UNK]", "[UNK]", "学", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "。", "[SEP]"] let tokens = tokenizer.tokenize(text: text) - + XCTAssertEqual(tokens, expectedTokens) } - + func testPerformanceExample() { let tokenizer = bertTokenizer @@ -128,49 +127,50 @@ class BertTokenizerTests: XCTestCase { _ = tokenizer.tokenizeToIds(text: "Brave gaillard, d'où [UNK] êtes vous?") } } - + func testWordpieceDetokenizer() { struct QuestionTokens: Codable { let original: String let basic: [String] let wordpiece: [String] } - + let url = Bundle.module.url(forResource: "question_tokens", withExtension: "json")! let json = try! Data(contentsOf: url) let decoder = JSONDecoder() let questionTokens = try! decoder.decode([QuestionTokens].self, from: json) let tokenizer = bertTokenizer - + for question in questionTokens { - XCTAssertEqual(question.basic.joined(separator: " "), tokenizer.convertWordpieceToBasicTokenList(question.wordpiece)) + XCTAssertEqual( + question.basic.joined(separator: " "), tokenizer.convertWordpieceToBasicTokenList(question.wordpiece)) } } - + func testEncoderDecoder() { let text = """ - Wake up (Wake up) - Grab a brush and put a little makeup - Hide your scars to fade away the shakeup (Hide the scars to fade away the shakeup) - Why'd you leave the keys upon the table? - Here you go, create another fable, you wanted to - Grab a brush and put a little makeup, you wanted to - Hide the scars to fade away the shakeup, you wanted to - Why'd you leave the keys upon the table? You wanted to - """ - + Wake up (Wake up) + Grab a brush and put a little makeup + Hide your scars to fade away the shakeup (Hide the scars to fade away the shakeup) + Why'd you leave the keys upon the table? + Here you go, create another fable, you wanted to + Grab a brush and put a little makeup, you wanted to + Hide the scars to fade away the shakeup, you wanted to + Why'd you leave the keys upon the table? You wanted to + """ + // Not sure if there's a way to achieve a non-destructive round-trip let decoded = """ - wake up ( wake up ) - grab a brush and put a little makeup - hide your scars to fade away the shakeup ( hide the scars to fade away the shakeup ) - why \' d you leave the keys upon the table ? - here you go , create another fable , you wanted to - grab a brush and put a little makeup , you wanted to - hide the scars to fade away the shakeup , you wanted to - why \' d you leave the keys upon the table ? you wanted to - """ - + wake up ( wake up ) + grab a brush and put a little makeup + hide your scars to fade away the shakeup ( hide the scars to fade away the shakeup ) + why \' d you leave the keys upon the table ? + here you go , create another fable , you wanted to + grab a brush and put a little makeup , you wanted to + hide the scars to fade away the shakeup , you wanted to + why \' d you leave the keys upon the table ? you wanted to + """ + let tokenizer = bertTokenizer for (line, expected) in zip(text.split(separator: "\n"), decoded.split(separator: "\n")) { let encoded = tokenizer.encode(text: String(line)) diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift index 3ee7aa1..22819f2 100644 --- a/Tests/TokenizersTests/ChatTemplateTests.swift +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -5,14 +5,16 @@ // Created by Anthony DePasquale on 2/10/24. // -import XCTest import Tokenizers +import XCTest class ChatTemplateTests: XCTestCase { - let messages = [[ - "role": "user", - "content": "Describe the Swift programming language.", - ]] + let messages = [ + [ + "role": "user", + "content": "Describe the Swift programming language.", + ] + ] func testTemplateFromConfig() async throws { let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") @@ -37,8 +39,10 @@ class ChatTemplateTests: XCTestCase { func testTemplateFromArgumentWithEnum() async throws { let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") // Purposely not using the correct template for this model to verify that the template from the config is not being used - let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" - let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .literal(mistral7BDefaultTemplate)) + let mistral7BDefaultTemplate = + "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let encoded = try tokenizer.applyChatTemplate( + messages: messages, chatTemplate: .literal(mistral7BDefaultTemplate)) let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962] let decoded = tokenizer.decode(tokens: encoded) let decodedTarget = " [INST] Describe the Swift programming language. [/INST]" @@ -49,7 +53,8 @@ class ChatTemplateTests: XCTestCase { func testTemplateFromArgumentWithString() async throws { let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") // Purposely not using the correct template for this model to verify that the template from the config is not being used - let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let mistral7BDefaultTemplate = + "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate) let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962] let decoded = tokenizer.decode(tokens: encoded) diff --git a/Tests/TokenizersTests/DecoderTests.swift b/Tests/TokenizersTests/DecoderTests.swift index b2ecf33..b971070 100644 --- a/Tests/TokenizersTests/DecoderTests.swift +++ b/Tests/TokenizersTests/DecoderTests.swift @@ -4,18 +4,20 @@ // Created by Pedro Cuenca on 20231123. // -import XCTest import Hub +import XCTest + @testable import Tokenizers class DecoderTests: XCTestCase { // https://github.com/huggingface/tokenizers/pull/1357 func testMetaspaceDecoder() { - let decoder = MetaspaceDecoder(config: Config([ - "add_prefix_space": true, - "replacement": "▁", - ])) - + let decoder = MetaspaceDecoder( + config: Config([ + "add_prefix_space": true, + "replacement": "▁", + ])) + let tokens = ["▁Hey", "▁my", "▁friend", "▁", "▁", "▁how", "▁are", "▁you"] let decoded = decoder.decode(tokens: tokens) @@ -24,7 +26,7 @@ class DecoderTests: XCTestCase { ["Hey", " my", " friend", " ", " ", " how", " are", " you"] ) } - + func testWordPieceDecoder() { let config = Config(["prefix": "##", "cleanup": true]) let decoder = WordPieceDecoder(config: config) diff --git a/Tests/TokenizersTests/FactoryTests.swift b/Tests/TokenizersTests/FactoryTests.swift index 88c1c57..314fc1e 100644 --- a/Tests/TokenizersTests/FactoryTests.swift +++ b/Tests/TokenizersTests/FactoryTests.swift @@ -1,13 +1,13 @@ // // FactoryTests.swift -// +// // // Created by Pedro Cuenca on 4/8/23. // -import XCTest -import Tokenizers import Hub +import Tokenizers +import XCTest class TestWithCustomHubDownloadLocation: XCTestCase { let downloadDestination: URL = { @@ -32,32 +32,33 @@ class TestWithCustomHubDownloadLocation: XCTestCase { class FactoryTests: TestWithCustomHubDownloadLocation { func testFromPretrained() async throws { - let tokenizer = try await AutoTokenizer.from(pretrained: "coreml-projects/Llama-2-7b-chat-coreml", hubApi: hubApi) + let tokenizer = try await AutoTokenizer.from( + pretrained: "coreml-projects/Llama-2-7b-chat-coreml", hubApi: hubApi) let inputIds = tokenizer("Today she took a train to the West") XCTAssertEqual(inputIds, [1, 20628, 1183, 3614, 263, 7945, 304, 278, 3122]) } - + func testWhisper() async throws { let tokenizer = try await AutoTokenizer.from(pretrained: "openai/whisper-large-v2", hubApi: hubApi) let inputIds = tokenizer("Today she took a train to the West") XCTAssertEqual(inputIds, [50258, 50363, 27676, 750, 1890, 257, 3847, 281, 264, 4055, 50257]) } - + func testFromModelFolder() async throws { let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"] let repo = Hub.Repo(id: "coreml-projects/Llama-2-7b-chat-coreml") let localModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload) - + let tokenizer = try await AutoTokenizer.from(modelFolder: localModelFolder, hubApi: hubApi) let inputIds = tokenizer("Today she took a train to the West") XCTAssertEqual(inputIds, [1, 20628, 1183, 3614, 263, 7945, 304, 278, 3122]) } - + func testWhisperFromModelFolder() async throws { let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"] let repo = Hub.Repo(id: "openai/whisper-large-v2") let localModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload) - + let tokenizer = try await AutoTokenizer.from(modelFolder: localModelFolder, hubApi: hubApi) let inputIds = tokenizer("Today she took a train to the West") XCTAssertEqual(inputIds, [50258, 50363, 27676, 750, 1890, 257, 3847, 281, 264, 4055, 50257]) diff --git a/Tests/TokenizersTests/SplitTests.swift b/Tests/TokenizersTests/SplitTests.swift index c14b212..e24ac35 100644 --- a/Tests/TokenizersTests/SplitTests.swift +++ b/Tests/TokenizersTests/SplitTests.swift @@ -5,65 +5,65 @@ // Created by Pedro Cuenca on 20240120. // -import XCTest import Tokenizers +import XCTest class SplitTests: XCTestCase { - func testSplitBehaviorMergedWithPrevious() { + func testSplitBehaviorMergedWithPrevious() { XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["the-", "final-", "-", "countdown"] ) - + XCTAssertEqual( "the-final--countdown-".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["the-", "final-", "-", "countdown-"] ) - + XCTAssertEqual( "the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["the-", "final-", "-", "countdown-", "-"] ) - + XCTAssertEqual( "-the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["-", "the-", "final-", "-", "countdown-", "-"] ) - + XCTAssertEqual( "--the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["-", "-", "the-", "final-", "-", "countdown-", "-"] ) } - + func testSplitBehaviorMergedWithNext() { XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext), ["the", "-final", "-", "-countdown"] ) - + XCTAssertEqual( "-the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext), ["-the", "-final", "-", "-countdown"] ) - + XCTAssertEqual( "--the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext), ["-", "-the", "-final", "-", "-countdown"] ) - + XCTAssertEqual( "--the-final--countdown-".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext), ["-", "-the", "-final", "-", "-countdown", "-"] ) } - + func testSplitBehaviorOther() { XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .isolated), ["the", "-", "final", "-", "-", "countdown"] ) - + XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .removed), ["the", "final", "countdown"] diff --git a/Tests/TokenizersTests/SquadDataset.swift b/Tests/TokenizersTests/SquadDataset.swift index 98781b6..2f9edbb 100644 --- a/Tests/TokenizersTests/SquadDataset.swift +++ b/Tests/TokenizersTests/SquadDataset.swift @@ -8,7 +8,6 @@ import Foundation - /// Our internal type, also used in unit tests struct SquadExample { let qaId: String @@ -49,12 +48,14 @@ struct Squad { let json = try! Data(contentsOf: url) let decoder = JSONDecoder() let squadDataset = try! decoder.decode(SquadDataset.self, from: json) - + var examples: [SquadExample] = [] for datum in squadDataset.data { for paragraph in datum.paragraphs { for qa in paragraph.qas { - let example = SquadExample(qaId: qa.id, context: paragraph.context, question: qa.question, answerText: qa.answers[0].text, startPos: qa.answers[0].answer_start, endPos: -1) // todo: remove -1 + let example = SquadExample( + qaId: qa.id, context: paragraph.context, question: qa.question, answerText: qa.answers[0].text, + startPos: qa.answers[0].answer_start, endPos: -1) // todo: remove -1 examples.append(example) } } diff --git a/Tests/TokenizersTests/TokenizerTests.swift b/Tests/TokenizersTests/TokenizerTests.swift index 800ea7f..1f06a52 100644 --- a/Tests/TokenizersTests/TokenizerTests.swift +++ b/Tests/TokenizersTests/TokenizerTests.swift @@ -6,10 +6,11 @@ // Copyright © 2023 Hugging Face. All rights reserved. // -import XCTest import Hub -@testable import Tokenizers +import XCTest + @testable import Models +@testable import Tokenizers class GPT2TokenizerTests: TokenizerTests { override class var hubModelName: String? { "distilgpt2" } @@ -97,7 +98,8 @@ class GemmaTokenizerTests: TokenizerTests { class GemmaUnicodeTests: XCTestCase { func testGemmaVocab() async throws { - guard let tokenizer = try await AutoTokenizer.from(pretrained: "pcuenq/gemma-tokenizer") as? PreTrainedTokenizer else { + guard let tokenizer = try await AutoTokenizer.from(pretrained: "pcuenq/gemma-tokenizer") as? PreTrainedTokenizer + else { XCTFail() return } @@ -116,13 +118,18 @@ class PhiSimpleTests: XCTestCase { XCTAssertEqual(tokenizer.encode(text: "hello"), [15339]) XCTAssertEqual(tokenizer.encode(text: "hello world"), [15339, 1917]) - XCTAssertEqual(tokenizer.encode(text: "<|im_start|>user<|im_sep|>Who are you?<|im_end|><|im_start|>assistant<|im_sep|>"), [100264, 882, 100266, 15546, 527, 499, 30, 100265, 100264, 78191, 100266]) + XCTAssertEqual( + tokenizer.encode(text: "<|im_start|>user<|im_sep|>Who are you?<|im_end|><|im_start|>assistant<|im_sep|>"), + [100264, 882, 100266, 15546, 527, 499, 30, 100265, 100264, 78191, 100266]) } } class BertDiacriticsTests: XCTestCase { func testBertCased() async throws { - guard let tokenizer = try await AutoTokenizer.from(pretrained: "distilbert/distilbert-base-multilingual-cased") as? PreTrainedTokenizer else { + guard + let tokenizer = try await AutoTokenizer.from(pretrained: "distilbert/distilbert-base-multilingual-cased") + as? PreTrainedTokenizer + else { XCTFail() return } @@ -132,7 +139,10 @@ class BertDiacriticsTests: XCTestCase { } func testBertCasedResaved() async throws { - guard let tokenizer = try await AutoTokenizer.from(pretrained: "pcuenq/distilbert-base-multilingual-cased-tokenizer") as? PreTrainedTokenizer else { + guard + let tokenizer = try await AutoTokenizer.from( + pretrained: "pcuenq/distilbert-base-multilingual-cased-tokenizer") as? PreTrainedTokenizer + else { XCTFail() return } @@ -141,7 +151,10 @@ class BertDiacriticsTests: XCTestCase { } func testBertUncased() async throws { - guard let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased") as? PreTrainedTokenizer else { + guard + let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased") + as? PreTrainedTokenizer + else { XCTFail() return } @@ -153,13 +166,21 @@ class BertDiacriticsTests: XCTestCase { XCTAssertEqual(tokenizer.tokenize(text: "Car"), ["car"]) XCTAssertEqual(tokenizer.tokenize(text: "€4"), ["€", "##4"]) - XCTAssertEqual(tokenizer.tokenize(text: "test $1 R2 #3 €4 £5 ¥6 ₣7 ₹8 ₱9 test"), ["test", "$", "1", "r", "##2", "#", "3", "€", "##4", "£5", "¥", "##6", "[UNK]", "₹", "##8", "₱", "##9", "test"]) + XCTAssertEqual( + tokenizer.tokenize(text: "test $1 R2 #3 €4 £5 ¥6 ₣7 ₹8 ₱9 test"), + [ + "test", "$", "1", "r", "##2", "#", "3", "€", "##4", "£5", "¥", "##6", "[UNK]", "₹", "##8", "₱", "##9", + "test", + ]) } } class BertSpacesTests: XCTestCase { func testEncodeDecode() async throws { - guard let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased") as? PreTrainedTokenizer else { + guard + let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased") + as? PreTrainedTokenizer + else { XCTFail() return } @@ -175,7 +196,6 @@ class BertSpacesTests: XCTestCase { } } - struct EncodedTokenizerSamplesDataset: Decodable { let text: String // Bad naming, not just for bpe. @@ -185,8 +205,7 @@ struct EncodedTokenizerSamplesDataset: Decodable { let decoded_text: String } - -typealias EdgeCasesDataset = [String : [EdgeCase]] +typealias EdgeCasesDataset = [String: [EdgeCase]] struct EdgeCase: Decodable { let input: String @@ -201,20 +220,19 @@ struct EncodedData: Decodable { let attention_mask: [Int] } - class TokenizerTester { let encodedSamplesFilename: String let unknownTokenId: Int? - + private var configuration: LanguageModelConfigurationFromHub? = nil private var edgeCases: [EdgeCase]? = nil private var _tokenizer: Tokenizer? = nil - + init(hubModelName: String, encodedSamplesFilename: String, unknownTokenId: Int?, hubApi: HubApi) { configuration = LanguageModelConfigurationFromHub(modelName: hubModelName, hubApi: hubApi) self.encodedSamplesFilename = encodedSamplesFilename self.unknownTokenId = unknownTokenId - + // Read the edge cases dataset edgeCases = { let url = Bundle.module.url(forResource: "tokenizer_tests", withExtension: "json")! @@ -225,7 +243,7 @@ class TokenizerTester { return cases[hubModelName] }() } - + lazy var dataset: EncodedTokenizerSamplesDataset = { let url = Bundle.module.url(forResource: encodedSamplesFilename, withExtension: "json")! let json = try! Data(contentsOf: url) @@ -233,13 +251,14 @@ class TokenizerTester { let dataset = try! decoder.decode(EncodedTokenizerSamplesDataset.self, from: json) return dataset }() - - + var tokenizer: Tokenizer? { get async { guard _tokenizer == nil else { return _tokenizer! } do { - guard let tokenizerConfig = try await configuration!.tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" } + guard let tokenizerConfig = try await configuration!.tokenizerConfig else { + throw "Cannot retrieve Tokenizer configuration" + } let tokenizerData = try await configuration!.tokenizerData _tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } catch { @@ -248,7 +267,7 @@ class TokenizerTester { return _tokenizer } } - + var tokenizerModel: TokenizingModel? { get async { // The model is not usually accessible; maybe it should @@ -256,7 +275,7 @@ class TokenizerTester { return (tokenizer as! PreTrainedTokenizer).model } } - + func testTokenize() async { let tokenized = await tokenizer?.tokenize(text: dataset.text) XCTAssertEqual( @@ -264,7 +283,7 @@ class TokenizerTester { dataset.bpe_tokens ) } - + func testEncode() async { let encoded = await tokenizer?.encode(text: dataset.text) XCTAssertEqual( @@ -272,7 +291,7 @@ class TokenizerTester { dataset.token_ids ) } - + func testDecode() async { let decoded = await tokenizer?.decode(tokens: dataset.token_ids) XCTAssertEqual( @@ -280,7 +299,7 @@ class TokenizerTester { dataset.decoded_text ) } - + /// Test encode and decode for a few edge cases func testEdgeCases() async { guard let edgeCases = edgeCases else { @@ -304,7 +323,7 @@ class TokenizerTester { ) } } - + func testUnknownToken() async { guard let model = await tokenizerModel else { return } XCTAssertEqual(model.unknownTokenId, unknownTokenId) @@ -326,10 +345,10 @@ class TokenizerTester { class TokenizerTests: XCTestCase { // Parallel testing in Xcode (when enabled) uses different processes, so this shouldn't be a problem static var _tester: TokenizerTester? = nil - + class var hubModelName: String? { nil } class var encodedSamplesFilename: String? { nil } - + // Known id retrieved from Python, to verify it was parsed correctly class var unknownTokenId: Int? { nil } @@ -364,25 +383,25 @@ class TokenizerTests: XCTestCase { await tester.testTokenize() } } - + func testEncode() async { if let tester = Self._tester { await tester.testEncode() } } - + func testDecode() async { if let tester = Self._tester { await tester.testDecode() } } - + func testEdgeCases() async { if let tester = Self._tester { await tester.testEdgeCases() } } - + func testUnknownToken() async { if let tester = Self._tester { await tester.testUnknownToken() diff --git a/Tests/TokenizersTests/TrieTests.swift b/Tests/TokenizersTests/TrieTests.swift index 15c54f8..051776d 100644 --- a/Tests/TokenizersTests/TrieTests.swift +++ b/Tests/TokenizersTests/TrieTests.swift @@ -6,6 +6,7 @@ // import XCTest + @testable import Tokenizers class TrieTests: XCTestCase { @@ -16,23 +17,23 @@ class TrieTests: XCTestCase { trie.insert("carp") trie.insert("car") XCTAssertEqual(trie.root.children.count, 1) - + let c = trie.get("c") XCTAssertNotNil(c) - XCTAssertEqual(c!.children.count, 1) // "a" - + XCTAssertEqual(c!.children.count, 1) // "a" + let ca = trie.get("ca") XCTAssertNotNil(ca) - XCTAssertEqual(ca!.children.count, 2) // "r", "t" - + XCTAssertEqual(ca!.children.count, 2) // "r", "t" + let car = trie.get("car") XCTAssertNotNil(car) XCTAssertTrue(car!.isLeaf) XCTAssertFalse(ca!.isLeaf) - + XCTAssertNil(trie.get("card")) } - + func testTrieCommonPrefixSearch() { // https://guillaume-be.github.io/2020-05-30/sentence_piece let trie = Trie() @@ -44,7 +45,7 @@ class TrieTests: XCTestCase { let leaves = trie.commonPrefixSearch("carpooling").map { String($0) } XCTAssertEqual(leaves, ["car", "carp"]) } - + func testTrieCommonPrefixSearchIterator() { // https://guillaume-be.github.io/2020-05-30/sentence_piece let trie = Trie()