diff --git a/.swift-format b/.swift-format new file mode 100644 index 0000000..53bf631 --- /dev/null +++ b/.swift-format @@ -0,0 +1,13 @@ +{ + "version": 1, + "lineLength": 120, + "indentation": { + "spaces": 4 + }, + "maximumBlankLines": 1, + "respectsExistingLineBreaks": true, + "lineBreakBeforeControlFlowKeywords": true, + "lineBreakBeforeEachArgument": true, + "multiElementCollectionTrailingCommas": true, + "spacesAroundRangeFormationOperators": true +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..32ed937 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,14 @@ +# Contributing to Swift Transformers + +## Code Styling and Linting + +Code formatting is enforced with `swift-format` default utility from Apple. +To install and run it on all the files in the project, use the following command: + +```bash +brew install swift-format +swift-format . -i -r +``` + +The style is controlled by the `.swift-format` JSON file in the root of the repository. +As there is no standard for Swift formatting, even Apple's own `swift-format` tool and Xcode differ in their formatting rules, and available settings. diff --git a/Package.swift b/Package.swift index de98aa6..09efe3b 100644 --- a/Package.swift +++ b/Package.swift @@ -19,18 +19,27 @@ let package = Package( name: "TransformersCLI", dependencies: [ "Models", "Generation", "Tokenizers", - .product(name: "ArgumentParser", package: "swift-argument-parser")]), - .executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]), + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ), + .executableTarget( + name: "HubCLI", + dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")] + ), .target(name: "Hub", resources: [.process("FallbackConfigs")]), .target(name: "Tokenizers", dependencies: ["Hub"]), .target(name: "TensorUtils"), .target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]), .target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]), - .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]), + .testTarget( + name: "TokenizersTests", + dependencies: ["Tokenizers", "Models", "Hub"], + resources: [.process("Resources"), .process("Vocabs")] + ), .testTarget(name: "HubTests", dependencies: ["Hub"]), .testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]), .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]), .testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]), - .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]) + .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]), ] ) diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 9464925..f473e9d 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -1,13 +1,13 @@ // // Generation.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // -import Tokenizers import CoreML import TensorUtils +import Tokenizers public enum GenerationMode { case contrastiveSearch @@ -29,13 +29,29 @@ public typealias PredictionStringCallback = (String) -> Void // TODO: callbacks (for streaming) public protocol Generation { - func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput - - func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String + func greedySearch( + config: GenerationConfig, + tokens: InputTokens, + model: NextTokenModel, + callback: PredictionTokensCallback? + ) async -> GenerationOutput + + func generate( + config: GenerationConfig, + prompt: String, + model: NextTokenModel, + tokenizer: Tokenizer, + callback: PredictionStringCallback? + ) async -> String } -public extension Generation { - func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput { +extension Generation { + public func greedySearch( + config: GenerationConfig, + tokens: InputTokens, + model: NextTokenModel, + callback: PredictionTokensCallback? = nil + ) async -> GenerationOutput { // Iterate until we find the eos token or reach the max length // TODO: additional stopping criteria var outputTokens = tokens @@ -48,9 +64,14 @@ public extension Generation { } return outputTokens } - + /// https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552 - func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput { + public func sample( + config: GenerationConfig, + tokens: InputTokens, + model: NextTokenModel, + callback: PredictionTokensCallback? = nil + ) async -> GenerationOutput { // Iterate until we find the eos token or reach the max length // TODO: additional stopping criteria var outputTokens = tokens @@ -68,7 +89,13 @@ public extension Generation { return outputTokens } - func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback? = nil) async -> String { + public func generate( + config: GenerationConfig, + prompt: String, + model: NextTokenModel, + tokenizer: Tokenizer, + callback: PredictionStringCallback? = nil + ) async -> String { let tokens = tokenizer.encode(text: prompt) var generationConfig = config generationConfig.maxLength = config.maxNewTokens + tokens.count @@ -86,7 +113,7 @@ public extension Generation { default: fatalError("Generation mode \(generationConfig.generationMode) not implemented yet") } - + return tokenizer.decode(tokens: output) } diff --git a/Sources/Generation/GenerationConfig.swift b/Sources/Generation/GenerationConfig.swift index a9eee7b..763673e 100644 --- a/Sources/Generation/GenerationConfig.swift +++ b/Sources/Generation/GenerationConfig.swift @@ -1,6 +1,6 @@ // // GenerationConfig.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // @@ -19,12 +19,23 @@ public struct GenerationConfig { public var topK = 50 public var topP = 1.0 public var repetitionPenalty = 1.0 - + public var padTokenId: Int? = nil public var bosTokenId: Int? = nil public var eosTokenId: Int? = nil - - public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) { + + public init( + maxLength: Int = 20, + maxNewTokens: Int, + doSample: Bool = false, + numBeams: Int = 1, + numBeamGroups: Int = 1, + penaltyAlpha: Double? = nil, + temperature: Double = 1.0, + topK: Int = 50, + topP: Double = 1.0, + repetitionPenalty: Double = 1.0 + ) { self.maxLength = maxLength self.maxNewTokens = maxNewTokens self.doSample = doSample @@ -38,19 +49,19 @@ public struct GenerationConfig { } } -public extension GenerationConfig { - var generationMode: GenerationMode { +extension GenerationConfig { + public var generationMode: GenerationMode { // Exclude this case from the pattern matching below if topK > 1 && !doSample && penaltyAlpha != nil && penaltyAlpha! > 0 { return .contrastiveSearch } - + switch (numBeams, numBeamGroups, doSample) { - case (1, 1, false) : return .greedy - case (1, 1, true) : return .sample + case (1, 1, false): return .greedy + case (1, 1, true): return .sample case (2..., 1, false): return .beam - case (2..., 2..., _) : return .groupBeam - default : return .unsupported + case (2..., 2..., _): return .groupBeam + default: return .unsupported } } } diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift index 1f3b8c6..d7d44de 100644 --- a/Sources/Hub/Downloader.swift +++ b/Sources/Hub/Downloader.swift @@ -6,8 +6,8 @@ // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE // -import Foundation import Combine +import Foundation class Downloader: NSObject, ObservableObject { private(set) var destination: URL @@ -86,16 +86,16 @@ class Downloader: NSObject, ObservableObject { stateSubscriber = downloadState.sink { state in switch state { case .completed: semaphore.signal() - case .failed: semaphore.signal() - default: break + case .failed: semaphore.signal() + default: break } } semaphore.wait() switch downloadState.value { case .completed(let url): return url - case .failed(let error): throw error - default: throw DownloadError.unexpectedError + case .failed(let error): throw error + default: throw DownloadError.unexpectedError } } @@ -105,7 +105,13 @@ class Downloader: NSObject, ObservableObject { } extension Downloader: URLSessionDownloadDelegate { - func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) { + func urlSession( + _: URLSession, + downloadTask: URLSessionDownloadTask, + didWriteData _: Int64, + totalBytesWritten: Int64, + totalBytesExpectedToWrite: Int64 + ) { downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite)) } @@ -114,7 +120,8 @@ extension Downloader: URLSessionDownloadDelegate { // If the downloaded file already exists on the filesystem, overwrite it try FileManager.default.moveDownloadedFile(from: location, to: self.destination) downloadState.value = .completed(destination) - } catch { + } + catch { downloadState.value = .failed(error) } } @@ -122,10 +129,10 @@ extension Downloader: URLSessionDownloadDelegate { func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { if let error = error { downloadState.value = .failed(error) -// } else if let response = task.response as? HTTPURLResponse { -// print("HTTP response status code: \(response.statusCode)") -// let headers = response.allHeaderFields -// print("HTTP response headers: \(headers)") + // } else if let response = task.response as? HTTPURLResponse { + // print("HTTP response status code: \(response.statusCode)") + // let headers = response.allHeaderFields + // print("HTTP response headers: \(headers)") } } } diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 8deeee8..ce386b6 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -1,6 +1,6 @@ // // Hub.swift -// +// // // Created by Pedro Cuenca on 18/5/23. // @@ -9,24 +9,24 @@ import Foundation public struct Hub {} -public extension Hub { - enum HubClientError: Error { +extension Hub { + public enum HubClientError: Error { case parse case authorizationRequired case unexpectedError case httpStatusCode(Int) } - - enum RepoType: String { + + public enum RepoType: String { case models case datasets case spaces } - - struct Repo { + + public struct Repo { let id: String let type: RepoType - + public init(id: String, type: RepoType = .models) { self.id = id self.type = type @@ -45,17 +45,18 @@ public struct Config { } func camelCase(_ string: String) -> String { - return string + return + string .split(separator: "_") .enumerated() .map { $0.offset == 0 ? $0.element.lowercased() : $0.element.capitalized } .joined() } - + func uncamelCase(_ string: String) -> String { let scalars = string.unicodeScalars var result = "" - + var previousCharacterIsLowercase = false for scalar in scalars { if CharacterSet.uppercaseLetters.contains(scalar) { @@ -65,21 +66,22 @@ public struct Config { let lowercaseChar = Character(scalar).lowercased() result += lowercaseChar previousCharacterIsLowercase = false - } else { + } + else { result += String(scalar) previousCharacterIsLowercase = true } } - + return result } - public subscript(dynamicMember member: String) -> Config? { let key = dictionary[member] != nil ? member : uncamelCase(member) if let value = dictionary[key] as? [String: Any] { return Config(value) - } else if let value = dictionary[key] { + } + else if let value = dictionary[key] { return Config(["value": value]) } return nil @@ -88,17 +90,17 @@ public struct Config { public var value: Any? { return dictionary["value"] } - + public var intValue: Int? { value as? Int } public var boolValue: Bool? { value as? Bool } public var stringValue: String? { value as? String } - + // Instead of doing this we could provide custom classes and decode to them public var arrayValue: [Config]? { guard let list = value as? [Any] else { return nil } - return list.map { Config($0 as! [String : Any]) } + return list.map { Config($0 as! [String: Any]) } } - + /// Tuple of token identifier and string value public var tokenValue: (UInt, String)? { value as? (UInt, String) } } @@ -120,7 +122,7 @@ public class LanguageModelConfigurationFromHub { return try await self.loadConfig(modelName: modelName, hubApi: hubApi) } } - + public init( modelFolder: URL, hubApi: HubApi = .shared @@ -145,7 +147,10 @@ public class LanguageModelConfigurationFromHub { // If the config exists but doesn't contain a tokenizerClass, use a fallback config if we have it if let fallbackConfig = Self.fallbackTokenizerConfig(for: modelType) { - let configuration = fallbackConfig.dictionary.merging(hubConfig.dictionary, uniquingKeysWith: { current, _ in current }) + let configuration = fallbackConfig.dictionary.merging( + hubConfig.dictionary, + uniquingKeysWith: { current, _ in current } + ) return Config(configuration) } @@ -183,7 +188,7 @@ public class LanguageModelConfigurationFromHub { return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi) } - + func loadConfig( modelFolder: URL, hubApi: HubApi = .shared @@ -192,7 +197,7 @@ public class LanguageModelConfigurationFromHub { let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json")) let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json")) let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json")) - + let configs = Configurations( modelConfig: modelConfig, tokenizerConfig: tokenizerConfig, @@ -202,13 +207,16 @@ public class LanguageModelConfigurationFromHub { } static func fallbackTokenizerConfig(for modelType: String) -> Config? { - guard let url = Bundle.module.url(forResource: "\(modelType)_tokenizer_config", withExtension: "json") else { return nil } + guard let url = Bundle.module.url(forResource: "\(modelType)_tokenizer_config", withExtension: "json") else { + return nil + } do { let data = try Data(contentsOf: url) let parsed = try JSONSerialization.jsonObject(with: data, options: []) guard let dictionary = parsed as? [String: Any] else { return nil } return Config(dictionary) - } catch { + } + catch { return nil } } diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 0d789da..61f53ca 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -15,91 +15,97 @@ public struct HubApi { public typealias RepoType = Hub.RepoType public typealias Repo = Hub.Repo - - public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false) { + + public init( + downloadBase: URL? = nil, + hfToken: String? = nil, + endpoint: String = "https://huggingface.co", + useBackgroundSession: Bool = false + ) { self.hfToken = hfToken if let downloadBase { self.downloadBase = downloadBase - } else { + } + else { let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first! self.downloadBase = documents.appending(component: "huggingface") } self.endpoint = endpoint self.useBackgroundSession = useBackgroundSession } - + public static let shared = HubApi() } /// File retrieval -public extension HubApi { +extension HubApi { /// Model data for parsed filenames - struct Sibling: Codable { + public struct Sibling: Codable { let rfilename: String } - - struct SiblingsResponse: Codable { + + public struct SiblingsResponse: Codable { let siblings: [Sibling] } - + /// Throws error if the response code is not 20X - func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { + public func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) if let hfToken = hfToken { request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") } let (data, response) = try await URLSession.shared.data(for: request) guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } - + switch response.statusCode { - case 200..<300: break - case 400..<500: throw Hub.HubClientError.authorizationRequired + case 200 ..< 300: break + case 400 ..< 500: throw Hub.HubClientError.authorizationRequired default: throw Hub.HubClientError.httpStatusCode(response.statusCode) } return (data, response) } - - func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] { + + public func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] { // Read repo info and only parse "siblings" let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)")! let (data, _) = try await httpGet(for: url) let response = try JSONDecoder().decode(SiblingsResponse.self, from: data) let filenames = response.siblings.map { $0.rfilename } guard globs.count > 0 else { return filenames } - + var selected: Set = [] for glob in globs { selected = selected.union(filenames.matching(glob: glob)) } return Array(selected) } - - func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { + + public func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { return try await getFilenames(from: Repo(id: repoId), matching: globs) } - - func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { + + public func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { return try await getFilenames(from: repo, matching: [glob]) } - - func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { + + public func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { return try await getFilenames(from: Repo(id: repoId), matching: [glob]) } } /// Configuration loading helpers -public extension HubApi { +extension HubApi { /// Assumes the file has already been downloaded. /// `filename` is relative to the download base. - func configuration(from filename: String, in repo: Repo) throws -> Config { + public func configuration(from filename: String, in repo: Repo) throws -> Config { let fileURL = localRepoLocation(repo).appending(path: filename) return try configuration(fileURL: fileURL) } - + /// Assumes the file is already present at local url. /// `fileURL` is a complete local file path for the given model - func configuration(fileURL: URL) throws -> Config { + public func configuration(fileURL: URL) throws -> Config { let data = try Data(contentsOf: fileURL) let parsed = try JSONSerialization.jsonObject(with: data, options: []) guard let dictionary = parsed as? [String: Any] else { throw Hub.HubClientError.parse } @@ -108,10 +114,10 @@ public extension HubApi { } /// Whoami -public extension HubApi { - func whoami() async throws -> Config { +extension HubApi { + public func whoami() async throws -> Config { guard hfToken != nil else { throw Hub.HubClientError.authorizationRequired } - + let url = URL(string: "\(endpoint)/api/whoami-v2")! let (data, _) = try await httpGet(for: url) @@ -122,12 +128,12 @@ public extension HubApi { } /// Snaphsot download -public extension HubApi { - func localRepoLocation(_ repo: Repo) -> URL { +extension HubApi { + public func localRepoLocation(_ repo: Repo) -> URL { downloadBase.appending(component: repo.type.rawValue).appending(component: repo.id) } - - struct HubFileDownloader { + + public struct HubFileDownloader { let repo: Repo let repoDestination: URL let relativeFilename: String @@ -142,22 +148,26 @@ public extension HubApi { url = url.appending(component: repo.type.rawValue) } url = url.appending(path: repo.id) - url = url.appending(path: "resolve/main") // TODO: revisions + url = url.appending(path: "resolve/main") // TODO: revisions url = url.appending(path: relativeFilename) return url } - + var destination: URL { repoDestination.appending(path: relativeFilename) } - + var downloaded: Bool { FileManager.default.fileExists(atPath: destination.path) } - + func prepareDestination() throws { let directoryURL = destination.deletingLastPathComponent() - try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) + try FileManager.default.createDirectory( + at: directoryURL, + withIntermediateDirectories: true, + attributes: nil + ) } // Note we go from Combine in Downloader to callback-based progress reporting @@ -182,7 +192,11 @@ public extension HubApi { } @discardableResult - func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot( + from repo: Repo, + matching globs: [String] = [], + progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { let filenames = try await getFilenames(from: repo, matching: globs) let progress = Progress(totalUnitCount: Int64(filenames.count)) let repoDestination = localRepoLocation(repo) @@ -205,64 +219,100 @@ public extension HubApi { progressHandler(progress) return repoDestination } - + @discardableResult - func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot( + from repoId: String, + matching globs: [String] = [], + progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } - + @discardableResult - func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot( + from repo: Repo, + matching glob: String, + progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler) } - + @discardableResult - func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot( + from repoId: String, + matching glob: String, + progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler) } } /// Stateless wrappers that use `HubApi` instances -public extension Hub { - static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { +extension Hub { + public static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { return try await HubApi.shared.getFilenames(from: repo, matching: globs) } - - static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { + + public static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs) } - - static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { + + public static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { return try await HubApi.shared.getFilenames(from: repo, matching: glob) } - - static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { + + public static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob) } - - static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + + public static func snapshot( + from repo: Repo, + matching globs: [String] = [], + progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) } - - static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) + + public static func snapshot( + from repoId: String, + matching globs: [String] = [], + progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { + return try await HubApi.shared.snapshot( + from: Repo(id: repoId), + matching: globs, + progressHandler: progressHandler + ) } - - static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + + public static func snapshot( + from repo: Repo, + matching glob: String, + progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { return try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler) } - - static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) + + public static func snapshot( + from repoId: String, + matching glob: String, + progressHandler: @escaping (Progress) -> Void = { _ in } + ) async throws -> URL { + return try await HubApi.shared.snapshot( + from: Repo(id: repoId), + matching: glob, + progressHandler: progressHandler + ) } - - static func whoami(token: String) async throws -> Config { + + public static func whoami(token: String) async throws -> Config { return try await HubApi(hfToken: token).whoami() } } -public extension [String] { - func matching(glob: String) -> [String] { +extension [String] { + public func matching(glob: String) -> [String] { filter { fnmatch(glob, $0, 0) == 0 } } } diff --git a/Sources/HubCLI/HubCLI.swift b/Sources/HubCLI/HubCLI.swift index fb0cc72..ac77609 100644 --- a/Sources/HubCLI/HubCLI.swift +++ b/Sources/HubCLI/HubCLI.swift @@ -1,6 +1,5 @@ import ArgumentParser import Foundation - import Hub let defaultTokenLocation = NSString("~/.cache/huggingface/token").expandingTildeInPath @@ -33,7 +32,7 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { case model case dataset case space - + var asHubApiRepoType: HubApi.RepoType { switch self { case .model: return .models @@ -42,7 +41,7 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { } } } - + @Argument(help: "Repo ID") var repo: String @@ -51,17 +50,20 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { @Option(help: "Glob patterns for files to include") var include: [String] = [] - + @Option(help: "Hugging Face token. If empty, will attempt to read from the filesystem at \(defaultTokenLocation)") var token: String? = nil - + func run() async throws { let hubApi = HubApi(hfToken: hfToken) let repo = Hub.Repo(id: repo, type: repoType.asHubApiRepoType) let downloadedTo = try await hubApi.snapshot(from: repo, matching: include) { progress in DispatchQueue.main.async { let totalPercent = 100 * progress.fractionCompleted - print("\(progress.completedUnitCount)/\(progress.totalUnitCount) \(totalPercent.formatted("%.02f"))%", terminator: "\r") + print( + "\(progress.completedUnitCount)/\(progress.totalUnitCount) \(totalPercent.formatted("%.02f"))%", + terminator: "\r" + ) fflush(stdout) } } @@ -71,19 +73,20 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { struct Whoami: AsyncParsableCommand, SubcommandWithToken { static let configuration = CommandConfiguration(abstract: "whoami") - + @Option(help: "Hugging Face token. If empty, will attempt to read from the filesystem at \(defaultTokenLocation)") var token: String? = nil - + func run() async throws { let hubApi = HubApi(hfToken: hfToken) let userInfo = try await hubApi.whoami() if let name = userInfo.name?.stringValue, - let fullname = userInfo.fullname?.stringValue, - let email = userInfo.email?.stringValue + let fullname = userInfo.fullname?.stringValue, + let email = userInfo.email?.stringValue { print("\(name) (\(fullname) <\(email)>)") - } else { + } + else { print("Cannot retrieve user info") } } diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index e46a23c..0e87e00 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -1,44 +1,44 @@ // // LanguageModel.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // import CoreML -import Tokenizers import Generation import Hub +import Tokenizers public class LanguageModel { public let model: MLModel - + public let minContextLength: Int public let maxContextLength: Int - + let input_ids = "input_ids" let attention_mask = "attention_mask" - + struct Configurations { var modelConfig: Config var tokenizerConfig: Config? var tokenizerData: Config } - + private var configuration: LanguageModelConfigurationFromHub? = nil private var _tokenizer: Tokenizer? = nil public required init(model: MLModel) { self.model = model - + // We assume inputs named "input_ids" with shape (1, seq_length) // Perhaps we should convert to vectors of shape (seq_length) and use sequenceConstraint instead of shapeConstraint let inputDescription = model.modelDescription.inputDescriptionsByName["input_ids"] - + guard let shapeConstraint = inputDescription?.multiArrayConstraint?.shapeConstraint else { fatalError("Cannot obtain shape information") } - + switch shapeConstraint.type { case .enumerated: // TODO: support a set of fixed shapes (keeping the first one here) @@ -55,13 +55,13 @@ public class LanguageModel { minContextLength = 128 maxContextLength = 128 } - + self.configuration = LanguageModelConfigurationFromHub(modelName: modelName) } } -public extension LanguageModel { - static func loadCompiled(url: URL, computeUnits: MLComputeUnits = .cpuAndGPU) throws -> LanguageModel { +extension LanguageModel { + public static func loadCompiled(url: URL, computeUnits: MLComputeUnits = .cpuAndGPU) throws -> LanguageModel { let config = MLModelConfiguration() config.computeUnits = computeUnits let model = try MLModel(contentsOf: url, configuration: config) @@ -69,63 +69,67 @@ public extension LanguageModel { } } -public extension LanguageModel { - var description: String { +extension LanguageModel { + public var description: String { if let description = model.modelDescription.metadata[MLModelMetadataKey.description] as? String, - !description.isEmpty { + !description.isEmpty + { return description } return model.configuration.modelDisplayName ?? "" } - + /// `name_or_path` in the Python world - var modelName: String { - if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String : String], - let name = userFields["co.huggingface.exporters.name"] { + public var modelName: String { + if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String], + let name = userFields["co.huggingface.exporters.name"] + { return name } // This is usually the basename of the file, that's our best bet if no metadata exists - guard let modelName = model.configuration.modelDisplayName else { fatalError("Models must have a name that identifies them") } + guard let modelName = model.configuration.modelDisplayName else { + fatalError("Models must have a name that identifies them") + } return modelName } - - var inputIdsDescription: MLFeatureDescription { + + public var inputIdsDescription: MLFeatureDescription { model.modelDescription.inputDescriptionsByName[input_ids]! } - - var inputIdsName: String { + + public var inputIdsName: String { inputIdsDescription.name } - + /// The expected shape of the models latent sample input - var inputIdsShape: [Int] { + public var inputIdsShape: [Int] { inputIdsDescription.multiArrayConstraint!.shape.map { $0.intValue } } - - var requiresAttention: Bool { + + public var requiresAttention: Bool { model.modelDescription.inputDescriptionsByName[attention_mask] != nil } - + // MLShapedArrayProtocol is either a MLShapedArray or a MLShapedArraySlice - func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { + public func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { // TODO: exceptions - + // Maybe pad or truncate let maxTokens = min(tokens.count, maxContextLength) - let padLength = maxTokens >= minContextLength ? 0 : minContextLength-maxTokens - let inputTokens = Array(tokens[0..= minContextLength ? 0 : minContextLength - maxTokens + let inputTokens = Array(tokens[0 ..< maxTokens]) + Array(repeating: config.padTokenId ?? 0, count: padLength) + let inputIds = MLShapedArray(scalars: inputTokens.map { Int32($0) }, shape: inputIdsShape) var inputDictionary = [inputIdsName: MLFeatureValue(shapedArray: inputIds)] if requiresAttention { let mask = Array(repeating: 1, count: maxTokens) + Array(repeating: 0, count: padLength) - let attentionMask = MLShapedArray(scalars: mask.map{ Int32($0) }, shape: inputIdsShape) + let attentionMask = MLShapedArray(scalars: mask.map { Int32($0) }, shape: inputIdsShape) inputDictionary[attention_mask] = MLFeatureValue(shapedArray: attentionMask) } let input = try! MLDictionaryFeatureProvider(dictionary: inputDictionary) - + let output = try! model.prediction(from: input) - + // TODO: maybe try to support models with "token_scores" too (after the softmax) assert(output.featureNames.first! == "logits") @@ -136,61 +140,63 @@ public extension LanguageModel { } /// async properties downloaded from the configuration -public extension LanguageModel { - var modelConfig: Config { +extension LanguageModel { + public var modelConfig: Config { get async throws { try await configuration!.modelConfig } } - - var tokenizerConfig: Config? { + + public var tokenizerConfig: Config? { get async throws { try await configuration!.tokenizerConfig } } - - var tokenizerData: Config { + + public var tokenizerData: Config { get async throws { try await configuration!.tokenizerData } } - - var modelType: String? { + + public var modelType: String? { get async throws { try await modelConfig.modelType?.stringValue } } - - var textGenerationParameters: Config? { + + public var textGenerationParameters: Config? { get async throws { try await modelConfig.taskSpecificParams?.textGeneration } } - - var defaultDoSample: Bool { + + public var defaultDoSample: Bool { get async throws { try await textGenerationParameters?.doSample?.boolValue ?? true } } - var bosTokenId: Int? { + public var bosTokenId: Int? { get async throws { let modelConfig = try await modelConfig return modelConfig.bosTokenId?.intValue } } - - var eosTokenId: Int? { + + public var eosTokenId: Int? { get async throws { let modelConfig = try await modelConfig return modelConfig.eosTokenId?.intValue } } - - var tokenizer: Tokenizer { + + public var tokenizer: Tokenizer { get async throws { guard _tokenizer == nil else { return _tokenizer! } - guard let tokenizerConfig = try await tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" } + guard let tokenizerConfig = try await tokenizerConfig else { + throw "Cannot retrieve Tokenizer configuration" + } let tokenizerData = try await tokenizerData _tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) return _tokenizer! diff --git a/Sources/Models/LanguageModelTypes.swift b/Sources/Models/LanguageModelTypes.swift index 08d7d48..79edb59 100644 --- a/Sources/Models/LanguageModelTypes.swift +++ b/Sources/Models/LanguageModelTypes.swift @@ -1,13 +1,13 @@ // // LanguageModelTypes.swift -// +// // // Created by Pedro Cuenca on 8/5/23. // import CoreML -import Tokenizers import Generation +import Tokenizers public protocol LanguageModelProtocol { /// `name_or_path` in the Python world @@ -15,16 +15,16 @@ public protocol LanguageModelProtocol { var tokenizer: Tokenizer { get async throws } var model: MLModel { get } - + init(model: MLModel) - + /// Make prediction callable (this works like __call__ in Python) func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol } -public extension LanguageModelProtocol { - func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { +extension LanguageModelProtocol { + public func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { predictNextTokenScores(tokens, config: config) } } @@ -34,9 +34,17 @@ public protocol TextGenerationModel: Generation, LanguageModelProtocol { func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback?) async throws -> String } -public extension TextGenerationModel { +extension TextGenerationModel { @discardableResult - func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async throws -> String { - try await self.generate(config: config, prompt: prompt, model: self.callAsFunction, tokenizer: self.tokenizer, callback: callback) + public func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) + async throws -> String + { + try await self.generate( + config: config, + prompt: prompt, + model: self.callAsFunction, + tokenizer: self.tokenizer, + callback: callback + ) } } diff --git a/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift index 53dc0db..9eb1630 100644 --- a/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift @@ -2,7 +2,7 @@ import Foundation public struct TemperatureLogitsWarper: LogitsWarper { public var temperature: Float - + public init(temperature: Float) { self.temperature = temperature } diff --git a/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift b/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift index a236d84..580b199 100644 --- a/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift @@ -1,5 +1,5 @@ -import Foundation import Accelerate +import Foundation /// Top-K. /// Select the k most-probable element indices from `arr` @@ -7,7 +7,7 @@ import Accelerate /// and their probabilities. public struct TopKLogitsWarper: LogitsWarper { public var k: Int - + public init(k: Int) { self.k = k } diff --git a/Sources/TensorUtils/MLMultiArray+Utils.swift b/Sources/TensorUtils/MLMultiArray+Utils.swift index ddb2760..c0341db 100644 --- a/Sources/TensorUtils/MLMultiArray+Utils.swift +++ b/Sources/TensorUtils/MLMultiArray+Utils.swift @@ -6,12 +6,12 @@ // Copyright © 2019 Hugging Face. All rights reserved. // -import Foundation import CoreML +import Foundation -public extension MLMultiArray { +extension MLMultiArray { /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) - static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { + public static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { var shape = Array(repeating: 1, count: dims) shape[shape.count - 1] = arr.count /// Examples: @@ -25,9 +25,9 @@ public extension MLMultiArray { } return o } - + /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) - static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray { + public static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray { var shape = Array(repeating: 1, count: dims) shape[shape.count - 1] = arr.count /// Examples: @@ -41,31 +41,31 @@ public extension MLMultiArray { } return o } - + /// This will concatenate all dimensions into one one-dim array. - static func toIntArray(_ o: MLMultiArray) -> [Int] { + public static func toIntArray(_ o: MLMultiArray) -> [Int] { var arr = Array(repeating: 0, count: o.count) let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for i in 0.. [Int] { Self.toIntArray(self) } - + + public func toIntArray() -> [Int] { Self.toIntArray(self) } + /// This will concatenate all dimensions into one one-dim array. - static func toDoubleArray(_ o: MLMultiArray) -> [Double] { + public static func toDoubleArray(_ o: MLMultiArray) -> [Double] { var arr: [Double] = Array(repeating: 0, count: o.count) let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for i in 0.. [Double] { Self.toDoubleArray(self) } - + + public func toDoubleArray() -> [Double] { Self.toDoubleArray(self) } + /// Helper to construct a sequentially-indexed multi array, /// useful for debugging and unit tests /// Example in 3 dimensions: @@ -77,29 +77,28 @@ public extension MLMultiArray { /// [ 16, 17, 18, 19 ], /// [ 20, 21, 22, 23 ]]] /// ``` - static func testTensor(shape: [Int]) -> MLMultiArray { + public static func testTensor(shape: [Int]) -> MLMultiArray { let arr = try! MLMultiArray(shape: shape as [NSNumber], dataType: .double) let ptr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) - for i in 0.. MLMultiArray { + public static func slice(_ o: MLMultiArray, indexing: [Indexing]) -> MLMultiArray { assert( indexing.count == o.shape.count ) @@ -118,12 +117,12 @@ public extension MLMultiArray { selectDims: selectDims ) } - + /// Slice an array according to a list, according to `sliceDim` (which dimension to slice on) /// and a dictionary of `dim` to `index`. /// /// You must select all other dimensions than the slice dimension (cf. the assert). - static func slice(_ o: MLMultiArray, sliceDim: Int, selectDims: [Int: Int]) -> MLMultiArray { + public static func slice(_ o: MLMultiArray, sliceDim: Int, selectDims: [Int: Int]) -> MLMultiArray { assert( selectDims.count + 1 == o.shape.count ) @@ -131,16 +130,17 @@ public extension MLMultiArray { shape[sliceDim] = o.shape[sliceDim] /// print("About to slice ndarray of shape \(o.shape) into ndarray of shape \(shape)") let arr = try! MLMultiArray(shape: shape, dataType: .double) - + /// let srcPtr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) /// TODO: use srcPtr instead of array subscripting. let dstPtr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) - for i in 0.. String { return String(repeating: " ", count: x) } - + // This function is called recursively for every dimension. // Add an entry for this dimension to the end of the array. var indices = indices + [0] - - let d = indices.count - 1 // the current dimension - let N = shape[d].intValue // how many elements in this dimension + + let d = indices.count - 1 // the current dimension + let N = shape[d].intValue // how many elements in this dimension var s = "[" - if indices.count < shape.count { // not last dimension yet? - for i in 0.. { - var floats: [Float] { +extension MLShapedArray { + public var floats: [Float] { guard self.strides.first == 1, self.strides.count == 1 else { // For some reason this path is slow. // If strides is not 1, we can write a Metal kernel to copy the values properly. return self.scalars } - + // Fast path: memcpy let mlArray = MLMultiArray(self) return mlArray.floats ?? self.scalars } } -public extension MLShapedArraySlice { - var floats: [Float] { +extension MLShapedArraySlice { + public var floats: [Float] { guard self.strides.first == 1, self.strides.count == 1 else { // For some reason this path is slow. // If strides is not 1, we can write a Metal kernel to copy the values properly. @@ -35,10 +35,10 @@ public extension MLShapedArraySlice { } } -public extension MLMultiArray { - var floats: [Float]? { +extension MLMultiArray { + public var floats: [Float]? { guard self.dataType == .float32 else { return nil } - + var result: [Float] = Array(repeating: 0, count: self.count) return self.withUnsafeBytes { ptr in guard let source = ptr.baseAddress else { return nil } diff --git a/Sources/TensorUtils/Math.swift b/Sources/TensorUtils/Math.swift index 4050ac1..12ad579 100644 --- a/Sources/TensorUtils/Math.swift +++ b/Sources/TensorUtils/Math.swift @@ -6,9 +6,9 @@ // Copyright © 2019 Hugging Face. All rights reserved. // -import Foundation import Accelerate import CoreML +import Foundation /// /// From M.I. Hollemans @@ -16,10 +16,10 @@ import CoreML /// https://github.com/hollance/CoreMLHelpers /// public struct Math { - + /** Returns the index and value of the largest element in the array. - + - Parameters: - ptr: Pointer to the first element in memory. - count: How many elements to look at. @@ -31,7 +31,7 @@ public struct Math { vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) return (Int(maxIndex), maxValue) } - + /** Returns the index and value of the largest element in the array. - Parameters: @@ -45,14 +45,14 @@ public struct Math { vDSP_maxviD(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) return (Int(maxIndex), maxValue) } - + public static func argmax32(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) { var maxValue: Float = 0 var maxIndex: vDSP_Length = 0 vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) return (Int(maxIndex), maxValue) } - + /// MLMultiArray helper. /// Works in our specific use case. public static func argmax(_ multiArray: MLMultiArray) -> (Int, Double) { @@ -60,7 +60,7 @@ public struct Math { let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer)) return Math.argmax(ptr, count: multiArray.count) } - + /// MLMultiArray helper. /// Works in our specific use case. public static func argmax32(_ multiArray: MLMultiArray) -> (Int, Float) { @@ -88,7 +88,7 @@ public struct Math { let i = randomNumber(probabilities: probs) return indexes[i] } - + /** Computes the "softmax" function over an array. Based on code from https://github.com/nikolaypavlov/MLPNeuralNet/ @@ -103,31 +103,31 @@ public struct Math { public static func softmax(_ x: [Float]) -> [Float] { var x = x let len = vDSP_Length(x.count) - + // Find the maximum value in the input array. var max: Float = 0 vDSP_maxv(x, 1, &max, len) - + // Subtract the maximum from all the elements in the array. // Now the highest value in the array is 0. max = -max vDSP_vsadd(x, 1, &max, &x, 1, len) - + // Exponentiate all the elements in the array. var count = Int32(x.count) vvexpf(&x, x, &count) - + // Compute the sum of all exponentiated values. var sum: Float = 0 vDSP_sve(x, 1, &sum, len) - + // Divide each element by the sum. This normalizes the array contents // so that they all add up to 1. vDSP_vsdiv(x, 1, &sum, &x, 1, len) - + return x } - + /// Multinomial sampling /// /// From https://stackoverflow.com/questions/30309556/generate-random-numbers-with-a-given-distribution @@ -152,16 +152,16 @@ public struct Math { // MLShapedArray versions -public extension Math { - static func argmax(_ shapedArray: MLShapedArray) -> (Int, Float) { +extension Math { + public static func argmax(_ shapedArray: MLShapedArray) -> (Int, Float) { shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices") return Math.argmax32(ptr.baseAddress!, count: shapedArray.count, stride: strides.first!) } } - + // TODO: handle Double, etc. - static func argmax(_ shapedArray: some MLShapedArrayProtocol) -> (Int, Float) { + public static func argmax(_ shapedArray: some MLShapedArrayProtocol) -> (Int, Float) { shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices") let floatsPtr = ptr.baseAddress as! UnsafePointer diff --git a/Sources/Tokenizers/BPETokenizer.swift b/Sources/Tokenizers/BPETokenizer.swift index bed4785..948e5cf 100644 --- a/Sources/Tokenizers/BPETokenizer.swift +++ b/Sources/Tokenizers/BPETokenizer.swift @@ -20,7 +20,7 @@ struct BytePair: Hashable { self.a = tuple[0] self.b = tuple[1] } - + static func == (lhs: BytePair, rhs: BytePair) -> Bool { return lhs.a == rhs.a && lhs.b == rhs.b } @@ -30,12 +30,11 @@ struct BytePair: Hashable { } } - class BPETokenizer: PreTrainedTokenizerModel { - let bpeRanks: Dictionary + let bpeRanks: [BytePair: Int] private let tokensToIds: [String: Int] private let idsToTokens: [Int: String] - + public let bosToken: String? public let bosTokenId: Int? public let eosToken: String? @@ -43,27 +42,30 @@ class BPETokenizer: PreTrainedTokenizerModel { public let unknownToken: String? public let unknownTokenId: Int? - required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws { - guard let merges = tokenizerData.model?.merges?.value as? [String] else { fatalError("BPETokenizer requires merges") } + required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws { + guard let merges = tokenizerData.model?.merges?.value as? [String] else { + fatalError("BPETokenizer requires merges") + } guard let vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else { throw TokenizerError.missingVocab } - var bpeRanks: Dictionary = [:] + var bpeRanks: [BytePair: Int] = [:] for (i, item) in merges.enumerated() { let tuple = item.unicodeScalars.split(separator: " ", omittingEmptySubsequences: false).map { String($0) } let bp = BytePair(tuple: tuple) bpeRanks[bp] = i } self.bpeRanks = bpeRanks - + self.tokensToIds = vocab.merging(addedTokens) { $1 } self.idsToTokens = Utils.invert(self.tokensToIds) - + // Populate tokens if let unknownToken = TokenizerModel.unknownToken(from: tokenizerConfig) { self.unknownToken = unknownToken self.unknownTokenId = self.tokensToIds[unknownToken] - } else { + } + else { self.unknownToken = nil self.unknownTokenId = nil } @@ -78,7 +80,7 @@ class BPETokenizer: PreTrainedTokenizerModel { func convertTokenToId(_ token: String) -> Int? { return tokensToIds[token] ?? self.unknownTokenId } - + func convertIdToToken(_ id: Int) -> String? { return idsToTokens[id] } @@ -90,7 +92,7 @@ class BPETokenizer: PreTrainedTokenizerModel { return Array(token.utf8).map { byteEncoder[$0]! }.joined() } } - + func hexaEncode(text: String) -> [String] { let RE = #"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"# let tokens = text.ranges(of: RE).map { String(text[$0]) } @@ -98,27 +100,27 @@ class BPETokenizer: PreTrainedTokenizerModel { return Array(token.utf8).map { String(format: "<0x%02X>", $0) } } } - + private func getPairs(word: [String]) -> Set { var s = Set() - for i in 0.. String { if token.count <= 1 { return token } - + var word = Array(token).map { String($0) } var pairs = Array(getPairs(word: word)) - + while true { let bigrams = pairs.filter { (bp) -> Bool in bpeRanks[bp] != nil } if bigrams.count == 0 { @@ -132,18 +134,20 @@ class BPETokenizer: PreTrainedTokenizerModel { var newWord: [String] = [] var i = 0 while i < word.count { - if let j = word[i.. [String] { let text = tokenizeChineseCharsIfNeed(text) var tokens: [String] = [] @@ -61,7 +68,7 @@ public class BertTokenizer { } return tokens } - + private func convertTokensToIds(tokens: [String]) throws -> [Int] { if tokens.count > maxLen { throw TokenizerError.tooLong( @@ -74,85 +81,85 @@ public class BertTokenizer { } return tokens.map { vocab[$0]! } } - + /// Main entry point func tokenizeToIds(text: String) -> [Int] { return try! convertTokensToIds(tokens: tokenize(text: text)) } - + func tokenToId(token: String) -> Int { return vocab[token]! } - + /// Un-tokenization: get tokens from tokenIds func unTokenize(tokens: [Int]) -> [String] { return tokens.map { ids_to_tokens[$0]! } } - + /// Un-tokenization: func convertWordpieceToBasicTokenList(_ wordpieceTokenList: [String]) -> String { var tokenList: [String] = [] var individualToken: String = "" - + for token in wordpieceTokenList { if token.starts(with: "##") { individualToken += String(token.suffix(token.count - 2)) - } else { + } + else { if individualToken.count > 0 { tokenList.append(individualToken) } - + individualToken = token } } - + tokenList.append(individualToken) - + return tokenList.joined(separator: " ") } - + private func tokenizeChineseCharsIfNeed(_ text: String) -> String { guard tokenizeChineseChars else { return text } - + return text.map { c in if let scalar = c.unicodeScalars.first, Utils.isChineseChar(scalar) { " \(c) " - } else { + } + else { "\(c)" } }.joined() } } - extension BertTokenizer: PreTrainedTokenizerModel { public var unknownToken: String? { wordpieceTokenizer.unkToken } public var unknownTokenId: Int? { vocab[unknownToken!] } func encode(text: String) -> [Int] { tokenizeToIds(text: text) } - + func decode(tokens: [Int]) -> String { let tokens = unTokenize(tokens: tokens) return convertWordpieceToBasicTokenList(tokens) } - + public func convertTokenToId(_ token: String) -> Int? { return vocab[token] ?? unknownTokenId } - + public func convertIdToToken(_ id: Int) -> String? { return ids_to_tokens[id] } } - class BasicTokenizer { let neverSplit = [ - "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]" + "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", ] - + func tokenize(text: String) -> [String] { let splitTokens = text.folding(options: .diacriticInsensitive, locale: nil) .components(separatedBy: NSCharacterSet.whitespaces) @@ -165,11 +172,13 @@ class BasicTokenizer { for c in token.lowercased() { if c.isLetter || c.isNumber || c == "°" { currentTok += String(c) - } else if currentTok.count > 0 { + } + else if currentTok.count > 0 { toks.append(currentTok) toks.append(String(c)) currentTok = "" - } else { + } + else { toks.append(String(c)) } } @@ -182,16 +191,15 @@ class BasicTokenizer { } } - class WordpieceTokenizer { let unkToken = "[UNK]" private let maxInputCharsPerWord = 100 private let vocab: [String: Int] - + init(vocab: [String: Int]) { self.vocab = vocab } - + /// `word`: A single token. /// Warning: this differs from the `pytorch-transformers` implementation. /// This should have already been passed through `BasicTokenizer`. @@ -207,7 +215,7 @@ class WordpieceTokenizer { var end = word.count var cur_substr: String? = nil while start < end { - var substr = Utils.substr(word, start.. 0 { substr = "##\(substr)" } @@ -226,7 +234,8 @@ class WordpieceTokenizer { } if isBad { outputTokens.append(unkToken) - } else { + } + else { outputTokens.append(contentsOf: subTokens) } return outputTokens diff --git a/Sources/Tokenizers/ByteEncoder.swift b/Sources/Tokenizers/ByteEncoder.swift index e636ca6..1d0532a 100644 --- a/Sources/Tokenizers/ByteEncoder.swift +++ b/Sources/Tokenizers/ByteEncoder.swift @@ -8,7 +8,7 @@ import Foundation -let byteEncoder: Dictionary = [ +let byteEncoder: [UTF8.CodeUnit: String] = [ 33: "!", 34: "\"", 35: "#", diff --git a/Sources/Tokenizers/Decoder.swift b/Sources/Tokenizers/Decoder.swift index c98c4b9..7dfe9a0 100644 --- a/Sources/Tokenizers/Decoder.swift +++ b/Sources/Tokenizers/Decoder.swift @@ -1,6 +1,6 @@ // // Decoder.swift -// +// // // Created by Pedro Cuenca on 17/7/23. // @@ -11,7 +11,7 @@ import Hub public protocol Decoder { func decode(tokens: [String]) -> [String] func callAsFunction(tokens: [String]) -> [String] - + init(config: Config) } @@ -23,7 +23,7 @@ extension Decoder { enum DecoderType: String { case Sequence -// case WordPiece + // case WordPiece case ByteLevel case Replace case ByteFallback @@ -40,26 +40,26 @@ struct DecoderFactory { guard let typeName = config.type?.stringValue else { return nil } let type = DecoderType(rawValue: typeName) switch type { - case .Sequence : return DecoderSequence(config: config) - case .ByteLevel : return ByteLevelDecoder(config: config, addedTokens: addedTokens) - case .Replace : return ReplaceDecoder(config: config) + case .Sequence: return DecoderSequence(config: config) + case .ByteLevel: return ByteLevelDecoder(config: config, addedTokens: addedTokens) + case .Replace: return ReplaceDecoder(config: config) case .ByteFallback: return ByteFallbackDecoder(config: config) - case .Fuse : return FuseDecoder(config: config) - case .Strip : return StripDecoder(config: config) - case .Metaspace : return MetaspaceDecoder(config: config) - default : fatalError("Unsupported Decoder type: \(typeName)") + case .Fuse: return FuseDecoder(config: config) + case .Strip: return StripDecoder(config: config) + case .Metaspace: return MetaspaceDecoder(config: config) + default: fatalError("Unsupported Decoder type: \(typeName)") } } } class DecoderSequence: Decoder { let decoders: [Decoder] - + required public init(config: Config) { guard let configs = config.decoders?.arrayValue else { fatalError("No decoders in Sequence") } decoders = configs.compactMap { DecoderFactory.fromConfig(config: $0) } } - + func decode(tokens: [String]) -> [String] { decoders.reduce(tokens) { current, decoder in decoder(tokens: current) @@ -69,26 +69,26 @@ class DecoderSequence: Decoder { class ByteLevelDecoder: Decoder { let addedTokens: Set - + required public init(config: Config) { self.addedTokens = [] } - + init(config: Config, addedTokens: Set?) { self.addedTokens = addedTokens ?? [] } - + func decode(tokens: [String]) -> [String] { var subTexts: [String] = [] var currentSubText: [String] = [] - + func convertTokensToString(_ tokens: [String]) -> String { let text = tokens.joined(separator: "") - + let utfCodepoints = text.map { byteDecoder[String($0)]! } return String(decoding: utfCodepoints, as: UTF8.self) } - + for token in tokens { if addedTokens.contains(token) { if !currentSubText.isEmpty { @@ -96,26 +96,27 @@ class ByteLevelDecoder: Decoder { currentSubText = [] } subTexts.append(token) - } else { + } + else { currentSubText.append(token) } } - + if !currentSubText.isEmpty { subTexts.append(convertTokensToString(currentSubText)) } - + return subTexts } } class ReplaceDecoder: Decoder { let pattern: StringReplacePattern? - + required public init(config: Config) { self.pattern = StringReplacePattern.from(config: config) } - + func decode(tokens: [String]) -> [String] { guard let pattern = pattern else { return tokens } return tokens.map { pattern.replace($0) } @@ -124,7 +125,7 @@ class ReplaceDecoder: Decoder { class ByteFallbackDecoder: Decoder { required public init(config: Config) {} - + func decode(tokens: [String]) -> [String] { var newTokens: [String] = [] var byteTokens: [Int] = [] @@ -135,13 +136,14 @@ class ByteFallbackDecoder: Decoder { } let startIndex = token.index(token.startIndex, offsetBy: 3) let endIndex = token.index(token.startIndex, offsetBy: 5) - return Int(token[startIndex.. [String] { [tokens.joined(separator: "")] } @@ -167,16 +169,22 @@ class StripDecoder: Decoder { let content: String let start: Int let stop: Int - + required public init(config: Config) { - guard let content = config.content?.stringValue else { fatalError("Incorrect StripDecoder configuration: can't parse `content`.") } - guard let start = config.start?.intValue else { fatalError("Incorrect StripDecoder configuration: can't parse `start`.") } - guard let stop = config.stop?.intValue else { fatalError("Incorrect StripDecoder configuration: can't parse `stop`.") } + guard let content = config.content?.stringValue else { + fatalError("Incorrect StripDecoder configuration: can't parse `content`.") + } + guard let start = config.start?.intValue else { + fatalError("Incorrect StripDecoder configuration: can't parse `start`.") + } + guard let stop = config.stop?.intValue else { + fatalError("Incorrect StripDecoder configuration: can't parse `stop`.") + } self.content = content self.start = start self.stop = stop } - + func decode(tokens: [String]) -> [String] { tokens.map { token in token.trimmingFromStart(upto: start).trimmingFromEnd(upto: stop) @@ -187,7 +195,7 @@ class StripDecoder: Decoder { class MetaspaceDecoder: Decoder { let addPrefixSpace: Bool let replacement: String - + required public init(config: Config) { addPrefixSpace = config.addPrefixSpace?.boolValue ?? false replacement = config.replacement?.stringValue ?? "_" @@ -205,8 +213,8 @@ class MetaspaceDecoder: Decoder { } // We could use firstIndex(where:), lastIndex(where:) for possibly better efficiency (and do both ends at once) -public extension String { - func trimmingFromStart(character: Character = " ", upto: Int) -> String { +extension String { + public func trimmingFromStart(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 while trimmed < upto && result.first == character { @@ -216,7 +224,7 @@ public extension String { return result } - func trimmingFromEnd(character: Character = " ", upto: Int) -> String { + public func trimmingFromEnd(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 while trimmed < upto && result.last == character { diff --git a/Sources/Tokenizers/Normalizer.swift b/Sources/Tokenizers/Normalizer.swift index 7f1353f..910897e 100644 --- a/Sources/Tokenizers/Normalizer.swift +++ b/Sources/Tokenizers/Normalizer.swift @@ -1,6 +1,6 @@ // // Normalizer.swift -// +// // // Created by Pedro Cuenca on 17/7/23. // @@ -11,7 +11,7 @@ import Hub public protocol Normalizer { func normalize(text: String) -> String func callAsFunction(text: String) -> String - + init(config: Config) } @@ -43,29 +43,29 @@ struct NormalizerFactory { let type = NormalizerType(rawValue: typeName) switch type { case .Sequence: return NormalizerSequence(config: config) - case .Prepend : return PrependNormalizer(config: config) - case .Replace : return ReplaceNormalizer(config: config) - case .Lowercase : return LowercaseNormalizer(config: config) - case .NFD : return NFDNormalizer(config: config) - case .NFC : return NFCNormalizer(config: config) - case .NFKD : return NFKDNormalizer(config: config) - case .NFKC : return NFKCNormalizer(config: config) - case .Bert : return BertNormalizer(config: config) - case .Precompiled : return PrecompiledNormalizer(config: config) - case .StripAccents : return StripAccentsNormalizer(config: config) - default : fatalError("Unsupported Normalizer type: \(typeName)") + case .Prepend: return PrependNormalizer(config: config) + case .Replace: return ReplaceNormalizer(config: config) + case .Lowercase: return LowercaseNormalizer(config: config) + case .NFD: return NFDNormalizer(config: config) + case .NFC: return NFCNormalizer(config: config) + case .NFKD: return NFKDNormalizer(config: config) + case .NFKC: return NFKCNormalizer(config: config) + case .Bert: return BertNormalizer(config: config) + case .Precompiled: return PrecompiledNormalizer(config: config) + case .StripAccents: return StripAccentsNormalizer(config: config) + default: fatalError("Unsupported Normalizer type: \(typeName)") } } } class NormalizerSequence: Normalizer { let normalizers: [Normalizer] - + required public init(config: Config) { guard let configs = config.normalizers?.arrayValue else { fatalError("No normalizers in Sequence") } normalizers = configs.compactMap { NormalizerFactory.fromConfig(config: $0) } } - + public func normalize(text: String) -> String { normalizers.reduce(text) { current, normalizer in normalizer(text: current) @@ -75,11 +75,11 @@ class NormalizerSequence: Normalizer { class PrependNormalizer: Normalizer { let prepend: String - + required public init(config: Config) { prepend = config.prepend?.stringValue ?? "" } - + public func normalize(text: String) -> String { return prepend + text } @@ -87,11 +87,11 @@ class PrependNormalizer: Normalizer { class ReplaceNormalizer: Normalizer { let pattern: StringReplacePattern? - + required public init(config: Config) { self.pattern = StringReplacePattern.from(config: config) } - + public func normalize(text: String) -> String { guard let pattern = pattern else { return text } return pattern.replace(text) @@ -106,7 +106,7 @@ class LowercaseNormalizer: Normalizer { } } -class NFDNormalizer: Normalizer { +class NFDNormalizer: Normalizer { required public init(config: Config) {} public func normalize(text: String) -> String { @@ -122,7 +122,7 @@ class NFCNormalizer: Normalizer { } } -class NFKDNormalizer: Normalizer { +class NFKDNormalizer: Normalizer { required init(config: Config) {} func normalize(text: String) -> String { @@ -172,17 +172,16 @@ class BertNormalizer: Normalizer { private func cleanText(text: String) -> String { text.map { c in guard let scalar = c.unicodeScalars.first, - scalar.value != 0x0, - scalar.value != 0xFFFD, - !isControl(scalar) + scalar.value != 0x0, + scalar.value != 0xFFFD, + !isControl(scalar) else { return "\(c)" } // Replace whitespace: \t, \n, \r - if scalar.value == 0x009 || - scalar.value == 0x00A || - scalar.value == 0x000D { + if scalar.value == 0x009 || scalar.value == 0x00A || scalar.value == 0x000D { return " " - } else { + } + else { return "\(c)" } } @@ -193,7 +192,8 @@ class BertNormalizer: Normalizer { if c.value == 0x009 || c.value == 0x00A || c.value == 0x000D { // Except \t, \n, \r that will be spaces. return false - } else { + } + else { // https://unicode.org/reports/tr44/#GC_Values_Table // Other Cc | Cf | Cs | Co | Cn return isOther(c.properties.generalCategory) @@ -201,19 +201,16 @@ class BertNormalizer: Normalizer { } private func isOther(_ c: Unicode.GeneralCategory) -> Bool { - c == .control || - c == .format || - c == .surrogate || - c == .privateUse || - c == .unassigned + c == .control || c == .format || c == .surrogate || c == .privateUse || c == .unassigned } private func handleChineseChars(text: String) -> String { text.map { c in if let scalar = c.unicodeScalars.first, Utils.isChineseChar(scalar) { " \(c) " - } else { - "\(c)" + } + else { + "\(c)" } } .joined() @@ -221,9 +218,11 @@ class BertNormalizer: Normalizer { private func stripAccents(text: String) -> String { text.decomposedStringWithCanonicalMapping - .filter { $0.unicodeScalars.allSatisfy { scalar in - !(0x0300 <= scalar.value && scalar.value <= 0x036F) - }} + .filter { + $0.unicodeScalars.allSatisfy { scalar in + !(0x0300 <= scalar.value && scalar.value <= 0x036F) + } + } } } @@ -242,10 +241,10 @@ class PrecompiledNormalizer: Normalizer { for scalar in text.unicodeScalars { switch scalar.value { - case 0x0001...0x0008, 0x000B, 0x000E...0x001F, 0x007F, 0x008F, 0x009F: + case 0x0001 ... 0x0008, 0x000B, 0x000E ... 0x001F, 0x007F, 0x008F, 0x009F: // Non-printing control characters output.append("") - case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B...0x200F, 0x2028, 0x2029, 0x2581, 0xFEFF, 0xFFFD: + case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B ... 0x200F, 0x2028, 0x2029, 0x2581, 0xFEFF, 0xFFFD: // Separators output.append(" ") case 0xFF5E: @@ -257,11 +256,13 @@ class PrecompiledNormalizer: Normalizer { } if hasFullwidthTilde { - return output + return + output .split(by: "\u{FF5E}") .map({ $0.precomposedStringWithCompatibilityMapping }) .joined(separator: "\u{FF5E}") - } else { + } + else { return output.precomposedStringWithCompatibilityMapping } } @@ -285,7 +286,12 @@ extension StringReplacePattern { switch self { case .regexp(let regexp, let replacement): let range = NSRange(text.startIndex..., in: text) - let replaced = regexp.stringByReplacingMatches(in: text, options: [], range: range, withTemplate: replacement) + let replaced = regexp.stringByReplacingMatches( + in: text, + options: [], + range: range, + withTemplate: replacement + ) return replaced case .string(let toReplace, let replacement): return text.replacingOccurrences(of: toReplace, with: replacement) diff --git a/Sources/Tokenizers/PostProcessor.swift b/Sources/Tokenizers/PostProcessor.swift index 7292006..cc41ad2 100644 --- a/Sources/Tokenizers/PostProcessor.swift +++ b/Sources/Tokenizers/PostProcessor.swift @@ -1,6 +1,6 @@ // // PostProcessor.swift -// +// // // Created by Pedro Cuenca on 17/7/23. // @@ -11,7 +11,7 @@ import Hub public protocol PostProcessor { func postProcess(tokens: [String], tokensPair: [String]?) -> [String] func callAsFunction(tokens: [String], tokensPair: [String]?) -> [String] - + init(config: Config) } @@ -34,9 +34,9 @@ struct PostProcessorFactory { let type = PostProcessorType(rawValue: typeName) switch type { case .TemplateProcessing: return TemplateProcessing(config: config) - case .ByteLevel : return ByteLevelPostProcessor(config: config) - case .RobertaProcessing : return RobertaProcessing(config: config) - default : fatalError("Unsupported PostProcessor type: \(typeName)") + case .ByteLevel: return ByteLevelPostProcessor(config: config) + case .RobertaProcessing: return RobertaProcessing(config: config) + default: fatalError("Unsupported PostProcessor type: \(typeName)") } } } @@ -44,26 +44,28 @@ struct PostProcessorFactory { class TemplateProcessing: PostProcessor { let single: [Config] let pair: [Config] - + required public init(config: Config) { guard let single = config.single?.arrayValue else { fatalError("Missing `single` processor configuration") } guard let pair = config.pair?.arrayValue else { fatalError("Missing `pair` processor configuration") } - + self.single = single self.pair = pair } - + func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] { let config = tokensPair == nil ? single : pair - + var toReturn: [String] = [] for item in config { if let specialToken = item.SpecialToken { toReturn.append(specialToken.id!.stringValue!) - } else if let sequence = item.Sequence { + } + else if let sequence = item.Sequence { if sequence.id?.stringValue == "A" { toReturn += tokens - } else if sequence.id?.stringValue == "B" { + } + else if sequence.id?.stringValue == "B" { toReturn += tokensPair! } } @@ -93,7 +95,7 @@ class RobertaProcessing: PostProcessor { self.trimOffset = config.trimOffset?.boolValue ?? true self.addPrefixSpace = config.addPrefixSpace?.boolValue ?? true } - + func postProcess(tokens: [String], tokensPair: [String]?) -> [String] { var outTokens = tokens var tokensPair = tokensPair @@ -101,7 +103,8 @@ class RobertaProcessing: PostProcessor { if addPrefixSpace { outTokens = outTokens.map({ trimExtraSpaces(token: $0) }) tokensPair = tokensPair?.map({ trimExtraSpaces(token: $0) }) - } else { + } + else { outTokens = outTokens.map({ $0.trimmingCharacters(in: .whitespaces) }) tokensPair = tokensPair?.map({ $0.trimmingCharacters(in: .whitespaces) }) } @@ -124,7 +127,7 @@ class RobertaProcessing: PostProcessor { let suffixOffset = findSuffixIndex(text: token) let prefixIndex = token.index(token.startIndex, offsetBy: prefixOffset) let suffixIndex = token.index(token.startIndex, offsetBy: token.count - suffixOffset) - return String(token[prefixIndex.. Int { diff --git a/Sources/Tokenizers/PreTokenizer.swift b/Sources/Tokenizers/PreTokenizer.swift index 5e57be3..001e1af 100644 --- a/Sources/Tokenizers/PreTokenizer.swift +++ b/Sources/Tokenizers/PreTokenizer.swift @@ -1,6 +1,6 @@ // // PreTokenizer.swift -// +// // // Created by Pedro Cuenca on 18/7/23. // @@ -25,11 +25,11 @@ extension PreTokenizer { func callAsFunction(texts: [String]) -> [String] { return preTokenize(texts: texts) } - + func callAsFunction(text: String) -> [String] { return preTokenize(text: text) } - + } enum PreTokenizerType: String { @@ -51,7 +51,7 @@ struct PreTokenizerFactory { guard let typeName = config.type?.stringValue else { return nil } let type = PreTokenizerType(rawValue: typeName) switch type { - case .Sequence : return PreTokenizerSequence(config: config) + case .Sequence: return PreTokenizerSequence(config: config) case .ByteLevel: return ByteLevelPreTokenizer(config: config) case .Punctuation: return PunctuationPreTokenizer(config: config) case .Digits: return DigitsPreTokenizer(config: config) @@ -65,12 +65,12 @@ struct PreTokenizerFactory { class PreTokenizerSequence: PreTokenizer { let preTokenizers: [PreTokenizer] - + required init(config: Config) { guard let configs = config.pretokenizers?.arrayValue else { fatalError("No pretokenizers in Sequence") } preTokenizers = configs.compactMap { PreTokenizerFactory.fromConfig(config: $0) } } - + func preTokenize(text: String) -> [String] { preTokenizers.reduce([text]) { current, preTokenizer in preTokenizer(texts: current) @@ -94,40 +94,40 @@ class WhitespacePreTokenizer: PreTokenizer { class MetaspacePreTokenizer: PreTokenizer { /// Whether to add a prefix space to the first token let addPrefixSpace: Bool - + /// Replacement character let replacement: String - + /// Optional string representation of the replacement character. let stringReplacement: String - + enum PrependScheme: String { case first case never case always - + static var defaultScheme: PrependScheme { .always } static func from(rawValue value: String?) -> PrependScheme { guard let value = value else { return defaultScheme } return PrependScheme(rawValue: value) ?? defaultScheme } } - + /// The metaspace prepend scheme, see https://github.com/huggingface/tokenizers/pull/1357 let prependScheme: PrependScheme - + required init(config: Config) { addPrefixSpace = config.addPrefixSpace?.boolValue ?? false replacement = config.replacement?.stringValue ?? " " stringReplacement = config.strRep?.stringValue ?? replacement prependScheme = PrependScheme.from(rawValue: config.prependScheme?.stringValue) } - + // https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L114 // https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/tokenizers.js#L2153 func preTokenize(text: String) -> [String] { let normalized = text.replacingOccurrences(of: " ", with: stringReplacement) - + // We add a prefix space if: // (1) The addPrefixSpace option is enabled and the normalized // token does not already start with the replacement character. @@ -145,7 +145,7 @@ class MetaspacePreTokenizer: PreTokenizer { prepend = stringReplacement } } - + // Split in `MergedWithNext` mode, although usually the input to this function is already pre-tokenized // https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L127 return (prepend + normalized).split(by: replacement, behavior: .mergedWithNext) @@ -157,13 +157,13 @@ class ByteLevelPreTokenizer: PreTokenizer { let trimOffsets: Bool let useRegex: Bool let RE = #"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"# - + required init(config: Config) { addPrefixSpace = config.addPrefixSpace?.boolValue ?? false trimOffsets = config.trimOffsets?.boolValue ?? true useRegex = config.useRegex?.boolValue ?? true } - + func preTokenize(text: String) -> [String] { // Split on whitespace and punctuation let tokens = useRegex ? text.ranges(of: RE).map({ String(text[$0]) }) : [text] @@ -252,27 +252,34 @@ extension String { func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range] { var result: [Range] = [] var start = startIndex - while let range = range(of: string, options: options, range: start.. [String] { + + func split( + by string: String, + options: CompareOptions = .regularExpression, + includeSeparators: Bool = false, + omittingEmptySubsequences: Bool = true + ) -> [String] { var result: [String] = [] var start = startIndex - while let range = range(of: string, options: options, range: start.. [String] { +extension String { + public func split(by string: String, options: CompareOptions = .regularExpression, behavior: SplitDelimiterBehavior) + -> [String] + { func mergedWithNext(ranges: [Range]) -> [Range] { var merged: [Range] = [] var currentStart = startIndex for range in ranges { if range.lowerBound == startIndex { continue } - let mergedRange = currentStart..]) -> [Range] { var merged: [Range] = [] var currentStart = startIndex for range in ranges { - let mergedRange = currentStart.. [TokenLatticeNode] { - for offset in 0...count { + for offset in 0 ... count { guard beginNodes[offset].count > 0 else { return [] } - + for rnode in beginNodes[offset] { rnode.prev = nil var bestScore: Float = 0 @@ -75,27 +75,27 @@ extension TokenLattice { bestScore = score } } - + if bestNode != nil { rnode.prev = bestNode rnode.backtraceScore = bestScore } } } - + let root = beginNodes[count][0] guard let prev = root.prev else { return [] } // TODO: the reference implementations have a few more clones here: verify var result: [TokenLatticeNode] = [] - var node = prev //.clone() + var node = prev //.clone() while node.prev != nil { result.append(node.clone()) - node = node.prev! //.clone() + node = node.prev! //.clone() } return result.reversed() } - + /// Returns the substring of the sentence to be tokenized associated to the specified node /// /// - Parameters: @@ -105,16 +105,16 @@ extension TokenLattice { func piece(_ node: TokenLatticeNode) -> any StringProtocol { let start = sentence.index(sentence.startIndex, offsetBy: node.startOffset) let end = sentence.index(start, offsetBy: node.length) - return sentence[start.. TokenLatticeNode { - TokenLatticeNode(tokenId: tokenId, startOffset: startOffset, length: length, score: score, prev: prev, backtraceScore: backtraceScore) + TokenLatticeNode( + tokenId: tokenId, + startOffset: startOffset, + length: length, + score: score, + prev: prev, + backtraceScore: backtraceScore + ) } } diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index fe29846..61f4713 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -5,10 +5,10 @@ // Created by Pedro Cuenca on 6/5/23. // -import Hub import Foundation +import Hub -enum TokenizerError : Error { +enum TokenizerError: Error { case missingConfig case missingTokenizerClassInConfig case unsupportedTokenizer(String) @@ -38,45 +38,47 @@ public protocol TokenizingModel { var unknownTokenId: Int? { get } } -public extension TokenizingModel { - func callAsFunction(_ text: String) -> [String] { +extension TokenizingModel { + public func callAsFunction(_ text: String) -> [String] { tokenize(text: text) } - func convertTokensToIds(_ tokens: [String]) -> [Int?] { + public func convertTokensToIds(_ tokens: [String]) -> [Int?] { return tokens.map { convertTokenToId($0) } } - func convertIdsToTokens(_ ids: [Int]) -> [String?] { + public func convertIdsToTokens(_ ids: [Int]) -> [String?] { return ids.map { convertIdToToken($0) } } } /// A tokenizer model that is set up with Hub configuration data public protocol PreTrainedTokenizerModel: TokenizingModel { - init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws + init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws } struct TokenizerModel { - static let knownTokenizers: [String : PreTrainedTokenizerModel.Type] = [ - "BertTokenizer" : BertTokenizer.self, - "CodeGenTokenizer" : CodeGenTokenizer.self, - "CodeLlamaTokenizer" : CodeLlamaTokenizer.self, - "FalconTokenizer" : FalconTokenizer.self, - "GemmaTokenizer" : GemmaTokenizer.self, - "GPT2Tokenizer" : GPT2Tokenizer.self, - "LlamaTokenizer" : LlamaTokenizer.self, - "T5Tokenizer" : T5Tokenizer.self, - "WhisperTokenizer" : WhisperTokenizer.self, - "CohereTokenizer" : CohereTokenizer.self, - "PreTrainedTokenizer": BPETokenizer.self + static let knownTokenizers: [String: PreTrainedTokenizerModel.Type] = [ + "BertTokenizer": BertTokenizer.self, + "CodeGenTokenizer": CodeGenTokenizer.self, + "CodeLlamaTokenizer": CodeLlamaTokenizer.self, + "FalconTokenizer": FalconTokenizer.self, + "GemmaTokenizer": GemmaTokenizer.self, + "GPT2Tokenizer": GPT2Tokenizer.self, + "LlamaTokenizer": LlamaTokenizer.self, + "T5Tokenizer": T5Tokenizer.self, + "WhisperTokenizer": WhisperTokenizer.self, + "CohereTokenizer": CohereTokenizer.self, + "PreTrainedTokenizer": BPETokenizer.self, ] static func unknownToken(from tokenizerConfig: Config) -> String? { return tokenizerConfig.unkToken?.content?.stringValue ?? tokenizerConfig.unkToken?.stringValue } - public static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws -> TokenizingModel { + public static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws + -> TokenizingModel + { guard let tokenizerClassName = tokenizerConfig.tokenizerClass?.stringValue else { throw TokenizerError.missingTokenizerClassInConfig } @@ -87,7 +89,11 @@ struct TokenizerModel { throw TokenizerError.unsupportedTokenizer(tokenizerName) } - return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) + return try tokenizerClass.init( + tokenizerConfig: tokenizerConfig, + tokenizerData: tokenizerData, + addedTokens: addedTokens + ) } } @@ -115,16 +121,16 @@ public protocol Tokenizer { var unknownTokenId: Int? { get } } -public extension Tokenizer { - func callAsFunction(_ text: String) -> [Int] { +extension Tokenizer { + public func callAsFunction(_ text: String) -> [Int] { encode(text: text) } - func convertTokensToIds(_ tokens: [String]) -> [Int?] { + public func convertTokensToIds(_ tokens: [String]) -> [Int?] { return tokens.map { convertTokenToId($0) } } - func convertIdsToTokens(_ ids: [Int]) -> [String?] { + public func convertIdsToTokens(_ ids: [Int]) -> [String?] { return ids.map { convertIdToToken($0) } } } @@ -150,11 +156,13 @@ public class PreTrainedTokenizer: Tokenizer { private let cleanUpTokenizationSpaces: Bool required public init(tokenizerConfig: Config, tokenizerData: Config) throws { - var addedTokens: [String : Int] = [:] - var specialTokens: [String : Int] = [:] + var addedTokens: [String: Int] = [:] + var specialTokens: [String: Int] = [:] for addedToken in tokenizerData.addedTokens?.arrayValue ?? [] { guard let id = addedToken.id?.intValue else { continue /* malformed: token with no id */ } - guard let content = addedToken.content?.stringValue else { continue /* malformed: token with no content */ } + guard let content = addedToken.content?.stringValue else { + continue /* malformed: token with no content */ + } addedTokens[content] = id if addedToken.special?.boolValue ?? false { @@ -171,7 +179,11 @@ public class PreTrainedTokenizer: Tokenizer { self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder) self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true - model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) + model = try TokenizerModel.from( + tokenizerConfig: tokenizerConfig, + tokenizerData: tokenizerData, + addedTokens: addedTokens + ) } func preTokenize(_ text: String) -> [String] { @@ -256,7 +268,7 @@ extension AutoTokenizer { return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } - + public static func from( modelFolder: URL, hubApi: HubApi = .shared @@ -264,20 +276,20 @@ extension AutoTokenizer { let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi) guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig } let tokenizerData = try await config.tokenizerData - + return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } } // MARK: - Tokenizer model classes -class GPT2Tokenizer : BPETokenizer {} -class FalconTokenizer : BPETokenizer {} -class LlamaTokenizer : BPETokenizer {} -class CodeGenTokenizer : BPETokenizer {} -class WhisperTokenizer : BPETokenizer {} -class GemmaTokenizer : BPETokenizer {} +class GPT2Tokenizer: BPETokenizer {} +class FalconTokenizer: BPETokenizer {} +class LlamaTokenizer: BPETokenizer {} +class CodeGenTokenizer: BPETokenizer {} +class WhisperTokenizer: BPETokenizer {} +class GemmaTokenizer: BPETokenizer {} class CodeLlamaTokenizer: BPETokenizer {} -class CohereTokenizer : BPETokenizer {} +class CohereTokenizer: BPETokenizer {} -class T5Tokenizer : UnigramTokenizer {} +class T5Tokenizer: UnigramTokenizer {} diff --git a/Sources/Tokenizers/Trie.swift b/Sources/Tokenizers/Trie.swift index 6c7f79c..c480824 100644 --- a/Sources/Tokenizers/Trie.swift +++ b/Sources/Tokenizers/Trie.swift @@ -10,21 +10,22 @@ import Foundation public struct Trie { public typealias Node = TrieNode - + var root: Node - + public init(root: Node? = nil) { self.root = root ?? Node() } } -public extension Trie { - func insert(_ element: any Sequence) { +extension Trie { + public func insert(_ element: any Sequence) { var node = root for item in element { if let child = node.children[item] { node = child - } else { + } + else { let child = Node() node.children[item] = child node = child @@ -32,14 +33,14 @@ public extension Trie { } node.isLeaf = true } - - func append(contentsOf container: any Sequence>) { + + public func append(contentsOf container: any Sequence>) { for t in container { insert(t) } } - + /// Find all leaf nodes that share a common prefix with the input sequence (usually a text) /// Returns an array - func commonPrefixSearch(_ text: any Sequence) -> [[T]] { + public func commonPrefixSearch(_ text: any Sequence) -> [[T]] { var node = root var seqs: [[T]] = [] var seq: [T] = [] @@ -53,17 +54,17 @@ public extension Trie { } return seqs } - + /// Find all leaf nodes that share a common prefix with the input sequence (usually a text) /// Returns an iterator - func commonPrefixSearchIterator(_ text: any Sequence) -> LeavesWithCommonPrefixIterator { + public func commonPrefixSearchIterator(_ text: any Sequence) -> LeavesWithCommonPrefixIterator { return LeavesWithCommonPrefixIterator(node: root, text: text) } } -public extension Trie { +extension Trie { // Only used for testing, could migrate to collection - func get(_ element: any Sequence) -> Node? { + public func get(_ element: any Sequence) -> Node? { var node = root for item in element { guard let child = node.children[item] else { return nil } @@ -79,12 +80,12 @@ public class TrieNode { var children: [T: TrieNode] = [:] } -public struct LeavesWithCommonPrefixIterator : Sequence, IteratorProtocol { +public struct LeavesWithCommonPrefixIterator: Sequence, IteratorProtocol { var node: TrieNode var text: any Sequence var seq: [T] = [] lazy var iterator = text.makeIterator() as any IteratorProtocol - + public mutating func next() -> [T]? { while true { guard let item = iterator.next() else { return nil } diff --git a/Sources/Tokenizers/UnigramTokenizer.swift b/Sources/Tokenizers/UnigramTokenizer.swift index 2fe754d..9be73e9 100644 --- a/Sources/Tokenizers/UnigramTokenizer.swift +++ b/Sources/Tokenizers/UnigramTokenizer.swift @@ -14,72 +14,72 @@ class UnigramTokenizer: PreTrainedTokenizerModel { var score: Float } let vocab: [SentencePieceToken] - + let unknownPiece: SentencePieceToken var unknownTokenScore: Float { unknownPiece.score } - + public let unknownTokenId: Int? public var unknownToken: String? { unknownPiece.token } - + let minScore: Float let tokensToIds: [String: Int] - + let bosToken: String? = " " let bosTokenId: Int? let eosToken: String? let eosTokenId: Int? - + private let trie: Trie - - required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws { + + required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws { guard let configVocab = tokenizerData.model?.vocab?.value as? [[Any]] else { throw TokenizerError.missingVocab } - + vocab = try configVocab.map { piece in guard let token = piece.first as? String else { throw TokenizerError.malformedVocab } guard let score = piece.last as? Float else { throw TokenizerError.malformedVocab } return SentencePieceToken(token: token, score: score) } - + minScore = vocab.reduce(999) { partial, token in min(partial, token.score) } - + guard let unknownTokenId = tokenizerData.model?.unkId?.intValue else { throw TokenizerError.malformedVocab } self.unknownTokenId = unknownTokenId self.unknownPiece = SentencePieceToken(token: vocab[unknownTokenId].token, score: minScore - 10) - + tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token }.enumerated().map { ($1, $0) }) - bosTokenId = tokensToIds[bosToken!] // May be nil - + bosTokenId = tokensToIds[bosToken!] // May be nil + eosToken = tokenizerConfig.eosToken?.stringValue eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken!] - + trie = Trie() trie.append(contentsOf: vocab.map { $0.token }) - + // TODO: set fuse_unk to true } - + func convertTokenToId(_ token: String) -> Int? { return tokensToIds[token] ?? self.unknownTokenId } - + func convertIdToToken(_ id: Int) -> String? { return vocab[id].token } - + func tokenize(text: String) -> [String] { var lattice = TokenLattice(sentence: text, bosTokenId: bosTokenId ?? 0, eosTokenId: eosTokenId ?? 0) - + // Populate nodes let sentence = lattice.sentence var beginPos = 0 while beginPos < sentence.count { let mblen = 1 var hasSingleNode = false - + let beginIndex = sentence.index(sentence.startIndex, offsetBy: beginPos) for token in trie.commonPrefixSearchIterator(sentence[beginIndex...]).map({ String($0) }) { guard let tokenId = tokensToIds[token] else { fatalError("Token not in vocab: \(token)") } @@ -90,7 +90,12 @@ class UnigramTokenizer: PreTrainedTokenizerModel { } } if !hasSingleNode { - lattice.insert(startOffset: beginPos, length: mblen, score: unknownTokenScore, tokenId: unknownTokenId ?? 0) + lattice.insert( + startOffset: beginPos, + length: mblen, + score: unknownTokenScore, + tokenId: unknownTokenId ?? 0 + ) } beginPos += mblen } diff --git a/Sources/Tokenizers/Utils.swift b/Sources/Tokenizers/Utils.swift index 78687ce..13e6dab 100644 --- a/Sources/Tokenizers/Utils.swift +++ b/Sources/Tokenizers/Utils.swift @@ -17,7 +17,7 @@ struct Utils { print("[\(label)] \(diff)ms") return result } - + /// Time a block in seconds and return (output, time) static func time(_ block: () -> T) -> (T, Double) { let startTime = CFAbsoluteTimeGetCurrent() @@ -25,24 +25,24 @@ struct Utils { let diff = CFAbsoluteTimeGetCurrent() - startTime return (result, diff) } - + /// Return unix timestamp in ms static func dateNow() -> Int64 { // Use `Int` when we don't support 32-bits devices/OSes anymore. // Int crashes on iPhone 5c. return Int64(Date().timeIntervalSince1970 * 1000) } - + /// Clamp a val to [min, max] static func clamp(_ val: T, _ vmin: T, _ vmax: T) -> T { return min(max(vmin, val), vmax) } - + /// Fake func that can throw. static func fakeThrowable(_ input: T) throws -> T { return input } - + /// Substring static func substr(_ s: String, _ r: Range) -> String? { let stringCount = s.count @@ -51,11 +51,11 @@ struct Utils { } let startIndex = s.index(s.startIndex, offsetBy: r.lowerBound) let endIndex = s.index(startIndex, offsetBy: r.upperBound - r.lowerBound) - return String(s[startIndex..(_ dict: Dictionary) -> Dictionary { + static func invert(_ dict: [K: V]) -> [V: K] { var inverted: [V: K] = [:] for (k, v) in dict { inverted[v] = k @@ -66,14 +66,9 @@ struct Utils { /// Checks if a character is considered Chinese /// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) static func isChineseChar(_ c: UnicodeScalar) -> Bool { - (c.value >= 0x4E00 && c.value <= 0x9FFF) || - (c.value >= 0x3400 && c.value <= 0x4DBF) || - (c.value >= 0x20000 && c.value <= 0x2A6DF) || - (c.value >= 0x2A700 && c.value <= 0x2B73F) || - (c.value >= 0x2B740 && c.value <= 0x2B81F) || - (c.value >= 0x2B820 && c.value <= 0x2CEAF) || - (c.value >= 0xF900 && c.value <= 0xFAFF) || - (c.value >= 0x2F800 && c.value <= 0x2FA1F) + (c.value >= 0x4E00 && c.value <= 0x9FFF) || (c.value >= 0x3400 && c.value <= 0x4DBF) + || (c.value >= 0x20000 && c.value <= 0x2A6DF) || (c.value >= 0x2A700 && c.value <= 0x2B73F) + || (c.value >= 0x2B740 && c.value <= 0x2B81F) || (c.value >= 0x2B820 && c.value <= 0x2CEAF) + || (c.value >= 0xF900 && c.value <= 0xFAFF) || (c.value >= 0x2F800 && c.value <= 0x2FA1F) } } - diff --git a/Sources/TransformersCLI/main.swift b/Sources/TransformersCLI/main.swift index a040946..0421da5 100644 --- a/Sources/TransformersCLI/main.swift +++ b/Sources/TransformersCLI/main.swift @@ -1,9 +1,8 @@ import ArgumentParser import CoreML import Foundation - -import Models import Generation +import Models @available(iOS 16.2, macOS 13.1, *) struct TransformersCLI: ParsableCommand { @@ -23,7 +22,7 @@ struct TransformersCLI: ParsableCommand { @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") var computeUnits: ComputeUnits = .cpuAndGPU - + func generate(model: LanguageModel, config: GenerationConfig, prompt: String, printOutput: Bool = true) { let semaphore = DispatchSemaphore(value: 0) Task.init { [config] in @@ -47,7 +46,8 @@ struct TransformersCLI: ParsableCommand { print("") print("\(tps.formatted("%.2f")) tokens/s, total time: \(completionTime.formatted("%.2f"))s") } - } catch { + } + catch { print("Error \(error)") } } @@ -56,11 +56,11 @@ struct TransformersCLI: ParsableCommand { func compile(at url: URL) throws -> URL { #if os(watchOS) - fatalError("Model compilation is not supported on watchOS") + fatalError("Model compilation is not supported on watchOS") #else - if url.pathExtension == "mlmodelc" { return url } - print("Compiling model \(url)") - return try MLModel.compileModel(at: url) + if url.pathExtension == "mlmodelc" { return url } + print("Compiling model \(url)") + return try MLModel.compileModel(at: url) #endif } @@ -69,15 +69,15 @@ struct TransformersCLI: ParsableCommand { let compiledURL = try compile(at: url) print("Loading model \(compiledURL)") let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits) - + // Using greedy generation for now var config = model.defaultGenerationConfig config.doSample = false config.maxNewTokens = maxLength - + print("Warming up...") generate(model: model, config: config, prompt: prompt, printOutput: false) - + print("Generating") generate(model: model, config: config, prompt: prompt) } @@ -98,7 +98,8 @@ enum ComputeUnits: String, ExpressibleByArgument, CaseIterable { if #available(iOS 16.2, macOS 13.1, *) { TransformersCLI.main() -} else { +} +else { print("Unsupported OS") } diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 4ab5d04..8a221a5 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -4,9 +4,10 @@ // Created by Pedro Cuenca on 20231230. // -@testable import Hub import XCTest +@testable import Hub + class HubApiTests: XCTestCase { override func setUp() { // Put setup code here. This method is called before the invocation of each test method in the class. @@ -22,7 +23,8 @@ class HubApiTests: XCTestCase { do { let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml") XCTAssertEqual(filenames.count, 13) - } catch { + } + catch { XCTFail("\(error)") } } @@ -30,7 +32,10 @@ class HubApiTests: XCTestCase { func testFilenameRetrievalWithGlob() async { do { try await { - let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.json") + let filenames = try await Hub.getFilenames( + from: "coreml-projects/Llama-2-7b-chat-coreml", + matching: "*.json" + ) XCTAssertEqual( Set(filenames), Set([ @@ -44,13 +49,17 @@ class HubApiTests: XCTestCase { // Glob patterns are case sensitive try await { - let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.JSON") + let filenames = try await Hub.getFilenames( + from: "coreml-projects/Llama-2-7b-chat-coreml", + matching: "*.JSON" + ) XCTAssertEqual( filenames, [] ) }() - } catch { + } + catch { XCTFail("\(error)") } } @@ -58,7 +67,10 @@ class HubApiTests: XCTestCase { func testFilenameRetrievalFromDirectories() async { do { // Contents of all directories matching a pattern - let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml", matching: "*.mlpackage/*") + let filenames = try await Hub.getFilenames( + from: "coreml-projects/Llama-2-7b-chat-coreml", + matching: "*.mlpackage/*" + ) XCTAssertEqual( Set(filenames), Set([ @@ -70,7 +82,8 @@ class HubApiTests: XCTestCase { ]) ) - } catch { + } + catch { XCTFail("\(error)") } } @@ -78,12 +91,16 @@ class HubApiTests: XCTestCase { func testFilenameRetrievalWithMultiplePatterns() async { do { let patterns = ["config.json", "tokenizer.json", "tokenizer_*.json"] - let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml", matching: patterns) + let filenames = try await Hub.getFilenames( + from: "coreml-projects/Llama-2-7b-chat-coreml", + matching: patterns + ) XCTAssertEqual( Set(filenames), Set(["config.json", "tokenizer.json", "tokenizer_config.json"]) ) - } catch { + } + catch { XCTFail("\(error)") } } @@ -101,7 +118,8 @@ class SnapshotDownloadTests: XCTestCase { override func tearDown() { do { try FileManager.default.removeItem(at: downloadDestination) - } catch { + } + catch { print("Can't remove test download destination \(downloadDestination), error: \(error)") } } @@ -110,14 +128,20 @@ class SnapshotDownloadTests: XCTestCase { var filenames: [String] = [] let prefix = downloadDestination.appending(path: "models/\(repo)").path.appending("/") - if let enumerator = FileManager.default.enumerator(at: url, includingPropertiesForKeys: [.isRegularFileKey], options: [.skipsHiddenFiles], errorHandler: nil) { + if let enumerator = FileManager.default.enumerator( + at: url, + includingPropertiesForKeys: [.isRegularFileKey], + options: [.skipsHiddenFiles], + errorHandler: nil + ) { for case let fileURL as URL in enumerator { do { let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey]) if resourceValues.isRegularFile == true { filenames.append(String(fileURL.path.suffix(from: prefix.endIndex))) } - } catch { + } + catch { print("Error reading file resources: \(error)") } } @@ -154,7 +178,10 @@ class SnapshotDownloadTests: XCTestCase { func testDownloadInBackground() async throws { let hubApi = HubApi(downloadBase: downloadDestination, useBackgroundSession: true) var lastProgress: Progress? = nil - let downloadedTo = try await hubApi.snapshot(from: repo, matching: "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json") { progress in + let downloadedTo = try await hubApi.snapshot( + from: repo, + matching: "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json" + ) { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") lastProgress = progress @@ -167,7 +194,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual( Set(downloadedFilenames), Set([ - "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json" ]) ) } diff --git a/Tests/HubTests/HubTests.swift b/Tests/HubTests/HubTests.swift index 9a1caea..05a827c 100644 --- a/Tests/HubTests/HubTests.swift +++ b/Tests/HubTests/HubTests.swift @@ -5,8 +5,8 @@ // import XCTest -@testable import Hub +@testable import Hub class HubTests: XCTestCase { let downloadDestination: URL = { @@ -19,7 +19,8 @@ class HubTests: XCTestCase { override func tearDown() { do { try FileManager.default.removeItem(at: downloadDestination) - } catch { + } + catch { print("Can't remove test download destination \(downloadDestination), error: \(error)") } } @@ -30,28 +31,28 @@ class HubTests: XCTestCase { do { let configLoader = LanguageModelConfigurationFromHub(modelName: "t5-base", hubApi: hubApi) let config = try await configLoader.modelConfig - + // Test leaf value (Int) guard let eos = config.eos_token_id?.intValue else { XCTFail("nil leaf value (Int)") return } XCTAssertEqual(eos, 1) - + // Test leaf value (String) guard let modelType = config.model_type?.stringValue else { XCTFail("nil leaf value (String)") return } XCTAssertEqual(modelType, "t5") - + // Test leaf value (Array) guard let architectures = config.architectures?.value as? [String] else { XCTFail("nil array") return } XCTAssertEqual(architectures, ["T5ForConditionalGeneration"]) - + // Test nested wrapper guard let taskParams = config.task_specific_params else { XCTFail("nil nested wrapper") @@ -64,11 +65,12 @@ class HubTests: XCTestCase { return } XCTAssertEqual(summarizationMaxLength, 200) - } catch { + } + catch { XCTFail("Cannot download test configuration from the Hub: \(error)") } } - + func testConfigCamelCase() async { do { let configLoader = LanguageModelConfigurationFromHub(modelName: "t5-base", hubApi: hubApi) @@ -80,20 +82,21 @@ class HubTests: XCTestCase { return } XCTAssertEqual(eos, 1) - + // Test leaf value (String) guard let modelType = config.modelType?.stringValue else { XCTFail("nil leaf value (String)") return } XCTAssertEqual(modelType, "t5") - + guard let summarizationMaxLength = config.taskSpecificParams?.summarization?.maxLength?.intValue else { XCTFail("cannot traverse nested containers") return } XCTAssertEqual(summarizationMaxLength, 200) - } catch { + } + catch { XCTFail("Cannot download test configuration from the Hub: \(error)") } } diff --git a/Tests/NormalizerTests/NormalizerTests.swift b/Tests/NormalizerTests/NormalizerTests.swift index e409375..00ab4ae 100644 --- a/Tests/NormalizerTests/NormalizerTests.swift +++ b/Tests/NormalizerTests/NormalizerTests.swift @@ -1,6 +1,7 @@ import XCTest -@testable import Tokenizers + @testable import Hub +@testable import Tokenizers class NormalizerTests: XCTestCase { @@ -22,7 +23,7 @@ class NormalizerTests: XCTestCase { let normalizer = LowercaseNormalizer(config: config) XCTAssertEqual(normalizer.normalize(text: arg), expect) } - + let config = Config(["type": NormalizerType.Lowercase.rawValue]) XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? LowercaseNormalizer) } @@ -68,11 +69,11 @@ class NormalizerTests: XCTestCase { let normalizer = NFCNormalizer(config: config) XCTAssertEqual(normalizer.normalize(text: arg), expect) } - + let config = Config(["type": NormalizerType.NFC.rawValue]) XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? NFCNormalizer) } - + func testNFKDNormalizer() { let testCases: [(String, String)] = [ ("café", "cafe\u{301}"), @@ -118,7 +119,7 @@ class NormalizerTests: XCTestCase { let config = Config(["type": NormalizerType.NFKC.rawValue]) XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? NFKCNormalizer) } - + func testBertNormalizer() { let testCases: [(String, String)] = [ ("Café", "café"), diff --git a/Tests/PostProcessorTests/PostProcessorTests.swift b/Tests/PostProcessorTests/PostProcessorTests.swift index 347bc38..6283c0c 100644 --- a/Tests/PostProcessorTests/PostProcessorTests.swift +++ b/Tests/PostProcessorTests/PostProcessorTests.swift @@ -1,73 +1,82 @@ import XCTest -@testable import Tokenizers + @testable import Hub +@testable import Tokenizers class PostProcessorTests: XCTestCase { func testRobertaProcessing() { - let testCases: [(Config, [String], [String]?, [String])] = [ + let testCases: [(Config, [String], [String]?, [String])] = [ // Should keep spaces; uneven spaces; ignore `addPrefixSpace`. ( - Config(["cls": (0, "[HEAD]") as (UInt, String), - "sep": (0, "[END]") as (UInt, String), - "trimOffset": false, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[HEAD]") as (UInt, String), + "sep": (0, "[END]") as (UInt, String), + "trimOffset": false, + "addPrefixSpace": true, + ]), [" The", " sun", "sets ", " in ", " the ", "west"], nil, ["[HEAD]", " The", " sun", "sets ", " in ", " the ", "west", "[END]"] ), // Should leave only one space around each token. ( - Config(["cls": (0, "[START]") as (UInt, String), - "sep": (0, "[BREAK]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[START]") as (UInt, String), + "sep": (0, "[BREAK]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": true, + ]), [" The ", " sun", "sets ", " in ", " the ", "west"], nil, ["[START]", " The ", " sun", "sets ", " in ", " the ", "west", "[BREAK]"] ), // Should ignore empty tokens pair. ( - Config(["cls": (0, "[START]") as (UInt, String), - "sep": (0, "[BREAK]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[START]") as (UInt, String), + "sep": (0, "[BREAK]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": true, + ]), [" The ", " sun", "sets ", " in ", " the ", "west"], [], ["[START]", " The ", " sun", "sets ", " in ", " the ", "west", "[BREAK]"] ), // Should trim all whitespace. ( - Config(["cls": (0, "[CLS]") as (UInt, String), - "sep": (0, "[SEP]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": false, - ]), + Config([ + "cls": (0, "[CLS]") as (UInt, String), + "sep": (0, "[SEP]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": false, + ]), [" The ", " sun", "sets ", " in ", " the ", "west"], nil, ["[CLS]", "The", "sun", "sets", "in", "the", "west", "[SEP]"] ), // Should add tokens. ( - Config(["cls": (0, "[CLS]") as (UInt, String), - "sep": (0, "[SEP]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[CLS]") as (UInt, String), + "sep": (0, "[SEP]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": true, + ]), [" The ", " sun", "sets ", " in ", " the ", "west"], [".", "The", " cat ", " is ", " sitting ", " on", "the ", "mat"], - ["[CLS]", " The ", " sun", "sets ", " in ", " the ", "west", "[SEP]", - "[SEP]", ".", "The", " cat ", " is ", " sitting ", " on", "the ", - "mat", "[SEP]"] + [ + "[CLS]", " The ", " sun", "sets ", " in ", " the ", "west", "[SEP]", + "[SEP]", ".", "The", " cat ", " is ", " sitting ", " on", "the ", + "mat", "[SEP]", + ] ), ( - Config(["cls": (0, "[CLS]") as (UInt, String), - "sep": (0, "[SEP]") as (UInt, String), - "trimOffset": true, - "addPrefixSpace": true, - ]), + Config([ + "cls": (0, "[CLS]") as (UInt, String), + "sep": (0, "[SEP]") as (UInt, String), + "trimOffset": true, + "addPrefixSpace": true, + ]), [" 你 ", " 好 ", ","], [" 凯 ", " 蒂 ", "!"], ["[CLS]", " 你 ", " 好 ", ",", "[SEP]", "[SEP]", " 凯 ", " 蒂 ", "!", "[SEP]"] diff --git a/Tests/PreTokenizerTests/PreTokenizerTests.swift b/Tests/PreTokenizerTests/PreTokenizerTests.swift index 3e4e064..3a75ab8 100644 --- a/Tests/PreTokenizerTests/PreTokenizerTests.swift +++ b/Tests/PreTokenizerTests/PreTokenizerTests.swift @@ -4,8 +4,9 @@ // Created by Jan Krukowski on 23/11/2023. // -import XCTest import Hub +import XCTest + @testable import Tokenizers class PreTokenizerTests: XCTestCase { @@ -119,7 +120,10 @@ class PreTokenizerTests: XCTestCase { ) XCTAssertEqual( preTokenizer1.preTokenize(text: " Hey, friend, what's up? "), - [" ", " ", " ", "Hey,", " ", " ", " ", " ", "friend,", " ", " ", " ", " ", "what's", " ", "up?", " ", " ", ""] + [ + " ", " ", " ", "Hey,", " ", " ", " ", " ", "friend,", " ", " ", " ", " ", "what's", " ", "up?", " ", + " ", "", + ] ) let preTokenizer2 = SplitPreTokenizer(config: Config(["pattern": ["Regex": "\\s"]])) @@ -133,7 +137,10 @@ class PreTokenizerTests: XCTestCase { ) XCTAssertEqual( preTokenizer2.preTokenize(text: " Hey, friend, what's up? "), - [" ", " ", " ", "Hey,", " ", " ", " ", " ", "friend,", " ", " ", " ", " ", "what's", " ", "up?", " ", " ", ""] + [ + " ", " ", " ", "Hey,", " ", " ", " ", " ", "friend,", " ", " ", " ", " ", "what's", " ", "up?", " ", + " ", "", + ] ) let preTokenizer3 = SplitPreTokenizer(config: Config(["pattern": ["Regex": "\\s"], "invert": true])) @@ -150,19 +157,22 @@ class PreTokenizerTests: XCTestCase { ["Hey,", "friend,", "what's", "up?", ""] ) } - + // https://github.com/huggingface/tokenizers/pull/1357 func testMetaspacePreTokenizer() { // Prepend "always" - let preTokenizer = MetaspacePreTokenizer(config: Config([ - "add_prefix_space": true, - "replacement": "▁", - "prepend_scheme": "always" - ])) - + let preTokenizer = MetaspacePreTokenizer( + config: Config([ + "add_prefix_space": true, + "replacement": "▁", + "prepend_scheme": "always", + ]) + ) + // TODO: different sections on let text = "Hey my friend how▁are you" - let tokens = text + let tokens = + text .split(by: "", includeSeparators: true) .flatMap { preTokenizer.preTokenize(text: $0) } diff --git a/Tests/TensorUtilsTests/LogitsWarperTests.swift b/Tests/TensorUtilsTests/LogitsWarperTests.swift index fa038c1..47722cf 100644 --- a/Tests/TensorUtilsTests/LogitsWarperTests.swift +++ b/Tests/TensorUtilsTests/LogitsWarperTests.swift @@ -4,8 +4,9 @@ // Created by Jan Krukowski on 09/12/2023. // -import XCTest import CoreML +import XCTest + @testable import TensorUtils final class LogitsWarperTests: XCTestCase { diff --git a/Tests/TensorUtilsTests/TensorUtilsTests.swift b/Tests/TensorUtilsTests/TensorUtilsTests.swift index 6355165..6071a76 100644 --- a/Tests/TensorUtilsTests/TensorUtilsTests.swift +++ b/Tests/TensorUtilsTests/TensorUtilsTests.swift @@ -4,8 +4,9 @@ // Created by Jan Krukowski on 25/11/2023. // -import XCTest import CoreML +import XCTest + @testable import TensorUtils final class TensorUtilsTests: XCTestCase { @@ -44,8 +45,8 @@ final class TensorUtilsTests: XCTestCase { } func testSoftmax() { - XCTAssertEqual(Math.softmax([]), []) - + XCTAssertEqual(Math.softmax([]), []) + let result1 = Math.softmax([3.0, 4.0, 1.0, 2.0]) XCTAssertEqual(result1, [0.23688284, 0.6439143, 0.032058604, 0.08714432], accuracy: accuracy) XCTAssertEqual(result1.reduce(0, +), 1.0, accuracy: accuracy) diff --git a/Tests/TensorUtilsTests/TestUtils.swift b/Tests/TensorUtilsTests/TestUtils.swift index 7765c8c..ae5f5df 100644 --- a/Tests/TensorUtilsTests/TestUtils.swift +++ b/Tests/TensorUtilsTests/TestUtils.swift @@ -16,7 +16,8 @@ func XCTAssertEqual( for (lhs, rhs) in zip(lhsEvaluated, rhsEvaluated) { XCTAssertEqual(lhs, rhs, accuracy: accuracy, file: file, line: line) } - } catch { + } + catch { XCTFail("Unexpected error: \(error)", file: file, line: line) } } diff --git a/Tests/TokenizersTests/BertTokenizerTests.swift b/Tests/TokenizersTests/BertTokenizerTests.swift index d30ae99..500b95c 100644 --- a/Tests/TokenizersTests/BertTokenizerTests.swift +++ b/Tests/TokenizersTests/BertTokenizerTests.swift @@ -7,9 +7,8 @@ // import XCTest -@testable import Tokenizers - +@testable import Tokenizers class BertTokenizerTests: XCTestCase { @@ -20,7 +19,7 @@ class BertTokenizerTests: XCTestCase { override func tearDown() { // Put teardown code here. This method is called after the invocation of each test method in the class. } - + lazy var bertTokenizer: BertTokenizer = { let vocab = { let url = Bundle.module.url(forResource: "bert-vocab", withExtension: "txt")! @@ -32,93 +31,94 @@ class BertTokenizerTests: XCTestCase { } return vocab }() - + return BertTokenizer(vocab: vocab, merges: nil) }() func testBasicTokenizer() { let basicTokenizer = BasicTokenizer() - + let text = "Brave gaillard, d'où [UNK] êtes vous?" let tokens = ["brave", "gaillard", ",", "d", "\'", "ou", "[UNK]", "etes", "vous", "?"] - + XCTAssertEqual( - basicTokenizer.tokenize(text: text), tokens + basicTokenizer.tokenize(text: text), + tokens ) /// Verify that `XCTAssertEqual` does what deep equality checks on arrays of strings. XCTAssertEqual(["foo", "bar"], ["foo", "bar"]) } - + /// For each Squad question tokenized by python, check that we get the same output through the `BasicTokenizer` func testFullBasicTokenizer() { let url = Bundle.module.url(forResource: "basic_tokenized_questions", withExtension: "json")! let json = try! Data(contentsOf: url) let decoder = JSONDecoder() let sampleTokens = try! decoder.decode([[String]].self, from: json) - + let basicTokenizer = BasicTokenizer() - + XCTAssertEqual(sampleTokens.count, Squad.examples.count) - + for (i, example) in Squad.examples.enumerated() { let output = basicTokenizer.tokenize(text: example.question) XCTAssertEqual(output, sampleTokens[i]) } } - + /// For each Squad question tokenized by python, check that we get the same output through the whole `BertTokenizer` func testFullBertTokenizer() { let url = Bundle.module.url(forResource: "tokenized_questions", withExtension: "json")! let json = try! Data(contentsOf: url) let decoder = JSONDecoder() let sampleTokens = try! decoder.decode([[Int]].self, from: json) - + let tokenizer = bertTokenizer - + XCTAssertEqual(sampleTokens.count, Squad.examples.count) - + for (i, example) in Squad.examples.enumerated() { let output = tokenizer.tokenizeToIds(text: example.question) XCTAssertEqual(output, sampleTokens[i]) } } - + func testMixedChineseEnglishTokenization() { let tokenizer = bertTokenizer let text = "你好,世界!Hello, world!" let expectedTokens = ["[UNK]", "[UNK]", ",", "世", "[UNK]", "!", "hello", ",", "world", "!"] let tokens = tokenizer.tokenize(text: text) - + XCTAssertEqual(tokens, expectedTokens) } - + func testPureChineseTokenization() { let tokenizer = bertTokenizer let text = "明日,大家上山看日出。" - let expectedTokens = ["明", "日", ",", "大", "家", "上", "山", "[UNK]", "日", "出","。"] + let expectedTokens = ["明", "日", ",", "大", "家", "上", "山", "[UNK]", "日", "出", "。"] let tokens = tokenizer.tokenize(text: text) - + XCTAssertEqual(tokens, expectedTokens) } - + func testChineseWithNumeralsTokenization() { let tokenizer = bertTokenizer let text = "2020年奥运会在东京举行。" let expectedTokens = ["2020", "年", "[UNK]", "[UNK]", "会", "[UNK]", "[UNK]", "京", "[UNK]", "行", "。"] let tokens = tokenizer.tokenize(text: text) - + XCTAssertEqual(tokens, expectedTokens) } - + func testChineseWithSpecialTokens() { let tokenizer = bertTokenizer let text = "[CLS] 机器学习是未来。 [SEP]" let expectedTokens = ["[CLS]", "[UNK]", "[UNK]", "学", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "。", "[SEP]"] let tokens = tokenizer.tokenize(text: text) - + XCTAssertEqual(tokens, expectedTokens) } - + func testPerformanceExample() { let tokenizer = bertTokenizer @@ -128,49 +128,52 @@ class BertTokenizerTests: XCTestCase { _ = tokenizer.tokenizeToIds(text: "Brave gaillard, d'où [UNK] êtes vous?") } } - + func testWordpieceDetokenizer() { struct QuestionTokens: Codable { let original: String let basic: [String] let wordpiece: [String] } - + let url = Bundle.module.url(forResource: "question_tokens", withExtension: "json")! let json = try! Data(contentsOf: url) let decoder = JSONDecoder() let questionTokens = try! decoder.decode([QuestionTokens].self, from: json) let tokenizer = bertTokenizer - + for question in questionTokens { - XCTAssertEqual(question.basic.joined(separator: " "), tokenizer.convertWordpieceToBasicTokenList(question.wordpiece)) + XCTAssertEqual( + question.basic.joined(separator: " "), + tokenizer.convertWordpieceToBasicTokenList(question.wordpiece) + ) } } - + func testEncoderDecoder() { let text = """ - Wake up (Wake up) - Grab a brush and put a little makeup - Hide your scars to fade away the shakeup (Hide the scars to fade away the shakeup) - Why'd you leave the keys upon the table? - Here you go, create another fable, you wanted to - Grab a brush and put a little makeup, you wanted to - Hide the scars to fade away the shakeup, you wanted to - Why'd you leave the keys upon the table? You wanted to - """ - + Wake up (Wake up) + Grab a brush and put a little makeup + Hide your scars to fade away the shakeup (Hide the scars to fade away the shakeup) + Why'd you leave the keys upon the table? + Here you go, create another fable, you wanted to + Grab a brush and put a little makeup, you wanted to + Hide the scars to fade away the shakeup, you wanted to + Why'd you leave the keys upon the table? You wanted to + """ + // Not sure if there's a way to achieve a non-destructive round-trip let decoded = """ - wake up ( wake up ) - grab a brush and put a little makeup - hide your scars to fade away the shakeup ( hide the scars to fade away the shakeup ) - why \' d you leave the keys upon the table ? - here you go , create another fable , you wanted to - grab a brush and put a little makeup , you wanted to - hide the scars to fade away the shakeup , you wanted to - why \' d you leave the keys upon the table ? you wanted to - """ - + wake up ( wake up ) + grab a brush and put a little makeup + hide your scars to fade away the shakeup ( hide the scars to fade away the shakeup ) + why \' d you leave the keys upon the table ? + here you go , create another fable , you wanted to + grab a brush and put a little makeup , you wanted to + hide the scars to fade away the shakeup , you wanted to + why \' d you leave the keys upon the table ? you wanted to + """ + let tokenizer = bertTokenizer for (line, expected) in zip(text.split(separator: "\n"), decoded.split(separator: "\n")) { let encoded = tokenizer.encode(text: String(line)) diff --git a/Tests/TokenizersTests/DecoderTests.swift b/Tests/TokenizersTests/DecoderTests.swift index efe8f53..15b54f5 100644 --- a/Tests/TokenizersTests/DecoderTests.swift +++ b/Tests/TokenizersTests/DecoderTests.swift @@ -4,18 +4,21 @@ // Created by Pedro Cuenca on 20231123. // -import XCTest import Hub +import XCTest + @testable import Tokenizers class DecoderTests: XCTestCase { // https://github.com/huggingface/tokenizers/pull/1357 func testMetaspaceDecoder() { - let decoder = MetaspaceDecoder(config: Config([ - "add_prefix_space": true, - "replacement": "▁", - ])) - + let decoder = MetaspaceDecoder( + config: Config([ + "add_prefix_space": true, + "replacement": "▁", + ]) + ) + let tokens = ["▁Hey", "▁my", "▁friend", "▁", "▁", "▁how", "▁are", "▁you"] let decoded = decoder.decode(tokens: tokens) diff --git a/Tests/TokenizersTests/FactoryTests.swift b/Tests/TokenizersTests/FactoryTests.swift index 88c1c57..0fffa40 100644 --- a/Tests/TokenizersTests/FactoryTests.swift +++ b/Tests/TokenizersTests/FactoryTests.swift @@ -1,13 +1,13 @@ // // FactoryTests.swift -// +// // // Created by Pedro Cuenca on 4/8/23. // -import XCTest -import Tokenizers import Hub +import Tokenizers +import XCTest class TestWithCustomHubDownloadLocation: XCTestCase { let downloadDestination: URL = { @@ -20,7 +20,8 @@ class TestWithCustomHubDownloadLocation: XCTestCase { override func tearDown() { do { try FileManager.default.removeItem(at: downloadDestination) - } catch { + } + catch { print("Can't remove test download destination \(downloadDestination), error: \(error)") } } @@ -32,32 +33,35 @@ class TestWithCustomHubDownloadLocation: XCTestCase { class FactoryTests: TestWithCustomHubDownloadLocation { func testFromPretrained() async throws { - let tokenizer = try await AutoTokenizer.from(pretrained: "coreml-projects/Llama-2-7b-chat-coreml", hubApi: hubApi) + let tokenizer = try await AutoTokenizer.from( + pretrained: "coreml-projects/Llama-2-7b-chat-coreml", + hubApi: hubApi + ) let inputIds = tokenizer("Today she took a train to the West") XCTAssertEqual(inputIds, [1, 20628, 1183, 3614, 263, 7945, 304, 278, 3122]) } - + func testWhisper() async throws { let tokenizer = try await AutoTokenizer.from(pretrained: "openai/whisper-large-v2", hubApi: hubApi) let inputIds = tokenizer("Today she took a train to the West") XCTAssertEqual(inputIds, [50258, 50363, 27676, 750, 1890, 257, 3847, 281, 264, 4055, 50257]) } - + func testFromModelFolder() async throws { let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"] let repo = Hub.Repo(id: "coreml-projects/Llama-2-7b-chat-coreml") let localModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload) - + let tokenizer = try await AutoTokenizer.from(modelFolder: localModelFolder, hubApi: hubApi) let inputIds = tokenizer("Today she took a train to the West") XCTAssertEqual(inputIds, [1, 20628, 1183, 3614, 263, 7945, 304, 278, 3122]) } - + func testWhisperFromModelFolder() async throws { let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"] let repo = Hub.Repo(id: "openai/whisper-large-v2") let localModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload) - + let tokenizer = try await AutoTokenizer.from(modelFolder: localModelFolder, hubApi: hubApi) let inputIds = tokenizer("Today she took a train to the West") XCTAssertEqual(inputIds, [50258, 50363, 27676, 750, 1890, 257, 3847, 281, 264, 4055, 50257]) diff --git a/Tests/TokenizersTests/SplitTests.swift b/Tests/TokenizersTests/SplitTests.swift index c14b212..e24ac35 100644 --- a/Tests/TokenizersTests/SplitTests.swift +++ b/Tests/TokenizersTests/SplitTests.swift @@ -5,65 +5,65 @@ // Created by Pedro Cuenca on 20240120. // -import XCTest import Tokenizers +import XCTest class SplitTests: XCTestCase { - func testSplitBehaviorMergedWithPrevious() { + func testSplitBehaviorMergedWithPrevious() { XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["the-", "final-", "-", "countdown"] ) - + XCTAssertEqual( "the-final--countdown-".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["the-", "final-", "-", "countdown-"] ) - + XCTAssertEqual( "the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["the-", "final-", "-", "countdown-", "-"] ) - + XCTAssertEqual( "-the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["-", "the-", "final-", "-", "countdown-", "-"] ) - + XCTAssertEqual( "--the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["-", "-", "the-", "final-", "-", "countdown-", "-"] ) } - + func testSplitBehaviorMergedWithNext() { XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext), ["the", "-final", "-", "-countdown"] ) - + XCTAssertEqual( "-the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext), ["-the", "-final", "-", "-countdown"] ) - + XCTAssertEqual( "--the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext), ["-", "-the", "-final", "-", "-countdown"] ) - + XCTAssertEqual( "--the-final--countdown-".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext), ["-", "-the", "-final", "-", "-countdown", "-"] ) } - + func testSplitBehaviorOther() { XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .isolated), ["the", "-", "final", "-", "-", "countdown"] ) - + XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .removed), ["the", "final", "countdown"] diff --git a/Tests/TokenizersTests/SquadDataset.swift b/Tests/TokenizersTests/SquadDataset.swift index 98781b6..2cca880 100644 --- a/Tests/TokenizersTests/SquadDataset.swift +++ b/Tests/TokenizersTests/SquadDataset.swift @@ -8,7 +8,6 @@ import Foundation - /// Our internal type, also used in unit tests struct SquadExample { let qaId: String @@ -49,12 +48,19 @@ struct Squad { let json = try! Data(contentsOf: url) let decoder = JSONDecoder() let squadDataset = try! decoder.decode(SquadDataset.self, from: json) - + var examples: [SquadExample] = [] for datum in squadDataset.data { for paragraph in datum.paragraphs { for qa in paragraph.qas { - let example = SquadExample(qaId: qa.id, context: paragraph.context, question: qa.question, answerText: qa.answers[0].text, startPos: qa.answers[0].answer_start, endPos: -1) // todo: remove -1 + let example = SquadExample( + qaId: qa.id, + context: paragraph.context, + question: qa.question, + answerText: qa.answers[0].text, + startPos: qa.answers[0].answer_start, + endPos: -1 + ) // todo: remove -1 examples.append(example) } } diff --git a/Tests/TokenizersTests/TokenizerTests.swift b/Tests/TokenizersTests/TokenizerTests.swift index 4b4b496..2d2479b 100644 --- a/Tests/TokenizersTests/TokenizerTests.swift +++ b/Tests/TokenizersTests/TokenizerTests.swift @@ -6,10 +6,11 @@ // Copyright © 2023 Hugging Face. All rights reserved. // -import XCTest import Hub -@testable import Tokenizers +import XCTest + @testable import Models +@testable import Tokenizers class GPT2TokenizerTests: TokenizerTests { override class var hubModelName: String? { "distilgpt2" } @@ -48,7 +49,6 @@ class T5TokenizerTests: TokenizerTests { override class var unknownTokenId: Int? { 2 } } - struct EncodedTokenizerSamplesDataset: Decodable { let text: String // Bad naming, not just for bpe. @@ -58,8 +58,7 @@ struct EncodedTokenizerSamplesDataset: Decodable { let decoded_text: String } - -typealias EdgeCasesDataset = [String : [EdgeCase]] +typealias EdgeCasesDataset = [String: [EdgeCase]] struct EdgeCase: Decodable { let input: String @@ -74,20 +73,19 @@ struct EncodedData: Decodable { let attention_mask: [Int] } - class TokenizerTester { let encodedSamplesFilename: String let unknownTokenId: Int? - + private var configuration: LanguageModelConfigurationFromHub? = nil private var edgeCases: [EdgeCase]? = nil private var _tokenizer: Tokenizer? = nil - + init(hubModelName: String, encodedSamplesFilename: String, unknownTokenId: Int?, hubApi: HubApi) { configuration = LanguageModelConfigurationFromHub(modelName: hubModelName, hubApi: hubApi) self.encodedSamplesFilename = encodedSamplesFilename self.unknownTokenId = unknownTokenId - + // Read the edge cases dataset edgeCases = { let url = Bundle.module.url(forResource: "tokenizer_tests", withExtension: "json")! @@ -98,7 +96,7 @@ class TokenizerTester { return cases[hubModelName] }() } - + lazy var dataset: EncodedTokenizerSamplesDataset = { let url = Bundle.module.url(forResource: encodedSamplesFilename, withExtension: "json")! let json = try! Data(contentsOf: url) @@ -106,22 +104,24 @@ class TokenizerTester { let dataset = try! decoder.decode(EncodedTokenizerSamplesDataset.self, from: json) return dataset }() - - + var tokenizer: Tokenizer? { get async { guard _tokenizer == nil else { return _tokenizer! } do { - guard let tokenizerConfig = try await configuration!.tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" } + guard let tokenizerConfig = try await configuration!.tokenizerConfig else { + throw "Cannot retrieve Tokenizer configuration" + } let tokenizerData = try await configuration!.tokenizerData _tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) - } catch { + } + catch { XCTFail("Cannot load tokenizer: \(error)") } return _tokenizer } } - + var tokenizerModel: TokenizingModel? { get async { // The model is not usually accessible; maybe it should @@ -129,7 +129,7 @@ class TokenizerTester { return (tokenizer as! PreTrainedTokenizer).model } } - + func testTokenize() async { let tokenized = await tokenizer?.tokenize(text: dataset.text) XCTAssertEqual( @@ -137,7 +137,7 @@ class TokenizerTester { dataset.bpe_tokens ) } - + func testEncode() async { let encoded = await tokenizer?.encode(text: dataset.text) XCTAssertEqual( @@ -145,7 +145,7 @@ class TokenizerTester { dataset.token_ids ) } - + func testDecode() async { let decoded = await tokenizer?.decode(tokens: dataset.token_ids) XCTAssertEqual( @@ -153,7 +153,7 @@ class TokenizerTester { dataset.decoded_text ) } - + /// Test encode and decode for a few edge cases func testEdgeCases() async { guard let edgeCases = edgeCases else { return } @@ -170,7 +170,7 @@ class TokenizerTester { ) } } - + func testUnknownToken() async { guard let model = await tokenizerModel else { return } XCTAssertEqual(model.unknownTokenId, unknownTokenId) @@ -183,7 +183,8 @@ class TokenizerTester { model.unknownToken, model.convertIdToToken(unknownTokenId) ) - } else { + } + else { XCTAssertNil(model.unknownTokenId) } } @@ -192,10 +193,10 @@ class TokenizerTester { class TokenizerTests: XCTestCase { // Parallel testing in Xcode (when enabled) uses different processes, so this shouldn't be a problem static var _tester: TokenizerTester? = nil - + class var hubModelName: String? { nil } class var encodedSamplesFilename: String? { nil } - + // Known id retrieved from Python, to verify it was parsed correctly class var unknownTokenId: Int? { nil } @@ -220,7 +221,8 @@ class TokenizerTests: XCTestCase { override class func tearDown() { do { try FileManager.default.removeItem(at: downloadDestination) - } catch { + } + catch { print("Can't remove test download destination \(downloadDestination), error: \(error)") } } @@ -230,25 +232,25 @@ class TokenizerTests: XCTestCase { await tester.testTokenize() } } - + func testEncode() async { if let tester = Self._tester { await tester.testEncode() } } - + func testDecode() async { if let tester = Self._tester { await tester.testDecode() } } - + func testEdgeCases() async { if let tester = Self._tester { await tester.testEdgeCases() } } - + func testUnknownToken() async { if let tester = Self._tester { await tester.testUnknownToken() diff --git a/Tests/TokenizersTests/TrieTests.swift b/Tests/TokenizersTests/TrieTests.swift index 15c54f8..051776d 100644 --- a/Tests/TokenizersTests/TrieTests.swift +++ b/Tests/TokenizersTests/TrieTests.swift @@ -6,6 +6,7 @@ // import XCTest + @testable import Tokenizers class TrieTests: XCTestCase { @@ -16,23 +17,23 @@ class TrieTests: XCTestCase { trie.insert("carp") trie.insert("car") XCTAssertEqual(trie.root.children.count, 1) - + let c = trie.get("c") XCTAssertNotNil(c) - XCTAssertEqual(c!.children.count, 1) // "a" - + XCTAssertEqual(c!.children.count, 1) // "a" + let ca = trie.get("ca") XCTAssertNotNil(ca) - XCTAssertEqual(ca!.children.count, 2) // "r", "t" - + XCTAssertEqual(ca!.children.count, 2) // "r", "t" + let car = trie.get("car") XCTAssertNotNil(car) XCTAssertTrue(car!.isLeaf) XCTAssertFalse(ca!.isLeaf) - + XCTAssertNil(trie.get("card")) } - + func testTrieCommonPrefixSearch() { // https://guillaume-be.github.io/2020-05-30/sentence_piece let trie = Trie() @@ -44,7 +45,7 @@ class TrieTests: XCTestCase { let leaves = trie.commonPrefixSearch("carpooling").map { String($0) } XCTAssertEqual(leaves, ["car", "carp"]) } - + func testTrieCommonPrefixSearchIterator() { // https://guillaume-be.github.io/2020-05-30/sentence_piece let trie = Trie()