diff --git a/.swiftformat b/.swiftformat new file mode 100644 index 0000000..6a9d678 --- /dev/null +++ b/.swiftformat @@ -0,0 +1,87 @@ +--swiftversion 5.9 +--acronyms ID,URL,UUID +--allman false +--anonymousforeach convert +--assetliterals visual-width +--asynccapturing +--beforemarks +--binarygrouping 4,8 +--categorymark "MARK: %c" +--classthreshold 0 +--closingparen balanced +--closurevoid remove +--commas always +--conflictmarkers reject +--decimalgrouping ignore +--elseposition same-line +--emptybraces spaced +--enumthreshold 0 +--exponentcase lowercase +--exponentgrouping disabled +--extensionacl on-extension +--extensionlength 0 +--extensionmark "MARK: - %t + %c" +--fractiongrouping disabled +--fragment false +--funcattributes preserve +--generictypes +--groupedextension "MARK: %c" +--guardelse auto +--header ignore +--hexgrouping 4,8 +--hexliteralcase uppercase +--ifdef no-indent +--importgrouping alpha +--indent 4 +--indentcase false +--indentstrings false +--lifecycle +--lineaftermarks true +--linebreaks lf +--markcategories true +--markextensions always +--marktypes always +--maxwidth none +--modifierorder +--nevertrailing +--nospaceoperators +--nowrapoperators +--octalgrouping 4,8 +--onelineforeach ignore +--operatorfunc spaced +--organizetypes actor,class,enum,struct +--patternlet hoist +--ranges no-space +--redundanttype infer-locals-only +--self remove +--selfrequired +--semicolons inline +--shortoptionals always +--smarttabs enabled +--someany true +--stripunusedargs unnamed-only +--structthreshold 0 +--tabwidth unspecified +--throwcapturing +--trailingclosures +--typeattributes preserve +--typeblanklines remove +--typemark "MARK: - %t" +--varattributes preserve +--voidtype void +--wraparguments preserve +--wrapcollections preserve +--wrapconditions preserve +--wrapeffects preserve +--wrapenumcases always +--wrapparameters preserve +--wrapreturntype preserve +--wrapternary default +--wraptypealiases preserve +--xcodeindentation disabled +--yodaswap always +--disable blankLineAfterImports,unusedArguments +--enable docComments +--disable enumnamespaces +--trimwhitespace nonblank-lines +--disable preferKeyPath diff --git a/Package.swift b/Package.swift index fc28be1..bc34dc7 100644 --- a/Package.swift +++ b/Package.swift @@ -13,14 +13,16 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")), - .package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0")) + .package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0")), ], targets: [ .executableTarget( name: "TransformersCLI", dependencies: [ "Models", "Generation", "Tokenizers", - .product(name: "ArgumentParser", package: "swift-argument-parser")]), + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ), .executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]), .target(name: "Hub", resources: [.process("FallbackConfigs")]), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]), diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 6cfd8ab..bfce220 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -1,13 +1,13 @@ // // Generation.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // -import Tokenizers import CoreML import TensorUtils +import Tokenizers public enum GenerationMode { case contrastiveSearch @@ -57,7 +57,7 @@ public extension Generation { let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config)) while outputTokens.count < config.maxLength { let outputs = model(outputTokens, config) - /// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case + // `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case let logits = (outputs as? MLShapedArraySlice)?.floats ?? outputs.scalars as! [Float] let (indexes, processedLogits) = logitsProcessor(logits) let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits)) @@ -92,7 +92,7 @@ public extension Generation { private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] { var logitsWarpers = [any LogitsWarper]() - if config.temperature > 0 && config.temperature != 1 { + if config.temperature > 0, config.temperature != 1 { logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature))) } if config.topK > 0 { diff --git a/Sources/Generation/GenerationConfig.swift b/Sources/Generation/GenerationConfig.swift index a9eee7b..0c7a057 100644 --- a/Sources/Generation/GenerationConfig.swift +++ b/Sources/Generation/GenerationConfig.swift @@ -1,6 +1,6 @@ // // GenerationConfig.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // @@ -14,15 +14,15 @@ public struct GenerationConfig { public var doSample = false public var numBeams = 1 public var numBeamGroups = 1 - public var penaltyAlpha: Double? = nil + public var penaltyAlpha: Double? public var temperature = 1.0 public var topK = 50 public var topP = 1.0 public var repetitionPenalty = 1.0 - public var padTokenId: Int? = nil - public var bosTokenId: Int? = nil - public var eosTokenId: Int? = nil + public var padTokenId: Int? + public var bosTokenId: Int? + public var eosTokenId: Int? public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) { self.maxLength = maxLength @@ -41,18 +41,18 @@ public struct GenerationConfig { public extension GenerationConfig { var generationMode: GenerationMode { // Exclude this case from the pattern matching below - if topK > 1 && !doSample && penaltyAlpha != nil && penaltyAlpha! > 0 { + if topK > 1, !doSample, penaltyAlpha != nil, penaltyAlpha! > 0 { return .contrastiveSearch } switch (numBeams, numBeamGroups, doSample) { - case (1, 1, false) : return .greedy - case (1, 1, true) : return .sample + case (1, 1, false): return .greedy + case (1, 1, true): return .sample case (2..., 1, false): return .beam - case (2..., 2..., _) : return .groupBeam - default : return .unsupported + case (2..., 2..., _): return .groupBeam + default: return .unsupported } } } -extension GenerationConfig: Decodable {} +extension GenerationConfig: Decodable { } diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift index b3b89ee..f52c596 100644 --- a/Sources/Hub/Downloader.swift +++ b/Sources/Hub/Downloader.swift @@ -6,13 +6,13 @@ // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE // -import Foundation import Combine +import Foundation class Downloader: NSObject, ObservableObject { private(set) var destination: URL - private let chunkSize = 10 * 1024 * 1024 // 10MB + private let chunkSize = 10 * 1024 * 1024 // 10MB enum DownloadState { case notStarted @@ -53,7 +53,7 @@ class Downloader: NSObject, ObservableObject { config.sessionSendsLaunchEvents = true } - self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil) + urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil) setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries) } @@ -106,14 +106,13 @@ class Downloader: NSObject, ObservableObject { var requestHeaders = headers ?? [:] // Populate header auth and range fields - if let authToken = authToken { + if let authToken { requestHeaders["Authorization"] = "Bearer \(authToken)" } if resumeSize > 0 { requestHeaders["Range"] = "bytes=\(resumeSize)-" } - request.timeoutInterval = timeout request.allHTTPHeaderFields = requestHeaders @@ -157,7 +156,7 @@ class Downloader: NSObject, ObservableObject { numRetries: Int, expectedSize: Int? ) async throws { - guard let session = self.urlSession else { + guard let session = urlSession else { throw DownloadError.unexpectedError } @@ -194,7 +193,7 @@ class Downloader: NSObject, ObservableObject { buffer.removeAll(keepingCapacity: true) downloadedSize += chunkSize newNumRetries = 5 - guard let expectedSize = expectedSize else { continue } + guard let expectedSize else { continue } let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0 downloadState.value = .downloading(progress) } @@ -227,7 +226,7 @@ class Downloader: NSObject, ObservableObject { // Verify the downloaded file size matches the expected size let actualSize = try tempFile.seekToEnd() - if let expectedSize = expectedSize, expectedSize != actualSize { + if let expectedSize, expectedSize != actualSize { throw DownloadError.unexpectedError } } @@ -239,16 +238,16 @@ class Downloader: NSObject, ObservableObject { stateSubscriber = downloadState.sink { state in switch state { case .completed: semaphore.signal() - case .failed: semaphore.signal() - default: break + case .failed: semaphore.signal() + default: break } } semaphore.wait() switch downloadState.value { - case .completed(let url): return url - case .failed(let error): throw error - default: throw DownloadError.unexpectedError + case let .completed(url): return url + case let .failed(error): throw error + default: throw DownloadError.unexpectedError } } @@ -265,7 +264,7 @@ extension Downloader: URLSessionDownloadDelegate { func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { do { // If the downloaded file already exists on the filesystem, overwrite it - try FileManager.default.moveDownloadedFile(from: location, to: self.destination) + try FileManager.default.moveDownloadedFile(from: location, to: destination) downloadState.value = .completed(destination) } catch { downloadState.value = .failed(error) @@ -273,7 +272,7 @@ extension Downloader: URLSessionDownloadDelegate { } func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - if let error = error { + if let error { downloadState.value = .failed(error) // } else if let response = task.response as? HTTPURLResponse { // print("HTTP response status code: \(response.statusCode)") diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index a3b0d2f..1c2cd22 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -7,7 +7,7 @@ import Foundation -public struct Hub {} +public struct Hub { } public extension Hub { enum HubClientError: LocalizedError { @@ -25,28 +25,28 @@ public extension Hub { public var errorDescription: String? { switch self { - case .authorizationRequired: - return String(localized: "Authentication required. Please provide a valid Hugging Face token.") - case .httpStatusCode(let code): - return String(localized: "HTTP error with status code: \(code)") - case .parse: - return String(localized: "Failed to parse server response.") - case .unexpectedError: - return String(localized: "An unexpected error occurred.") - case .downloadError(let message): - return String(localized: "Download failed: \(message)") - case .fileNotFound(let filename): - return String(localized: "File not found: \(filename)") - case .networkError(let error): - return String(localized: "Network error: \(error.localizedDescription)") - case .resourceNotFound(let resource): - return String(localized: "Resource not found: \(resource)") - case .configurationMissing(let file): - return String(localized: "Required configuration file missing: \(file)") - case .fileSystemError(let error): - return String(localized: "File system error: \(error.localizedDescription)") - case .parseError(let message): - return String(localized: "Parse error: \(message)") + case .authorizationRequired: + String(localized: "Authentication required. Please provide a valid Hugging Face token.") + case let .httpStatusCode(code): + String(localized: "HTTP error with status code: \(code)") + case .parse: + String(localized: "Failed to parse server response.") + case .unexpectedError: + String(localized: "An unexpected error occurred.") + case let .downloadError(message): + String(localized: "Download failed: \(message)") + case let .fileNotFound(filename): + String(localized: "File not found: \(filename)") + case let .networkError(error): + String(localized: "Network error: \(error.localizedDescription)") + case let .resourceNotFound(resource): + String(localized: "Resource not found: \(resource)") + case let .configurationMissing(file): + String(localized: "Required configuration file missing: \(file)") + case let .fileSystemError(error): + String(localized: "File system error: \(error.localizedDescription)") + case let .parseError(message): + String(localized: "Parse error: \(message)") } } } @@ -79,7 +79,7 @@ public struct Config { } func camelCase(_ string: String) -> String { - return string + string .split(separator: "_") .enumerated() .map { $0.offset == 0 ? $0.element.lowercased() : $0.element.capitalized } @@ -108,7 +108,6 @@ public struct Config { return result } - public subscript(dynamicMember member: String) -> Config? { let key = (dictionary[member as NSString] != nil ? member : uncamelCase(member)) as NSString if let value = dictionary[key] as? [NSString: Any] { @@ -120,17 +119,17 @@ public struct Config { } public var value: Any? { - return dictionary["value"] + dictionary["value"] } public var intValue: Int? { value as? Int } public var boolValue: Bool? { value as? Bool } public var stringValue: String? { value as? String } - // Instead of doing this we could provide custom classes and decode to them + /// Instead of doing this we could provide custom classes and decode to them public var arrayValue: [Config]? { guard let list = value as? [Any] else { return nil } - return list.map { Config($0 as! [NSString : Any]) } + return list.map { Config($0 as! [NSString: Any]) } } /// Tuple of token identifier and string value @@ -144,14 +143,14 @@ public class LanguageModelConfigurationFromHub { var tokenizerData: Config } - private var configPromise: Task? = nil + private var configPromise: Task? public init( modelName: String, hubApi: HubApi = .shared ) { - self.configPromise = Task.init { - return try await self.loadConfig(modelName: modelName, hubApi: hubApi) + configPromise = Task.init { + try await self.loadConfig(modelName: modelName, hubApi: hubApi) } } @@ -159,8 +158,8 @@ public class LanguageModelConfigurationFromHub { modelFolder: URL, hubApi: HubApi = .shared ) { - self.configPromise = Task { - return try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi) + configPromise = Task { + try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi) } } @@ -221,12 +220,12 @@ public class LanguageModelConfigurationFromHub { // Convert generic errors to more specific ones if let urlError = error as? URLError { switch urlError.code { - case .notConnectedToInternet, .networkConnectionLost: - throw Hub.HubClientError.networkError(urlError) - case .resourceUnavailable: - throw Hub.HubClientError.resourceNotFound(modelName) - default: - throw Hub.HubClientError.networkError(urlError) + case .notConnectedToInternet, .networkConnectionLost: + throw Hub.HubClientError.networkError(urlError) + case .resourceUnavailable: + throw Hub.HubClientError.resourceNotFound(modelName) + default: + throw Hub.HubClientError.networkError(urlError) } } else { throw error @@ -265,7 +264,8 @@ public class LanguageModelConfigurationFromHub { let chatTemplateURL = modelFolder.appending(path: "chat_template.json") if FileManager.default.fileExists(atPath: chatTemplateURL.path), let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateURL), - let chatTemplate = chatTemplateConfig.chatTemplate?.stringValue { + let chatTemplate = chatTemplateConfig.chatTemplate?.stringValue + { // Create or update tokenizer config with chat template if var configDict = tokenizerConfig?.dictionary { configDict["chat_template"] = chatTemplate @@ -284,7 +284,7 @@ public class LanguageModelConfigurationFromHub { throw error } catch { if let nsError = error as NSError? { - if nsError.domain == NSCocoaErrorDomain && nsError.code == NSFileReadNoSuchFileError { + if nsError.domain == NSCocoaErrorDomain, nsError.code == NSFileReadNoSuchFileError { throw Hub.HubClientError.fileSystemError(error) } else if nsError.domain == "NSJSONSerialization" { throw Hub.HubClientError.parseError("Invalid JSON format: \(nsError.localizedDescription)") diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 4e4c1b3..8102dc6 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -5,8 +5,8 @@ // Created by Pedro Cuenca on 20231230. // -import Foundation import CryptoKit +import Foundation import Network import os @@ -15,7 +15,7 @@ public struct HubApi { var hfToken: String? var endpoint: String var useBackgroundSession: Bool - var useOfflineMode: Bool? = nil + var useOfflineMode: Bool? private let networkMonitor = NetworkMonitor() public typealias RepoType = Hub.RepoType @@ -65,12 +65,12 @@ private extension HubApi { } }, { try? String(contentsOf: .homeDirectory.appendingPathComponent(".cache/huggingface/token"), encoding: .utf8) }, - { try? String(contentsOf: .homeDirectory.appendingPathComponent(".huggingface/token"), encoding: .utf8) } + { try? String(contentsOf: .homeDirectory.appendingPathComponent(".huggingface/token"), encoding: .utf8) }, ] return possibleTokens .lazy - .compactMap({ $0() }) - .filter({ !$0.isEmpty }) + .compactMap { $0() } + .filter { !$0.isEmpty } .first } } @@ -89,7 +89,7 @@ public extension HubApi { /// Throws error if the response code is not 20X func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) - if let hfToken = hfToken { + if let hfToken { request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") } @@ -100,14 +100,14 @@ public extension HubApi { } switch httpResponse.statusCode { - case 200..<300: - return (data, httpResponse) - case 401, 403: - throw Hub.HubClientError.authorizationRequired - case 404: - throw Hub.HubClientError.fileNotFound(url.lastPathComponent) - default: - throw Hub.HubClientError.httpStatusCode(httpResponse.statusCode) + case 200..<300: + return (data, httpResponse) + case 401, 403: + throw Hub.HubClientError.authorizationRequired + case 404: + throw Hub.HubClientError.fileNotFound(url.lastPathComponent) + default: + throw Hub.HubClientError.httpStatusCode(httpResponse.statusCode) } } catch let error as Hub.HubClientError { throw error @@ -121,7 +121,7 @@ public extension HubApi { func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) request.httpMethod = "HEAD" - if let hfToken = hfToken { + if let hfToken { request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") } request.setValue("identity", forHTTPHeaderField: "Accept-Encoding") @@ -157,15 +157,15 @@ public extension HubApi { } func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { - return try await getFilenames(from: Repo(id: repoId), matching: globs) + try await getFilenames(from: Repo(id: repoId), matching: globs) } func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { - return try await getFilenames(from: repo, matching: [glob]) + try await getFilenames(from: repo, matching: [glob]) } func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { - return try await getFilenames(from: Repo(id: repoId), matching: [glob]) + try await getFilenames(from: Repo(id: repoId), matching: [glob]) } } @@ -179,14 +179,14 @@ public extension HubApi { public var errorDescription: String? { switch self { - case .invalidMetadataError(let message): - return String(localized: "Invalid metadata: \(message)") - case .offlineModeError(let message): - return String(localized: "Offline mode error: \(message)") - case .fileIntegrityError(let message): - return String(localized: "File integrity check failed: \(message)") - case .fileWriteError(let message): - return String(localized: "Failed to write file: \(message)") + case let .invalidMetadataError(message): + String(localized: "Invalid metadata: \(message)") + case let .offlineModeError(message): + String(localized: "Offline mode error: \(message)") + case let .fileIntegrityError(message): + String(localized: "File integrity check failed: \(message)") + case let .fileWriteError(message): + String(localized: "Failed to write file: \(message)") } } } @@ -319,7 +319,6 @@ public extension HubApi { let digest = hasher.finalize() return digest.map { String(format: "%02x", $0) }.joined() - } /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 @@ -376,9 +375,9 @@ public extension HubApi { try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) } - // Note we go from Combine in Downloader to callback-based progress reporting - // We'll probably need to support Combine as well to play well with Swift UI - // (See for example PipelineLoader in swift-coreml-diffusers) + /// Note we go from Combine in Downloader to callback-based progress reporting + /// We'll probably need to support Combine as well to play well with Swift UI + /// (See for example PipelineLoader in swift-coreml-diffusers) @discardableResult func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { let localMetadata = try HubApi.shared.readDownloadMetadata(metadataPath: metadataDestination) @@ -388,7 +387,7 @@ public extension HubApi { let remoteCommitHash = remoteMetadata.commitHash ?? "" // Local file exists + metadata exists + commit_hash matches => return file - if HubApi.shared.isValidHash(hash: remoteCommitHash, pattern: HubApi.shared.commitHashPattern) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { + if HubApi.shared.isValidHash(hash: remoteCommitHash, pattern: HubApi.shared.commitHashPattern), downloaded, localMetadata != nil, localCommitHash == remoteCommitHash { return destination } @@ -396,7 +395,8 @@ public extension HubApi { guard let remoteCommitHash = remoteMetadata.commitHash, let remoteEtag = remoteMetadata.etag, let remoteSize = remoteMetadata.size, - remoteMetadata.location != "" else { + remoteMetadata.location != "" + else { throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server") } @@ -427,7 +427,7 @@ public extension HubApi { let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession, expectedSize: remoteSize) let downloadSubscriber = downloader.downloadState.sink { state in - if case .downloading(let progress) = state { + if case let .downloading(progress) = state { progressHandler(progress) } } @@ -467,13 +467,13 @@ public extension HubApi { let localMetadata = try readDownloadMetadata(metadataPath: metadataPath) - guard let localMetadata = localMetadata else { + guard let localMetadata else { throw EnvironmentError.offlineModeError(String(localized: "Metadata not available for \(fileUrl.lastPathComponent)")) } let localEtag = localMetadata.etag // LFS file so check file integrity - if self.isValidHash(hash: localEtag, pattern: self.sha256Pattern) { + if isValidHash(hash: localEtag, pattern: sha256Pattern) { let fileHash = try computeFileHash(file: fileUrl) if fileHash != localEtag { throw EnvironmentError.fileIntegrityError(String(localized: "Hash mismatch for \(fileUrl.lastPathComponent)")) @@ -509,17 +509,17 @@ public extension HubApi { @discardableResult func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) + try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } @discardableResult func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler) + try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler) } @discardableResult func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler) + try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler) } } @@ -557,7 +557,7 @@ public extension HubApi { } private func normalizeEtag(_ etag: String?) -> String? { - guard let etag = etag else { return nil } + guard let etag else { return nil } return etag.trimmingPrefix("W/").trimmingCharacters(in: CharacterSet(charactersIn: "\"")) } @@ -578,24 +578,24 @@ public extension HubApi { func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [FileMetadata] { let files = try await getFilenames(from: repo, matching: globs) let url = URL(string: "\(endpoint)/\(repo.id)/resolve/main")! // TODO: revisions - var selectedMetadata: Array = [] + var selectedMetadata: [FileMetadata] = [] for file in files { let fileURL = url.appending(path: file) - selectedMetadata.append(try await getFileMetadata(url: fileURL)) + try await selectedMetadata.append(getFileMetadata(url: fileURL)) } return selectedMetadata } func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [FileMetadata] { - return try await getFileMetadata(from: Repo(id: repoId), matching: globs) + try await getFileMetadata(from: Repo(id: repoId), matching: globs) } func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [FileMetadata] { - return try await getFileMetadata(from: repo, matching: [glob]) + try await getFileMetadata(from: repo, matching: [glob]) } func getFileMetadata(from repoId: String, matching glob: String) async throws -> [FileMetadata] { - return try await getFileMetadata(from: Repo(id: repoId), matching: [glob]) + try await getFileMetadata(from: Repo(id: repoId), matching: [glob]) } } @@ -619,11 +619,11 @@ private extension HubApi { func startMonitoring() { monitor.pathUpdateHandler = { [weak self] path in - guard let self = self else { return } + guard let self else { return } - self.isConnected = path.status == .satisfied - self.isExpensive = path.isExpensive - self.isConstrained = path.isConstrained + isConnected = path.status == .satisfied + isExpensive = path.isExpensive + isConstrained = path.isConstrained } monitor.start(queue: queue) @@ -634,7 +634,7 @@ private extension HubApi { } func shouldUseOfflineMode() -> Bool { - return !isConnected || isExpensive || isConstrained + !isConnected || isExpensive || isConstrained } deinit { @@ -646,59 +646,59 @@ private extension HubApi { /// Stateless wrappers that use `HubApi` instances public extension Hub { static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { - return try await HubApi.shared.getFilenames(from: repo, matching: globs) + try await HubApi.shared.getFilenames(from: repo, matching: globs) } static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { - return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs) + try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs) } static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { - return try await HubApi.shared.getFilenames(from: repo, matching: glob) + try await HubApi.shared.getFilenames(from: repo, matching: glob) } static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { - return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob) + try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob) } static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) + try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) } static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) + try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler) + try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler) } static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) + try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) } static func whoami(token: String) async throws -> Config { - return try await HubApi(hfToken: token).whoami() + try await HubApi(hfToken: token).whoami() } static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata { - return try await HubApi.shared.getFileMetadata(url: fileURL) + try await HubApi.shared.getFileMetadata(url: fileURL) } static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { - return try await HubApi.shared.getFileMetadata(from: repo, matching: globs) + try await HubApi.shared.getFileMetadata(from: repo, matching: globs) } static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { - return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs) + try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs) } static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] { - return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob]) + try await HubApi.shared.getFileMetadata(from: repo, matching: [glob]) } static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] { - return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob]) + try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob]) } } @@ -724,7 +724,7 @@ public extension FileManager { for case let fileURL as URL in enumerator { do { let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey, .isHiddenKey]) - if resourceValues.isRegularFile == true && resourceValues.isHidden != true { + if resourceValues.isRegularFile == true, resourceValues.isHidden != true { fileUrls.append(fileURL) } } catch { @@ -744,13 +744,14 @@ private class RedirectDelegate: NSObject, URLSessionTaskDelegate { if (300...399).contains(response.statusCode) { // Get the Location header if let locationString = response.value(forHTTPHeaderField: "Location"), - let locationUrl = URL(string: locationString) { - + let locationUrl = URL(string: locationString) + { // Check if it's a relative redirect (no host component) if locationUrl.host == nil { // For relative redirects, construct the new URL using the original request's base if let originalUrl = task.originalRequest?.url, - var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) { + var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) + { // Update the path component with the relative path components.path = locationUrl.path components.query = locationUrl.query diff --git a/Sources/HubCLI/HubCLI.swift b/Sources/HubCLI/HubCLI.swift index fb0cc72..8aa8e64 100644 --- a/Sources/HubCLI/HubCLI.swift +++ b/Sources/HubCLI/HubCLI.swift @@ -15,13 +15,12 @@ struct HubCLI: AsyncParsableCommand { } protocol SubcommandWithToken { - var token: String? { get } } extension SubcommandWithToken { var hfToken: String? { - if let token = token { return token } + if let token { return token } return try? String(contentsOfFile: defaultTokenLocation, encoding: .utf8) } } @@ -36,9 +35,9 @@ struct Download: AsyncParsableCommand, SubcommandWithToken { var asHubApiRepoType: HubApi.RepoType { switch self { - case .model: return .models - case .dataset: return .datasets - case .space: return .spaces + case .model: .models + case .dataset: .datasets + case .space: .spaces } } } @@ -91,6 +90,6 @@ struct Whoami: AsyncParsableCommand, SubcommandWithToken { extension Double { func formatted(_ format: String) -> String { - return String(format: "\(format)", self) + String(format: "\(format)", self) } } diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 9e73ad9..d45c4d7 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -1,14 +1,14 @@ // // LanguageModel.swift -// +// // // Created by Pedro Cuenca on 7/5/23. // import CoreML -import Tokenizers import Generation import Hub +import Tokenizers public class LanguageModel { public let model: MLModel @@ -25,8 +25,8 @@ public class LanguageModel { var tokenizerData: Config } - private var configuration: LanguageModelConfigurationFromHub? = nil - private var _tokenizer: Tokenizer? = nil + private var configuration: LanguageModelConfigurationFromHub? + private var _tokenizer: Tokenizer? public required init(model: MLModel) { self.model = model @@ -56,7 +56,7 @@ public class LanguageModel { maxContextLength = 128 } - self.configuration = LanguageModelConfigurationFromHub(modelName: modelName) + configuration = LanguageModelConfigurationFromHub(modelName: modelName) } } @@ -72,7 +72,8 @@ public extension LanguageModel { public extension LanguageModel { var description: String { if let description = model.modelDescription.metadata[MLModelMetadataKey.description] as? String, - !description.isEmpty { + !description.isEmpty + { return description } return model.configuration.modelDisplayName ?? "" @@ -80,8 +81,9 @@ public extension LanguageModel { /// `name_or_path` in the Python world var modelName: String { - if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String : String], - let name = userFields["co.huggingface.exporters.name"] { + if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String], + let name = userFields["co.huggingface.exporters.name"] + { return name } // This is usually the basename of the file, that's our best bet if no metadata exists @@ -106,20 +108,20 @@ public extension LanguageModel { model.modelDescription.inputDescriptionsByName[attention_mask] != nil } - // MLShapedArrayProtocol is either a MLShapedArray or a MLShapedArraySlice + /// MLShapedArrayProtocol is either a MLShapedArray or a MLShapedArraySlice func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { // TODO: exceptions // Maybe pad or truncate let maxTokens = min(tokens.count, maxContextLength) - let padLength = maxTokens >= minContextLength ? 0 : minContextLength-maxTokens + let padLength = maxTokens >= minContextLength ? 0 : minContextLength - maxTokens let inputTokens = Array(tokens[0..(scalars: inputTokens.map { Int32($0) }, shape: inputIdsShape) var inputDictionary = [inputIdsName: MLFeatureValue(shapedArray: inputIds)] if requiresAttention { let mask = Array(repeating: 1, count: maxTokens) + Array(repeating: 0, count: padLength) - let attentionMask = MLShapedArray(scalars: mask.map{ Int32($0) }, shape: inputIdsShape) + let attentionMask = MLShapedArray(scalars: mask.map { Int32($0) }, shape: inputIdsShape) inputDictionary[attention_mask] = MLFeatureValue(shapedArray: attentionMask) } let input = try! MLDictionaryFeatureProvider(dictionary: inputDictionary) @@ -201,7 +203,7 @@ public extension LanguageModel { } extension LanguageModel: TextGenerationModel { - //TODO: retrieve from the json: https://huggingface.co/nlpcloud/instruct-gpt-j-fp16/blob/main/config.json#L26 + // TODO: retrieve from the json: https://huggingface.co/nlpcloud/instruct-gpt-j-fp16/blob/main/config.json#L26 public var defaultGenerationConfig: GenerationConfig { var config = GenerationConfig(maxNewTokens: 30) switch modelName.lowercased() { @@ -219,8 +221,8 @@ public enum TokenizerError: LocalizedError { public var errorDescription: String? { switch self { - case .tokenizerConfigNotFound: - return String(localized: "Tokenizer configuration could not be found. The model may be missing required tokenizer files.", comment: "Error when tokenizer configuration is missing") + case .tokenizerConfigNotFound: + String(localized: "Tokenizer configuration could not be found. The model may be missing required tokenizer files.", comment: "Error when tokenizer configuration is missing") } } } diff --git a/Sources/Models/LanguageModelTypes.swift b/Sources/Models/LanguageModelTypes.swift index 08d7d48..6b5a11f 100644 --- a/Sources/Models/LanguageModelTypes.swift +++ b/Sources/Models/LanguageModelTypes.swift @@ -1,13 +1,13 @@ // // LanguageModelTypes.swift -// +// // // Created by Pedro Cuenca on 8/5/23. // import CoreML -import Tokenizers import Generation +import Tokenizers public protocol LanguageModelProtocol { /// `name_or_path` in the Python world @@ -37,6 +37,6 @@ public protocol TextGenerationModel: Generation, LanguageModelProtocol { public extension TextGenerationModel { @discardableResult func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async throws -> String { - try await self.generate(config: config, prompt: prompt, model: self.callAsFunction, tokenizer: self.tokenizer, callback: callback) + try await generate(config: config, prompt: prompt, model: callAsFunction, tokenizer: tokenizer, callback: callback) } } diff --git a/Sources/TensorUtils/LogitsWarper/LogitsWarper.swift b/Sources/TensorUtils/LogitsWarper/LogitsWarper.swift index ac92ebf..17fc64e 100644 --- a/Sources/TensorUtils/LogitsWarper/LogitsWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/LogitsWarper.swift @@ -6,8 +6,8 @@ public protocol LogitsWarper { func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) } -extension LogitsWarper { - public func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) { +public extension LogitsWarper { + func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) { warp(indices: indices, logits: logits) } } diff --git a/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift index 53dc0db..df58b7e 100644 --- a/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift @@ -8,6 +8,6 @@ public struct TemperatureLogitsWarper: LogitsWarper { } public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - return (indices: indices, logits: logits.map { $0 / temperature }) + (indices: indices, logits: logits.map { $0 / temperature }) } } diff --git a/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift b/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift index a236d84..7d0a123 100644 --- a/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift @@ -1,5 +1,5 @@ -import Foundation import Accelerate +import Foundation /// Top-K. /// Select the k most-probable element indices from `arr` diff --git a/Sources/TensorUtils/LogitsWarper/TopPLogitsWarper.swift b/Sources/TensorUtils/LogitsWarper/TopPLogitsWarper.swift index 3796a09..bc08952 100644 --- a/Sources/TensorUtils/LogitsWarper/TopPLogitsWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/TopPLogitsWarper.swift @@ -30,8 +30,8 @@ public struct TopPLogitsWarper: LogitsWarper { break } - let toppIndices = indexLogitProb[0 ... sliceIndex].map { indices[$0.index] } - let toppLogits = indexLogitProb[0 ... sliceIndex].map(\.logit) + let toppIndices = indexLogitProb[0...sliceIndex].map { indices[$0.index] } + let toppLogits = indexLogitProb[0...sliceIndex].map(\.logit) return (indices: toppIndices, logits: toppLogits) } } diff --git a/Sources/TensorUtils/MLMultiArray+Utils.swift b/Sources/TensorUtils/MLMultiArray+Utils.swift index ddb2760..a06cfe5 100644 --- a/Sources/TensorUtils/MLMultiArray+Utils.swift +++ b/Sources/TensorUtils/MLMultiArray+Utils.swift @@ -6,18 +6,18 @@ // Copyright © 2019 Hugging Face. All rights reserved. // -import Foundation import CoreML +import Foundation public extension MLMultiArray { /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { var shape = Array(repeating: 1, count: dims) shape[shape.count - 1] = arr.count - /// Examples: - /// dims=1 : [arr.count] - /// dims=2 : [1, arr.count] - /// + // Examples: + // dims=1 : [arr.count] + // dims=2 : [1, arr.count] + // let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .int32) let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) for (i, item) in arr.enumerated() { @@ -30,10 +30,10 @@ public extension MLMultiArray { static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray { var shape = Array(repeating: 1, count: dims) shape[shape.count - 1] = arr.count - /// Examples: - /// dims=1 : [arr.count] - /// dims=2 : [1, arr.count] - /// + // Examples: + // dims=1 : [arr.count] + // dims=2 : [1, arr.count] + // let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .float64) let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) for (i, item) in arr.enumerated() { @@ -87,7 +87,6 @@ public extension MLMultiArray { } } - public extension MLMultiArray { /// Provides a way to index n-dimensionals arrays a la numpy. enum Indexing: Equatable { @@ -108,7 +107,7 @@ public extension MLMultiArray { ) var selectDims: [Int: Int] = [:] for (i, idx) in indexing.enumerated() { - if case .select(let select) = idx { + if case let .select(select) = idx { selectDims[i] = select } } @@ -129,11 +128,11 @@ public extension MLMultiArray { ) var shape: [NSNumber] = Array(repeating: 1, count: o.shape.count) shape[sliceDim] = o.shape[sliceDim] - /// print("About to slice ndarray of shape \(o.shape) into ndarray of shape \(shape)") + // print("About to slice ndarray of shape \(o.shape) into ndarray of shape \(shape)") let arr = try! MLMultiArray(shape: shape, dataType: .double) - /// let srcPtr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - /// TODO: use srcPtr instead of array subscripting. + // let srcPtr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) + // TODO: use srcPtr instead of array subscripting. let dstPtr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) for i in 0.. String { func indent(_ x: Int) -> String { - return String(repeating: " ", count: x) + String(repeating: " ", count: x) } // This function is called recursively for every dimension. // Add an entry for this dimension to the end of the array. var indices = indices + [0] - let d = indices.count - 1 // the current dimension - let N = shape[d].intValue // how many elements in this dimension + let d = indices.count - 1 // the current dimension + let N = shape[d].intValue // how many elements in this dimension var s = "[" - if indices.count < shape.count { // not last dimension yet? + if indices.count < shape.count { // not last dimension yet? for i in 0.. { var floats: [Float] { - guard self.strides.first == 1, self.strides.count == 1 else { + guard strides.first == 1, strides.count == 1 else { // For some reason this path is slow. // If strides is not 1, we can write a Metal kernel to copy the values properly. - return self.scalars + return scalars } // Fast path: memcpy let mlArray = MLMultiArray(self) - return mlArray.floats ?? self.scalars + return mlArray.floats ?? scalars } } public extension MLShapedArraySlice { var floats: [Float] { - guard self.strides.first == 1, self.strides.count == 1 else { + guard strides.first == 1, strides.count == 1 else { // For some reason this path is slow. // If strides is not 1, we can write a Metal kernel to copy the values properly. - return self.scalars + return scalars } // Fast path: memcpy let mlArray = MLMultiArray(self) - return mlArray.floats ?? self.scalars + return mlArray.floats ?? scalars } } public extension MLMultiArray { var floats: [Float]? { - guard self.dataType == .float32 else { return nil } + guard dataType == .float32 else { return nil } - var result: [Float] = Array(repeating: 0, count: self.count) - return self.withUnsafeBytes { ptr in + var result: [Float] = Array(repeating: 0, count: count) + return withUnsafeBytes { ptr in guard let source = ptr.baseAddress else { return nil } result.withUnsafeMutableBytes { resultPtr in let dest = resultPtr.baseAddress! @@ -48,6 +48,5 @@ public extension MLMultiArray { } return result } - } } diff --git a/Sources/TensorUtils/Math.swift b/Sources/TensorUtils/Math.swift index 4050ac1..5451093 100644 --- a/Sources/TensorUtils/Math.swift +++ b/Sources/TensorUtils/Math.swift @@ -6,9 +6,9 @@ // Copyright © 2019 Hugging Face. All rights reserved. // -import Foundation import Accelerate import CoreML +import Foundation /// /// From M.I. Hollemans @@ -16,7 +16,6 @@ import CoreML /// https://github.com/hollance/CoreMLHelpers /// public struct Math { - /** Returns the index and value of the largest element in the array. @@ -146,7 +145,7 @@ public struct Math { } } // This point might be reached due to floating point inaccuracies: - return (probabilities.count - 1) + return probabilities.count - 1 } } diff --git a/Sources/TensorUtils/Weights.swift b/Sources/TensorUtils/Weights.swift index b77de6e..3444015 100644 --- a/Sources/TensorUtils/Weights.swift +++ b/Sources/TensorUtils/Weights.swift @@ -1,18 +1,16 @@ import CoreML - public struct Weights { - enum WeightsError: LocalizedError { case notSupported(message: String) case invalidFile public var errorDescription: String? { switch self { - case .notSupported(let message): - return String(localized: "The weight format '\(message)' is not supported by this application.", comment: "Error when weight format is not supported") - case .invalidFile: - return String(localized: "The weights file is invalid or corrupted.", comment: "Error when weight file is invalid") + case let .notSupported(message): + String(localized: "The weight format '\(message)' is not supported by this application.", comment: "Error when weight format is not supported") + case .invalidFile: + String(localized: "The weights file is invalid or corrupted.", comment: "Error when weight file is invalid") } } } @@ -31,19 +29,17 @@ public struct Weights { let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) { - case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf")) - case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") + case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: "gguf") + case ([0x93, 0x4E, 0x55, 0x4D], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") default: return try Safetensor.from(data: data) } } } struct Safetensor { - typealias Error = Weights.WeightsError struct Header { - struct Offset: Decodable { let dataOffsets: [Int]? let dtype: String? @@ -71,7 +67,7 @@ struct Safetensor { } static func from(data: Data) throws -> Weights { - let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) }) + let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes { $0.load(as: Int.self) } guard headerSize < data.count else { throw Error.invalidFile } let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8))) diff --git a/Sources/Tokenizers/BPETokenizer.swift b/Sources/Tokenizers/BPETokenizer.swift index 0fbf250..e0fbe31 100644 --- a/Sources/Tokenizers/BPETokenizer.swift +++ b/Sources/Tokenizers/BPETokenizer.swift @@ -1,5 +1,5 @@ // -// GPT2Tokenizer.swift +// BPETokenizer.swift // CoreMLBert // // Created by Julien Chaumond on 18/07/2019. @@ -16,23 +16,24 @@ struct BytePair: Hashable { self.a = a self.b = b } + init(tuple: [String]) { - self.a = tuple[0] - self.b = tuple[1] + a = tuple[0] + b = tuple[1] } static func == (lhs: BytePair, rhs: BytePair) -> Bool { - return lhs.a == rhs.a && lhs.b == rhs.b + lhs.a == rhs.a && lhs.b == rhs.b } + func hash(into hasher: inout Hasher) { hasher.combine(a) hasher.combine(b) } } - class BPETokenizer: PreTrainedTokenizerModel { - let bpeRanks: Dictionary + let bpeRanks: [BytePair: Int] private let tokensToIds: [NSString: Int] private let idsToTokens: [Int: NSString] @@ -48,7 +49,7 @@ class BPETokenizer: PreTrainedTokenizerModel { public let fuseUnknownTokens: Bool static func mergesFromConfig(_ config: Config?) -> [[String]]? { - guard let config = config else { return nil } + guard let config else { return nil } // New format (pushed with tokenizers >= 0.20.0): each merge is a list of 2 items if let merges = config.value as? [[String]] { return merges } @@ -60,28 +61,28 @@ class BPETokenizer: PreTrainedTokenizerModel { } } - required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws { + required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws { guard let merges = Self.mergesFromConfig(tokenizerData.model?.merges) else { fatalError("BPETokenizer requires merges") } guard let vocab = tokenizerData.model?.vocab?.dictionary as? [NSString: Int] else { throw TokenizerError.missingVocab } - var bpeRanks: Dictionary = [:] + var bpeRanks: [BytePair: Int] = [:] for (i, merge) in merges.enumerated() { let bp = BytePair(tuple: merge) bpeRanks[bp] = i } self.bpeRanks = bpeRanks - self.tokensToIds = vocab.merging(addedTokens as [NSString : Int]) { $1 } - self.idsToTokens = Utils.invert(self.tokensToIds) + tokensToIds = vocab.merging(addedTokens as [NSString: Int]) { $1 } + idsToTokens = Utils.invert(tokensToIds) // Populate tokens if let unknownToken = TokenizerModel.unknownToken(from: tokenizerConfig) { self.unknownToken = unknownToken - self.unknownTokenId = self.tokensToIds[unknownToken as NSString] + unknownTokenId = tokensToIds[unknownToken as NSString] } else { - self.unknownToken = nil - self.unknownTokenId = nil + unknownToken = nil + unknownTokenId = nil } eosToken = addedTokenAsString(tokenizerConfig.eosToken) @@ -94,17 +95,17 @@ class BPETokenizer: PreTrainedTokenizerModel { } func convertTokenToId(_ token: String) -> Int? { - return tokensToIds[token as NSString] ?? self.unknownTokenId + tokensToIds[token as NSString] ?? unknownTokenId } func convertIdToToken(_ id: Int) -> String? { - return idsToTokens[id] as String? + idsToTokens[id] as String? } func byteEncode(text: String) -> [String] { let RE = #"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"# let tokens = text.ranges(of: RE).map { String(text[$0]) } - return tokens.map { (token) -> String in + return tokens.map { token -> String in return Array(token.utf8).compactMap { byteEncoder[$0] }.joined() } } @@ -112,17 +113,17 @@ class BPETokenizer: PreTrainedTokenizerModel { func hexaEncode(text: String) -> [String] { let RE = #"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"# let tokens = text.ranges(of: RE).map { String(text[$0]) } - return tokens.flatMap { (token) -> [String] in + return tokens.flatMap { token -> [String] in return Array(token.utf8).map { String(format: "<0x%02X>", $0) } } } private func getPairs(word: [String]) -> Set { var s = Set() - for i in 0.. Bool in bpeRanks[bp] != nil } + let bigrams = pairs.filter { bp -> Bool in bpeRanks[bp] != nil } if bigrams.count == 0 { break } - let bigram = bigrams.min { (bp1, bp2) -> Bool in + let bigram = bigrams.min { bp1, bp2 -> Bool in return bpeRanks[bp1]! < bpeRanks[bp2]! }! let first = bigram.a @@ -158,8 +159,8 @@ class BPETokenizer: PreTrainedTokenizerModel { break } - if word[i] == first && i < word.count - 1 && word[i+1] == second { - newWord.append(first+second) + if word[i] == first, i < word.count - 1, word[i + 1] == second { + newWord.append(first + second) i += 2 } else { newWord.append(word[i]) @@ -178,13 +179,13 @@ class BPETokenizer: PreTrainedTokenizerModel { func tokenize(text: String) -> [String] { var tokens: [String] = [] - let bpeTokens = self.bpe(token: text).split(separator: " ").map { String($0) } + let bpeTokens = bpe(token: text).split(separator: " ").map { String($0) } for token in bpeTokens { if convertTokenToId(token) != unknownTokenId { tokens.append(token) } else { // TODO: if config.byte_fallback is False, append the unknown token instead - tokens.append(contentsOf: self.hexaEncode(text: token)) + tokens.append(contentsOf: hexaEncode(text: token)) } } return tokens diff --git a/Sources/Tokenizers/BertTokenizer.swift b/Sources/Tokenizers/BertTokenizer.swift index d06b0b1..9a6c830 100644 --- a/Sources/Tokenizers/BertTokenizer.swift +++ b/Sources/Tokenizers/BertTokenizer.swift @@ -31,27 +31,28 @@ public class BertTokenizer { bosToken: String? = nil, eosToken: String? = nil, fuseUnknownTokens: Bool = false, - doLowerCase: Bool = true - ) { + doLowerCase: Bool = true) + { self.vocab = vocab - self.ids_to_tokens = Utils.invert(vocab) - self.basicTokenizer = BasicTokenizer(doLowerCase: doLowerCase) - self.wordpieceTokenizer = WordpieceTokenizer(vocab: self.vocab) + ids_to_tokens = Utils.invert(vocab) + basicTokenizer = BasicTokenizer(doLowerCase: doLowerCase) + wordpieceTokenizer = WordpieceTokenizer(vocab: self.vocab) self.tokenizeChineseChars = tokenizeChineseChars self.bosToken = bosToken - self.bosTokenId = bosToken == nil ? nil : vocab[bosToken!] + bosTokenId = bosToken == nil ? nil : vocab[bosToken!] self.eosToken = eosToken - self.eosTokenId = eosToken == nil ? nil : vocab[eosToken!] + eosTokenId = eosToken == nil ? nil : vocab[eosToken!] self.fuseUnknownTokens = fuseUnknownTokens } - public required convenience init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws { + public required convenience init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws { guard var vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else { throw TokenizerError.missingVocab } if let addedTokens = tokenizerData.added_tokens?.dictionary["value"] as? [[String: Any]], - let pairs = addedTokens.compactMap({ ($0["content"] as? String, $0["id"] as? Int) }) as? [(String, Int)] { - vocab.merge(pairs, uniquingKeysWith: {$1}) + let pairs = addedTokens.compactMap({ ($0["content"] as? String, $0["id"] as? Int) }) as? [(String, Int)] + { + vocab.merge(pairs, uniquingKeysWith: { $1 }) } - vocab.merge(addedTokens, uniquingKeysWith: {$1}) + vocab.merge(addedTokens, uniquingKeysWith: { $1 }) let merges = tokenizerData.model?.merges?.value as? [String] let tokenizeChineseChars = tokenizerConfig.handleChineseChars?.boolValue ?? true let eosToken = tokenizerConfig.eosToken?.stringValue @@ -61,7 +62,6 @@ public class BertTokenizer { self.init(vocab: vocab, merges: merges, tokenizeChineseChars: tokenizeChineseChars, bosToken: bosToken, eosToken: eosToken, fuseUnknownTokens: fuseUnknown, doLowerCase: doLowerCase) } - public func tokenize(text: String) -> [String] { let text = tokenizeChineseCharsIfNeed(text) var tokens: [String] = [] @@ -88,22 +88,22 @@ public class BertTokenizer { /// Main entry point func tokenizeToIds(text: String) -> [Int] { - return try! convertTokensToIds(tokens: tokenize(text: text)) + try! convertTokensToIds(tokens: tokenize(text: text)) } func tokenToId(token: String) -> Int { - return vocab[token]! + vocab[token]! } /// Un-tokenization: get tokens from tokenIds func unTokenize(tokens: [Int]) -> [String] { - return tokens.compactMap { ids_to_tokens[$0] } + tokens.compactMap { ids_to_tokens[$0] } } /// Un-tokenization: func convertWordpieceToBasicTokenList(_ wordpieceTokenList: [String]) -> String { var tokenList: [String] = [] - var individualToken: String = "" + var individualToken = "" for token in wordpieceTokenList { if token.starts(with: "##") { @@ -137,7 +137,6 @@ public class BertTokenizer { } } - extension BertTokenizer: PreTrainedTokenizerModel { public var unknownToken: String? { wordpieceTokenizer.unkToken } public var unknownTokenId: Int? { vocab[unknownToken!] } @@ -150,15 +149,14 @@ extension BertTokenizer: PreTrainedTokenizerModel { } public func convertTokenToId(_ token: String) -> Int? { - return vocab[token] ?? unknownTokenId + vocab[token] ?? unknownTokenId } public func convertIdToToken(_ id: Int) -> String? { - return ids_to_tokens[id] + ids_to_tokens[id] } } - class BasicTokenizer { let doLowerCase: Bool @@ -167,7 +165,7 @@ class BasicTokenizer { } let neverSplit = [ - "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]" + "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", ] func maybeStripAccents(_ text: String) -> String { @@ -182,7 +180,7 @@ class BasicTokenizer { func tokenize(text: String) -> [String] { let splitTokens = maybeStripAccents(text).components(separatedBy: NSCharacterSet.whitespaces) - let tokens = splitTokens.flatMap({ (token: String) -> [String] in + let tokens = splitTokens.flatMap { (token: String) -> [String] in if neverSplit.contains(token) { return [token] } @@ -203,7 +201,7 @@ class BasicTokenizer { toks.append(currentTok) } return toks - }) + } return tokens } } @@ -214,11 +212,11 @@ extension Character { if isPunctuation { return true } if let value = unicodeScalars.first?.value { switch value { - case 33...47: return true - case 58...64: return true - case 91...96: return true - case 123...126: return true - default: return false + case 33...47: return true + case 58...64: return true + case 91...96: return true + case 123...126: return true + default: return false } } return false @@ -247,7 +245,7 @@ class WordpieceTokenizer { var subTokens: [String] = [] while start < word.count { var end = word.count - var cur_substr: String? = nil + var cur_substr: String? while start < end { var substr = Utils.substr(word, start.. 0 { diff --git a/Sources/Tokenizers/ByteEncoder.swift b/Sources/Tokenizers/ByteEncoder.swift index e636ca6..1d0532a 100644 --- a/Sources/Tokenizers/ByteEncoder.swift +++ b/Sources/Tokenizers/ByteEncoder.swift @@ -8,7 +8,7 @@ import Foundation -let byteEncoder: Dictionary = [ +let byteEncoder: [UTF8.CodeUnit: String] = [ 33: "!", 34: "\"", 35: "#", diff --git a/Sources/Tokenizers/Decoder.swift b/Sources/Tokenizers/Decoder.swift index 81d0d60..c508202 100644 --- a/Sources/Tokenizers/Decoder.swift +++ b/Sources/Tokenizers/Decoder.swift @@ -1,6 +1,6 @@ // // Decoder.swift -// +// // // Created by Pedro Cuenca on 17/7/23. // @@ -17,7 +17,7 @@ public protocol Decoder { extension Decoder { func callAsFunction(tokens: [String]) -> [String] { - return decode(tokens: tokens) + decode(tokens: tokens) } } @@ -36,19 +36,19 @@ enum DecoderType: String { struct DecoderFactory { static func fromConfig(config: Config?, addedTokens: Set? = nil) -> Decoder? { // TODO: not sure if we need to include `addedTokens` in all the decoder initializers (and the protocol) - guard let config = config else { return nil } + guard let config else { return nil } guard let typeName = config.type?.stringValue else { return nil } let type = DecoderType(rawValue: typeName) switch type { - case .Sequence : return DecoderSequence(config: config) - case .ByteLevel : return ByteLevelDecoder(config: config, addedTokens: addedTokens) - case .Replace : return ReplaceDecoder(config: config) + case .Sequence: return DecoderSequence(config: config) + case .ByteLevel: return ByteLevelDecoder(config: config, addedTokens: addedTokens) + case .Replace: return ReplaceDecoder(config: config) case .ByteFallback: return ByteFallbackDecoder(config: config) - case .Fuse : return FuseDecoder(config: config) - case .Strip : return StripDecoder(config: config) - case .Metaspace : return MetaspaceDecoder(config: config) - case .WordPiece : return WordPieceDecoder(config: config) - default : fatalError("Unsupported Decoder type: \(typeName)") + case .Fuse: return FuseDecoder(config: config) + case .Strip: return StripDecoder(config: config) + case .Metaspace: return MetaspaceDecoder(config: config) + case .WordPiece: return WordPieceDecoder(config: config) + default: fatalError("Unsupported Decoder type: \(typeName)") } } } @@ -57,13 +57,13 @@ class WordPieceDecoder: Decoder { let prefix: String let cleanup: Bool - // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L31 + /// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L31 private let re = try! NSRegularExpression(pattern: "\\s(\\.|\\?|\\!|\\,|'\\s|n't|'m|'s|'ve|'re)", options: []) - required public init(config: Config) { + public required init(config: Config) { guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") } self.prefix = prefix - self.cleanup = config.cleanup?.boolValue ?? false + cleanup = config.cleanup?.boolValue ?? false } func decode(tokens: [String]) -> [String] { @@ -74,7 +74,7 @@ class WordPieceDecoder: Decoder { } } - // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L40 + /// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L40 private func cleanUpTokenization(_ token: String) -> String { let range = NSRange(location: 0, length: token.utf16.count) return re.stringByReplacingMatches(in: token, options: [], range: range, withTemplate: "$1") @@ -85,7 +85,7 @@ class WordPieceDecoder: Decoder { class DecoderSequence: Decoder { let decoders: [Decoder] - required public init(config: Config) { + public required init(config: Config) { guard let configs = config.decoders?.arrayValue else { fatalError("No decoders in Sequence") } decoders = configs.compactMap { DecoderFactory.fromConfig(config: $0) } } @@ -100,8 +100,8 @@ class DecoderSequence: Decoder { class ByteLevelDecoder: Decoder { let addedTokens: Set - required public init(config: Config) { - self.addedTokens = [] + public required init(config: Config) { + addedTokens = [] } init(config: Config, addedTokens: Set?) { @@ -142,25 +142,25 @@ class ByteLevelDecoder: Decoder { class ReplaceDecoder: Decoder { let pattern: StringReplacePattern? - required public init(config: Config) { - self.pattern = StringReplacePattern.from(config: config) + public required init(config: Config) { + pattern = StringReplacePattern.from(config: config) } func decode(tokens: [String]) -> [String] { - guard let pattern = pattern else { return tokens } + guard let pattern else { return tokens } return tokens.map { pattern.replace($0) } } } class ByteFallbackDecoder: Decoder { - required public init(config: Config) {} + public required init(config: Config) { } func decode(tokens: [String]) -> [String] { var newTokens: [String] = [] var byteTokens: [Int] = [] func parseByte(_ token: String) -> Int? { - guard token.count == 6 && token.hasPrefix("<0x") && token.hasSuffix(">") else { + guard token.count == 6, token.hasPrefix("<0x"), token.hasSuffix(">") else { return nil } let startIndex = token.index(token.startIndex, offsetBy: 3) @@ -186,7 +186,7 @@ class ByteFallbackDecoder: Decoder { } class FuseDecoder: Decoder { - required public init(config: Config) {} + public required init(config: Config) { } func decode(tokens: [String]) -> [String] { [tokens.joined(separator: "")] @@ -198,7 +198,7 @@ class StripDecoder: Decoder { let start: Int let stop: Int - required public init(config: Config) { + public required init(config: Config) { guard let content = config.content?.stringValue else { fatalError("Incorrect StripDecoder configuration: can't parse `content`.") } guard let start = config.start?.intValue else { fatalError("Incorrect StripDecoder configuration: can't parse `start`.") } guard let stop = config.stop?.intValue else { fatalError("Incorrect StripDecoder configuration: can't parse `stop`.") } @@ -218,7 +218,7 @@ class MetaspaceDecoder: Decoder { let addPrefixSpace: Bool let replacement: String - required public init(config: Config) { + public required init(config: Config) { addPrefixSpace = config.addPrefixSpace?.boolValue ?? false replacement = config.replacement?.stringValue ?? "_" } @@ -227,19 +227,19 @@ class MetaspaceDecoder: Decoder { var replaced = tokens.map { token in token.replacingOccurrences(of: replacement, with: " ") } - if addPrefixSpace && replaced.first?.starts(with: " ") ?? false { + if addPrefixSpace, replaced.first?.starts(with: " ") ?? false { replaced[0].removeFirst() } return replaced } } -// We could use firstIndex(where:), lastIndex(where:) for possibly better efficiency (and do both ends at once) +/// We could use firstIndex(where:), lastIndex(where:) for possibly better efficiency (and do both ends at once) public extension String { func trimmingFromStart(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 - while trimmed < upto && result.first == character { + while trimmed < upto, result.first == character { result.removeFirst() trimmed += 1 } @@ -249,7 +249,7 @@ public extension String { func trimmingFromEnd(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 - while trimmed < upto && result.last == character { + while trimmed < upto, result.last == character { result.removeLast() trimmed += 1 } diff --git a/Sources/Tokenizers/Normalizer.swift b/Sources/Tokenizers/Normalizer.swift index e9730c8..578ecd1 100644 --- a/Sources/Tokenizers/Normalizer.swift +++ b/Sources/Tokenizers/Normalizer.swift @@ -17,7 +17,7 @@ public protocol Normalizer { extension Normalizer { func callAsFunction(text: String) -> String { - return normalize(text: text) + normalize(text: text) } } @@ -40,7 +40,7 @@ enum NormalizerType: String { struct NormalizerFactory { static func fromConfig(config: Config?) -> Normalizer? { - guard let config = config else { return nil } + guard let config else { return nil } guard let typeName = config.type?.stringValue else { return nil } let type = NormalizerType(rawValue: typeName) switch type { @@ -64,7 +64,7 @@ struct NormalizerFactory { class NormalizerSequence: Normalizer { let normalizers: [Normalizer] - required public init(config: Config) { + public required init(config: Config) { guard let configs = config.normalizers?.arrayValue else { fatalError("No normalizers in Sequence") } @@ -81,30 +81,30 @@ class NormalizerSequence: Normalizer { class PrependNormalizer: Normalizer { let prepend: String - required public init(config: Config) { + public required init(config: Config) { prepend = config.prepend?.stringValue ?? "" } public func normalize(text: String) -> String { - return prepend + text + prepend + text } } class ReplaceNormalizer: Normalizer { let pattern: StringReplacePattern? - required public init(config: Config) { - self.pattern = StringReplacePattern.from(config: config) + public required init(config: Config) { + pattern = StringReplacePattern.from(config: config) } public func normalize(text: String) -> String { - guard let pattern = pattern else { return text } + guard let pattern else { return text } return pattern.replace(text) } } class LowercaseNormalizer: Normalizer { - required public init(config: Config) {} + public required init(config: Config) { } public func normalize(text: String) -> String { text.lowercased() @@ -112,7 +112,7 @@ class LowercaseNormalizer: Normalizer { } class NFDNormalizer: Normalizer { - required public init(config: Config) {} + public required init(config: Config) { } public func normalize(text: String) -> String { text.decomposedStringWithCanonicalMapping @@ -120,7 +120,7 @@ class NFDNormalizer: Normalizer { } class NFCNormalizer: Normalizer { - required public init(config: Config) {} + public required init(config: Config) { } public func normalize(text: String) -> String { text.precomposedStringWithCanonicalMapping @@ -128,7 +128,7 @@ class NFCNormalizer: Normalizer { } class NFKDNormalizer: Normalizer { - required init(config: Config) {} + required init(config: Config) { } func normalize(text: String) -> String { text.decomposedStringWithCompatibilityMapping @@ -136,7 +136,7 @@ class NFKDNormalizer: Normalizer { } class NFKCNormalizer: Normalizer { - required init(config: Config) {} + required init(config: Config) { } func normalize(text: String) -> String { text.precomposedStringWithCompatibilityMapping @@ -150,10 +150,10 @@ class BertNormalizer: Normalizer { let shouldLowercase: Bool required init(config: Config) { - self.shouldCleanText = config.cleanText?.boolValue ?? true - self.shouldHandleChineseChars = config.handleChineseChars?.boolValue ?? true - self.shouldLowercase = config.lowercase?.boolValue ?? true - self.shouldStripAccents = config.stripAccents?.boolValue ?? shouldLowercase + shouldCleanText = config.cleanText?.boolValue ?? true + shouldHandleChineseChars = config.handleChineseChars?.boolValue ?? true + shouldLowercase = config.lowercase?.boolValue ?? true + shouldStripAccents = config.stripAccents?.boolValue ?? shouldLowercase } func normalize(text: String) -> String { @@ -177,9 +177,9 @@ class BertNormalizer: Normalizer { private func cleanText(text: String) -> String { text.map { c in guard let scalar = c.unicodeScalars.first, - scalar.value != 0x0, - scalar.value != 0xFFFD, - !isControl(scalar) + scalar.value != 0x0, + scalar.value != 0xFFFD, + !isControl(scalar) else { return "\(c)" } // Replace whitespace: \t, \n, \r @@ -195,11 +195,11 @@ class BertNormalizer: Normalizer { private func isControl(_ c: UnicodeScalar) -> Bool { if c.value == 0x009 || c.value == 0x00A || c.value == 0x000D { // Except \t, \n, \r that will be spaces. - return false + false } else { // https://unicode.org/reports/tr44/#GC_Values_Table // Other Cc | Cf | Cs | Co | Cn - return isOther(c.properties.generalCategory) + isOther(c.properties.generalCategory) } } @@ -221,14 +221,14 @@ class BertNormalizer: Normalizer { private func stripAccents(text: String) -> String { // This might be the same as `text.folding(options: .diacriticInsensitive, locale: nil)` String(text.decomposedStringWithCanonicalMapping.unicodeScalars.filter { scalar in - !(0x0300 <= scalar.value && scalar.value <= 0x036F) + !(scalar.value >= 0x0300 && scalar.value <= 0x036F) }) } } class PrecompiledNormalizer: Normalizer { // TODO: use `precompiledCharsmap` (base64-encoded string) from the configuration - required init(config: Config) {} + required init(config: Config) { } func normalize(text: String) -> String { // TODO: This is a simplified implementation. @@ -236,7 +236,7 @@ class PrecompiledNormalizer: Normalizer { // https://github.com/xenova/transformers.js/blob/main/src/tokenizers.js#L2237-L2247 // - For a proper implementation, see: // https://github.com/huggingface/tokenizers/blob/b58227c7f1ccf8b73ee2268354336da56d91e492/tokenizers/src/normalizers/precompiled.rs#L36 - var output: String = "" + var output = "" var hasFullwidthTilde = false for scalar in text.unicodeScalars { @@ -245,7 +245,7 @@ class PrecompiledNormalizer: Normalizer { // Non-printing control characters output.append("") case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B...0x200F, 0x2028, 0x2029, 0x2581, - 0xFEFF, 0xFFFD: + 0xFEFF, 0xFFFD: // Separators output.append(" ") case 0xFF5E: @@ -259,9 +259,9 @@ class PrecompiledNormalizer: Normalizer { if hasFullwidthTilde { return output - .split(by: "\u{FF5E}") - .map({ $0.precomposedStringWithCompatibilityMapping }) - .joined(separator: "\u{FF5E}") + .split(by: "\u{FF5E}") + .map { $0.precomposedStringWithCompatibilityMapping } + .joined(separator: "\u{FF5E}") } else { return output.precomposedStringWithCompatibilityMapping } @@ -269,7 +269,7 @@ class PrecompiledNormalizer: Normalizer { } class StripAccentsNormalizer: Normalizer { - required init(config: Config) {} + required init(config: Config) { } func normalize(text: String) -> String { text.precomposedStringWithCompatibilityMapping @@ -281,8 +281,8 @@ class StripNormalizer: Normalizer { let rightStrip: Bool required init(config: Config) { - self.leftStrip = config.stripLeft?.boolValue ?? true - self.rightStrip = config.stripRight?.boolValue ?? true + leftStrip = config.stripLeft?.boolValue ?? true + rightStrip = config.stripRight?.boolValue ?? true } func normalize(text: String) -> String { @@ -308,12 +308,13 @@ enum StringReplacePattern { extension StringReplacePattern { func replace(_ text: String) -> String { switch self { - case .regexp(let regexp, let replacement): + case let .regexp(regexp, replacement): let range = NSRange(text.startIndex..., in: text) let replaced = regexp.stringByReplacingMatches( - in: text, options: [], range: range, withTemplate: replacement) + in: text, options: [], range: range, withTemplate: replacement + ) return replaced - case .string(let toReplace, let replacement): + case let .string(toReplace, replacement): return text.replacingOccurrences(of: toReplace, with: replacement) } } diff --git a/Sources/Tokenizers/PostProcessor.swift b/Sources/Tokenizers/PostProcessor.swift index 0b26415..693cd75 100644 --- a/Sources/Tokenizers/PostProcessor.swift +++ b/Sources/Tokenizers/PostProcessor.swift @@ -1,6 +1,6 @@ // // PostProcessor.swift -// +// // // Created by Pedro Cuenca on 17/7/23. // @@ -17,7 +17,7 @@ public protocol PostProcessor { extension PostProcessor { func callAsFunction(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { - return postProcess(tokens: tokens, tokensPair: tokensPair, addSpecialTokens: addSpecialTokens) + postProcess(tokens: tokens, tokensPair: tokensPair, addSpecialTokens: addSpecialTokens) } } @@ -31,16 +31,16 @@ enum PostProcessorType: String { struct PostProcessorFactory { static func fromConfig(config: Config?) -> PostProcessor? { - guard let config = config else { return nil } + guard let config else { return nil } guard let typeName = config.type?.stringValue else { return nil } let type = PostProcessorType(rawValue: typeName) switch type { - case .TemplateProcessing : return TemplateProcessing(config: config) - case .ByteLevel : return ByteLevelPostProcessor(config: config) - case .RobertaProcessing : return RobertaProcessing(config: config) - case .BertProcessing : return BertProcessing(config: config) - case .Sequence : return SequenceProcessing(config: config) - default : fatalError("Unsupported PostProcessor type: \(typeName)") + case .TemplateProcessing: return TemplateProcessing(config: config) + case .ByteLevel: return ByteLevelPostProcessor(config: config) + case .RobertaProcessing: return RobertaProcessing(config: config) + case .BertProcessing: return BertProcessing(config: config) + case .Sequence: return SequenceProcessing(config: config) + default: fatalError("Unsupported PostProcessor type: \(typeName)") } } } @@ -49,7 +49,7 @@ class TemplateProcessing: PostProcessor { let single: [Config] let pair: [Config] - required public init(config: Config) { + public required init(config: Config) { guard let single = config.single?.arrayValue else { fatalError("Missing `single` processor configuration") } guard let pair = config.pair?.arrayValue else { fatalError("Missing `pair` processor configuration") } @@ -79,7 +79,7 @@ class TemplateProcessing: PostProcessor { } class ByteLevelPostProcessor: PostProcessor { - required public init(config: Config) {} + public required init(config: Config) { } func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { tokens } } @@ -91,13 +91,13 @@ class RobertaProcessing: PostProcessor { /// Keep one space character on each side. Depends on `trimOffsets` being `true`. private let addPrefixSpace: Bool - required public init(config: Config) { + public required init(config: Config) { guard let sep = config.sep?.tokenValue else { fatalError("Missing `sep` processor configuration") } guard let cls = config.cls?.tokenValue else { fatalError("Missing `cls` processor configuration") } self.sep = sep self.cls = cls - self.trimOffset = config.trimOffset?.boolValue ?? true - self.addPrefixSpace = config.addPrefixSpace?.boolValue ?? true + trimOffset = config.trimOffset?.boolValue ?? true + addPrefixSpace = config.addPrefixSpace?.boolValue ?? true } func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] { @@ -105,19 +105,19 @@ class RobertaProcessing: PostProcessor { var tokensPair = tokensPair if trimOffset { if addPrefixSpace { - outTokens = outTokens.map({ trimExtraSpaces(token: $0) }) - tokensPair = tokensPair?.map({ trimExtraSpaces(token: $0) }) - } else { - outTokens = outTokens.map({ $0.trimmingCharacters(in: .whitespaces) }) - tokensPair = tokensPair?.map({ $0.trimmingCharacters(in: .whitespaces) }) + outTokens = outTokens.map { trimExtraSpaces(token: $0) } + tokensPair = tokensPair?.map { trimExtraSpaces(token: $0) } + } else { + outTokens = outTokens.map { $0.trimmingCharacters(in: .whitespaces) } + tokensPair = tokensPair?.map { $0.trimmingCharacters(in: .whitespaces) } } } - outTokens = [self.cls.1] + outTokens + [self.sep.1] - if let tokensPair = tokensPair, !tokensPair.isEmpty { + outTokens = [cls.1] + outTokens + [sep.1] + if let tokensPair, !tokensPair.isEmpty { // Yes, it adds another `sep`. // https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/roberta/hub_interface.py#L58-L65 - outTokens += [self.sep.1] + tokensPair + [self.sep.1] + outTokens += [sep.1] + tokensPair + [sep.1] } return outTokens @@ -148,7 +148,7 @@ class BertProcessing: PostProcessor { private let sep: (UInt, String) private let cls: (UInt, String) - required public init(config: Config) { + public required init(config: Config) { guard let sep = config.sep?.tokenValue else { fatalError("Missing `sep` processor configuration") } guard let cls = config.cls?.tokenValue else { fatalError("Missing `cls` processor configuration") } self.sep = sep @@ -158,9 +158,9 @@ class BertProcessing: PostProcessor { func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] { guard addSpecialTokens else { return tokens + (tokensPair ?? []) } - var outTokens = [self.cls.1] + tokens + [self.sep.1] - if let tokensPair = tokensPair, !tokensPair.isEmpty { - outTokens += tokensPair + [self.sep.1] + var outTokens = [cls.1] + tokens + [sep.1] + if let tokensPair, !tokensPair.isEmpty { + outTokens += tokensPair + [sep.1] } return outTokens @@ -170,12 +170,12 @@ class BertProcessing: PostProcessor { class SequenceProcessing: PostProcessor { private let processors: [PostProcessor] - required public init(config: Config) { + public required init(config: Config) { guard let processorConfigs = config.processors?.arrayValue else { fatalError("Missing `processors` configuration") } - self.processors = processorConfigs.compactMap { PostProcessorFactory.fromConfig(config: $0) } + processors = processorConfigs.compactMap { PostProcessorFactory.fromConfig(config: $0) } } func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] { @@ -185,7 +185,7 @@ class SequenceProcessing: PostProcessor { for processor in processors { let processed = processor.postProcess(tokens: currentTokens, tokensPair: currentTokensPair, addSpecialTokens: addSpecialTokens) currentTokens = processed - currentTokensPair = nil // After the first processor, we no longer have a separate pair + currentTokensPair = nil // After the first processor, we no longer have a separate pair } return currentTokens diff --git a/Sources/Tokenizers/PreTokenizer.swift b/Sources/Tokenizers/PreTokenizer.swift index d7eb972..9bb0ddf 100644 --- a/Sources/Tokenizers/PreTokenizer.swift +++ b/Sources/Tokenizers/PreTokenizer.swift @@ -1,6 +1,6 @@ // // PreTokenizer.swift -// +// // // Created by Pedro Cuenca on 18/7/23. // @@ -29,11 +29,11 @@ extension PreTokenizer { } func callAsFunction(texts: [String], options: PreTokenizerOptions = [.firstSection]) -> [String] { - return preTokenize(texts: texts, options: options) + preTokenize(texts: texts, options: options) } func callAsFunction(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { - return preTokenize(text: text, options: options) + preTokenize(text: text, options: options) } } @@ -47,17 +47,17 @@ enum PreTokenizerType: String { case WhitespaceSplit case Metaspace case BertPreTokenizer - // Several more to be supported + /// Several more to be supported case Unknown = "" } struct PreTokenizerFactory { static func fromConfig(config: Config?) -> PreTokenizer? { - guard let config = config else { return nil } + guard let config else { return nil } guard let typeName = config.type?.stringValue else { return nil } let type = PreTokenizerType(rawValue: typeName) switch type { - case .Sequence : return PreTokenizerSequence(config: config) + case .Sequence: return PreTokenizerSequence(config: config) case .ByteLevel: return ByteLevelPreTokenizer(config: config) case .Punctuation: return PunctuationPreTokenizer(config: config) case .Digits: return DigitsPreTokenizer(config: config) @@ -79,7 +79,7 @@ class BertPreTokenizer: PreTokenizer { } func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { - return text.ranges(of: re).map { String(text[$0]) } + text.ranges(of: re).map { String(text[$0]) } } } @@ -106,7 +106,7 @@ class WhitespacePreTokenizer: PreTokenizer { } func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { - return text.ranges(of: re).map { String(text[$0]) } + text.ranges(of: re).map { String(text[$0]) } } } @@ -128,7 +128,7 @@ class MetaspacePreTokenizer: PreTokenizer { static var defaultScheme: PrependScheme { .always } static func from(rawValue value: String?) -> PrependScheme { - guard let value = value else { return defaultScheme } + guard let value else { return defaultScheme } return PrependScheme(rawValue: value) ?? defaultScheme } } @@ -143,8 +143,8 @@ class MetaspacePreTokenizer: PreTokenizer { prependScheme = PrependScheme.from(rawValue: config.prependScheme?.stringValue) } - // https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L114 - // https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/tokenizers.js#L2153 + /// https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L114 + /// https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/tokenizers.js#L2153 func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { let normalized = text.replacingOccurrences(of: " ", with: stringReplacement) @@ -157,11 +157,11 @@ class MetaspacePreTokenizer: PreTokenizer { // FIXME: (2b) always prepends, we are not passing section info var prepend = "" - if addPrefixSpace && !normalized.hasPrefix(replacement) { + if addPrefixSpace, !normalized.hasPrefix(replacement) { if prependScheme == .always { prepend = stringReplacement } - if prependScheme == .first && options.contains(.firstSection) { + if prependScheme == .first, options.contains(.firstSection) { prepend = stringReplacement } } @@ -186,14 +186,14 @@ class ByteLevelPreTokenizer: PreTokenizer { func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { // Split on whitespace and punctuation - let tokens = useRegex ? text.ranges(of: RE).map({ String(text[$0]) }) : [text] + let tokens = useRegex ? text.ranges(of: RE).map { String(text[$0]) } : [text] return tokens.map { token in - if addPrefixSpace && !token.hasPrefix(" ") { + if addPrefixSpace, !token.hasPrefix(" ") { return " " + token } return token }.map { token in - return Array(token.utf8).map { byteEncoder[$0]! }.joined() + Array(token.utf8).map { byteEncoder[$0]! }.joined() } } } @@ -207,7 +207,7 @@ class PunctuationPreTokenizer: PreTokenizer { func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { // Ref: https://github.com/xenova/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1138 - return text.ranges(of: re).map { String(text[$0]) } + text.ranges(of: re).map { String(text[$0]) } } } @@ -220,7 +220,7 @@ class DigitsPreTokenizer: PreTokenizer { } func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { - return text.ranges(of: re).map { String(text[$0]) } + text.ranges(of: re).map { String(text[$0]) } } } @@ -234,7 +234,7 @@ class SplitPreTokenizer: PreTokenizer { } func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { - guard let pattern = pattern else { return [text] } + guard let pattern else { return [text] } return pattern.split(text, invert: invert) } } @@ -247,10 +247,10 @@ enum StringSplitPattern { extension StringSplitPattern { func split(_ text: String, invert: Bool = true) -> [String] { switch self { - case .regexp(let regexp): - return text.split(by: regexp, includeSeparators: true) - case .string(let substring): - return text.split(by: substring, options: [], includeSeparators: !invert) + case let .regexp(regexp): + text.split(by: regexp, includeSeparators: true) + case let .string(substring): + text.split(by: substring, options: [], includeSeparators: !invert) } } } @@ -283,7 +283,7 @@ public extension String { var start = startIndex while let range = range(of: string, options: options, range: start.. bestScore { @@ -88,10 +88,10 @@ extension TokenLattice { // TODO: the reference implementations have a few more clones here: verify var result: [TokenLatticeNode] = [] - var node = prev //.clone() + var node = prev // .clone() while node.prev != nil { result.append(node.clone()) - node = node.prev! //.clone() + node = node.prev! // .clone() } return result.reversed() } @@ -125,7 +125,7 @@ class TokenLatticeNode { let length: Int let score: Float - var prev: TokenLatticeNode? = nil + var prev: TokenLatticeNode? var backtraceScore: Float = 0 init(tokenId: Int, startOffset: Int, length: Int, score: Float, prev: TokenLatticeNode? = nil, backtraceScore: Float = 0) { @@ -139,8 +139,8 @@ class TokenLatticeNode { } extension TokenLatticeNode { - // This is a reference type because structs can't contain references to the same type - // We could implement NSCopying, but frankly I don't see the point + /// This is a reference type because structs can't contain references to the same type + /// We could implement NSCopying, but frankly I don't see the point func clone() -> TokenLatticeNode { TokenLatticeNode(tokenId: tokenId, startOffset: startOffset, length: length, score: score, prev: prev, backtraceScore: backtraceScore) } diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index 935a37e..1926933 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -5,8 +5,8 @@ // Created by Pedro Cuenca on 6/5/23. // -import Hub import Foundation +import Hub import Jinja public typealias Message = [String: Any] @@ -24,22 +24,22 @@ public enum TokenizerError: LocalizedError { public var errorDescription: String? { switch self { - case .missingConfig: - return String(localized: "Tokenizer configuration is missing.", comment: "Error when tokenizer config cannot be found") - case .missingTokenizerClassInConfig: - return String(localized: "The tokenizer class is not specified in the configuration.", comment: "Error when tokenizer_class is missing in config") - case .unsupportedTokenizer(let name): - return String(localized: "The tokenizer type '\(name)' is not supported.", comment: "Error when tokenizer type is not supported") - case .missingVocab: - return String(localized: "Vocabulary file is missing from the tokenizer configuration.", comment: "Error when vocab file is missing") - case .malformedVocab: - return String(localized: "The vocabulary file is malformed or corrupted.", comment: "Error when vocab file is malformed") - case .chatTemplate(let message): - return String(localized: "Chat template error: \(message)", comment: "Error with chat template") - case .tooLong(let message): - return String(localized: "Input is too long: \(message)", comment: "Error when input exceeds maximum length") - case .mismatchedConfig(let message): - return String(localized: "Tokenizer configuration mismatch: \(message)", comment: "Error when tokenizer configuration is inconsistent") + case .missingConfig: + String(localized: "Tokenizer configuration is missing.", comment: "Error when tokenizer config cannot be found") + case .missingTokenizerClassInConfig: + String(localized: "The tokenizer class is not specified in the configuration.", comment: "Error when tokenizer_class is missing in config") + case let .unsupportedTokenizer(name): + String(localized: "The tokenizer type '\(name)' is not supported.", comment: "Error when tokenizer type is not supported") + case .missingVocab: + String(localized: "Vocabulary file is missing from the tokenizer configuration.", comment: "Error when vocab file is missing") + case .malformedVocab: + String(localized: "The vocabulary file is malformed or corrupted.", comment: "Error when vocab file is malformed") + case let .chatTemplate(message): + String(localized: "Chat template error: \(message)", comment: "Error with chat template") + case let .tooLong(message): + String(localized: "Input is too long: \(message)", comment: "Error when input exceeds maximum length") + case let .mismatchedConfig(message): + String(localized: "Tokenizer configuration mismatch: \(message)", comment: "Error when tokenizer configuration is inconsistent") } } } @@ -47,7 +47,7 @@ public enum TokenizerError: LocalizedError { public protocol TokenizingModel { func tokenize(text: String) -> [String] - // Alias for `tokenize` + /// Alias for `tokenize` func callAsFunction(_ text: String) -> [String] func convertTokenToId(_ token: String) -> Int? @@ -66,9 +66,9 @@ public protocol TokenizingModel { var fuseUnknownTokens: Bool { get } } -// Helper - possibly to be moved somewhere else +/// Helper - possibly to be moved somewhere else func addedTokenAsString(_ addedToken: Config?) -> String? { - guard let addedToken = addedToken else { return nil } + guard let addedToken else { return nil } if let stringValue = addedToken.stringValue { return stringValue } @@ -83,42 +83,42 @@ public extension TokenizingModel { } func convertTokensToIds(_ tokens: [String]) -> [Int?] { - return tokens.map { convertTokenToId($0) } + tokens.map { convertTokenToId($0) } } func convertIdsToTokens(_ ids: [Int]) -> [String?] { - return ids.map { convertIdToToken($0) } + ids.map { convertIdToToken($0) } } } /// A tokenizer model that is set up with Hub configuration data public protocol PreTrainedTokenizerModel: TokenizingModel { - init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws + init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws } struct TokenizerModel { - static let knownTokenizers: [String : PreTrainedTokenizerModel.Type] = [ - "BertTokenizer" : BertTokenizer.self, + static let knownTokenizers: [String: PreTrainedTokenizerModel.Type] = [ + "BertTokenizer": BertTokenizer.self, "DistilbertTokenizer": BertTokenizer.self, "DistilBertTokenizer": BertTokenizer.self, - "CodeGenTokenizer" : CodeGenTokenizer.self, - "CodeLlamaTokenizer" : CodeLlamaTokenizer.self, - "FalconTokenizer" : FalconTokenizer.self, - "GemmaTokenizer" : GemmaTokenizer.self, - "GPT2Tokenizer" : GPT2Tokenizer.self, - "LlamaTokenizer" : LlamaTokenizer.self, - "T5Tokenizer" : T5Tokenizer.self, - "WhisperTokenizer" : WhisperTokenizer.self, - "CohereTokenizer" : CohereTokenizer.self, - "Qwen2Tokenizer" : Qwen2Tokenizer.self, - "PreTrainedTokenizer": BPETokenizer.self + "CodeGenTokenizer": CodeGenTokenizer.self, + "CodeLlamaTokenizer": CodeLlamaTokenizer.self, + "FalconTokenizer": FalconTokenizer.self, + "GemmaTokenizer": GemmaTokenizer.self, + "GPT2Tokenizer": GPT2Tokenizer.self, + "LlamaTokenizer": LlamaTokenizer.self, + "T5Tokenizer": T5Tokenizer.self, + "WhisperTokenizer": WhisperTokenizer.self, + "CohereTokenizer": CohereTokenizer.self, + "Qwen2Tokenizer": Qwen2Tokenizer.self, + "PreTrainedTokenizer": BPETokenizer.self, ] static func unknownToken(from tokenizerConfig: Config) -> String? { - return tokenizerConfig.unkToken?.content?.stringValue ?? tokenizerConfig.unkToken?.stringValue + tokenizerConfig.unkToken?.content?.stringValue ?? tokenizerConfig.unkToken?.stringValue } - public static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws -> TokenizingModel { + public static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws -> TokenizingModel { guard let tokenizerClassName = tokenizerConfig.tokenizerClass?.stringValue else { throw TokenizerError.missingTokenizerClassInConfig } @@ -184,7 +184,7 @@ public protocol Tokenizer { func applyChatTemplate( messages: [Message], - /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. + // A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. chatTemplate: ChatTemplateArgument?, addGenerationPrompt: Bool, truncation: Bool, @@ -194,7 +194,7 @@ public protocol Tokenizer { func applyChatTemplate( messages: [Message], - /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. + // A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. chatTemplate: ChatTemplateArgument?, addGenerationPrompt: Bool, truncation: Bool, @@ -210,7 +210,7 @@ extension Tokenizer { /// Call previous signature for backwards compatibility func applyChatTemplate( messages: [Message], - /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. + // A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. chatTemplate: ChatTemplateArgument?, addGenerationPrompt: Bool, truncation: Bool, @@ -236,11 +236,11 @@ public extension Tokenizer { } func convertTokensToIds(_ tokens: [String]) -> [Int?] { - return tokens.map { convertTokenToId($0) } + tokens.map { convertTokenToId($0) } } func convertIdsToTokens(_ ids: [Int]) -> [String?] { - return ids.map { convertIdToToken($0) } + ids.map { convertIdToToken($0) } } } @@ -252,7 +252,7 @@ let specialTokenAttributes: [String] = [ "pad_token", "cls_token", "mask_token", - "additional_special_tokens" + "additional_special_tokens", ] public class PreTrainedTokenizer: Tokenizer { @@ -278,9 +278,9 @@ public class PreTrainedTokenizer: Tokenizer { private let cleanUpTokenizationSpaces: Bool - required public init(tokenizerConfig: Config, tokenizerData: Config) throws { - var addedTokens: [String : Int] = [:] - var specialTokens: [String : Int] = [:] + public required init(tokenizerConfig: Config, tokenizerData: Config) throws { + var addedTokens: [String: Int] = [:] + var specialTokens: [String: Int] = [:] for addedToken in tokenizerData.addedTokens?.arrayValue ?? [] { guard let id = addedToken.id?.intValue else { continue /* malformed: token with no id */ } guard let content = addedToken.content?.stringValue else { continue /* malformed: token with no content */ } @@ -293,7 +293,7 @@ public class PreTrainedTokenizer: Tokenizer { // Convert to tuples for easier access, then sort by length (descending) to avoid early partial matches // (https://github.com/xenova/transformers.js/commit/c305c3824f628f1f02806a6310bd3b18b0f7f8f5) - let unwrappedAddedTokens : [(content: String, prefix: Bool, suffix: Bool)] = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in + let unwrappedAddedTokens: [(content: String, prefix: Bool, suffix: Bool)] = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in guard let content = addedToken.content?.stringValue else { return nil } let prefix = addedToken.lstrip?.boolValue ?? false let suffix = addedToken.rstrip?.boolValue ?? false @@ -315,28 +315,28 @@ public class PreTrainedTokenizer: Tokenizer { self.specialTokens = specialTokens self.addedTokens = Set(addedTokens.keys) - self.preTokenizer = PreTokenizerFactory.fromConfig(config: tokenizerData.preTokenizer) - self.normalizer = NormalizerFactory.fromConfig(config: tokenizerData.normalizer) - self.postProcessor = PostProcessorFactory.fromConfig(config: tokenizerData.postProcessor) - self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder, addedTokens: self.addedTokens) - self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true + preTokenizer = PreTokenizerFactory.fromConfig(config: tokenizerData.preTokenizer) + normalizer = NormalizerFactory.fromConfig(config: tokenizerData.normalizer) + postProcessor = PostProcessorFactory.fromConfig(config: tokenizerData.postProcessor) + decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder, addedTokens: self.addedTokens) + cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true self.tokenizerConfig = tokenizerConfig model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) } func preTokenize(_ text: String, options: PreTokenizerOptions) -> [String] { - guard let preTokenizer = preTokenizer else { return [text] } + guard let preTokenizer else { return [text] } return preTokenizer(text: text, options: options) } func normalize(_ text: String) -> String { - guard let normalizer = normalizer else { return text } + guard let normalizer else { return text } return normalizer(text: text) } func postProcess(_ tokens: [String], addSpecialTokens: Bool = true) -> [String] { - guard let postProcessor = postProcessor else { return tokens } + guard let postProcessor else { return tokens } return postProcessor(tokens: tokens, addSpecialTokens: addSpecialTokens) } @@ -379,11 +379,10 @@ public class PreTrainedTokenizer: Tokenizer { public func tokenize(text: String) -> [String] { // Take care of special tokens first - let sections: [String] - if let regex = self.addedTokensRegex { - sections = text.split(by: regex) + let sections: [String] = if let regex = addedTokensRegex { + text.split(by: regex) } else { - sections = [text] + [text] } return sections.enumerated().map { section, x in if addedTokens.contains(x) { return [x] } @@ -393,11 +392,11 @@ public class PreTrainedTokenizer: Tokenizer { /// Main entry point public func encode(text: String, addSpecialTokens: Bool = true) -> [Int] { - return postProcess(tokenize(text: text), addSpecialTokens: addSpecialTokens).map { model.convertTokenToId($0)! } + postProcess(tokenize(text: text), addSpecialTokens: addSpecialTokens).map { model.convertTokenToId($0)! } } public func encode(text: String) -> [Int] { - return encode(text: text, addSpecialTokens: true) + encode(text: text, addSpecialTokens: true) } public func decode(tokens: [Int], skipSpecialTokens: Bool = false) -> String { @@ -425,7 +424,7 @@ public class PreTrainedTokenizer: Tokenizer { } public var hasChatTemplate: Bool { - return tokenizerConfig.chatTemplate != nil + tokenizerConfig.chatTemplate != nil } public func applyChatTemplate(messages: [Message]) throws -> [Int] { @@ -472,28 +471,28 @@ public class PreTrainedTokenizer: Tokenizer { addGenerationPrompt: Bool = false, truncation: Bool = false, maxLength: Int? = nil, - /// A list of tools (callable functions) that will be accessible to the model. If the template does not - /// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, - /// giving the name, description and argument types for the tool. See the - /// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) - /// for more information. + // A list of tools (callable functions) that will be accessible to the model. If the template does not + // support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + // giving the name, description and argument types for the tool. See the + // [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + // for more information. tools: [ToolSpec]? = nil, additionalContext: [String: Any]? = nil ) throws -> [Int] { var selectedChatTemplate: String? - if let chatTemplate, case .literal(let template) = chatTemplate { + if let chatTemplate, case let .literal(template) = chatTemplate { // Use chat template from argument selectedChatTemplate = template } else if let valueFromConfig = tokenizerConfig.chatTemplate { if let arrayValue = valueFromConfig.arrayValue { // If the config specifies a list of chat templates, convert them to a dictionary - let templateDict = Dictionary(uniqueKeysWithValues: arrayValue.compactMap { item in + let templateDict = [String: String](uniqueKeysWithValues: arrayValue.compactMap { item in guard let name = item.name?.stringValue, let template = item.template?.stringValue else { return nil } return (name, template) }) - if let chatTemplate, case .name(let name) = chatTemplate { + if let chatTemplate, case let .name(name) = chatTemplate { // Select chat template from config by name if let matchingDictEntry = templateDict[name] { selectedChatTemplate = matchingDictEntry @@ -537,11 +536,11 @@ public class PreTrainedTokenizer: Tokenizer { } // TODO: maybe keep NSString here - for (key, value) in tokenizerConfig.dictionary as [String : Any] { + for (key, value) in tokenizerConfig.dictionary as [String: Any] { if specialTokenAttributes.contains(key), !(value is NSNull) { if let stringValue = value as? String { context[key] = stringValue - } else if let dictionary = value as? [NSString:Any] { + } else if let dictionary = value as? [NSString: Any] { context[key] = addedTokenAsString(Config(dictionary)) } else { context[key] = value @@ -565,18 +564,18 @@ public class PreTrainedTokenizer: Tokenizer { // MARK: - Building -public struct AutoTokenizer {} +public struct AutoTokenizer { } struct PreTrainedTokenizerClasses { /// Class overrides for custom behaviour /// Not to be confused with the TokenizerModel classes defined in TokenizerModel - static let tokenizerClasses: [String : PreTrainedTokenizer.Type] = [ - "LlamaTokenizer": LlamaPreTrainedTokenizer.self + static let tokenizerClasses: [String: PreTrainedTokenizer.Type] = [ + "LlamaTokenizer": LlamaPreTrainedTokenizer.self, ] } -extension AutoTokenizer { - static func tokenizerClass(for tokenizerConfig: Config) -> PreTrainedTokenizer.Type { +public extension AutoTokenizer { + internal static func tokenizerClass(for tokenizerConfig: Config) -> PreTrainedTokenizer.Type { guard let tokenizerClassName = tokenizerConfig.tokenizerClass?.stringValue else { return PreTrainedTokenizer.self } @@ -590,12 +589,12 @@ extension AutoTokenizer { return PreTrainedTokenizer.self } - public static func from(tokenizerConfig: Config, tokenizerData: Config) throws -> Tokenizer { + static func from(tokenizerConfig: Config, tokenizerData: Config) throws -> Tokenizer { let tokenizerClass = tokenizerClass(for: tokenizerConfig) return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } - public static func from( + static func from( pretrained model: String, hubApi: HubApi = .shared ) async throws -> Tokenizer { @@ -606,7 +605,7 @@ extension AutoTokenizer { return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } - public static func from( + static func from( modelFolder: URL, hubApi: HubApi = .shared ) async throws -> Tokenizer { @@ -620,45 +619,43 @@ extension AutoTokenizer { // MARK: - Tokenizer model classes -class GPT2Tokenizer : BPETokenizer {} -class FalconTokenizer : BPETokenizer {} -class LlamaTokenizer : BPETokenizer {} -class CodeGenTokenizer : BPETokenizer {} -class WhisperTokenizer : BPETokenizer {} -class GemmaTokenizer : BPETokenizer {} -class CodeLlamaTokenizer: BPETokenizer {} -class CohereTokenizer : BPETokenizer {} -class Qwen2Tokenizer : BPETokenizer {} - -class T5Tokenizer : UnigramTokenizer {} +class GPT2Tokenizer: BPETokenizer { } +class FalconTokenizer: BPETokenizer { } +class LlamaTokenizer: BPETokenizer { } +class CodeGenTokenizer: BPETokenizer { } +class WhisperTokenizer: BPETokenizer { } +class GemmaTokenizer: BPETokenizer { } +class CodeLlamaTokenizer: BPETokenizer { } +class CohereTokenizer: BPETokenizer { } +class Qwen2Tokenizer: BPETokenizer { } +class T5Tokenizer: UnigramTokenizer { } // MARK: - PreTrainedTokenizer classes let sentencePieceUnderline = "▁" -// Hack for Llama tokenizers, see https://github.com/huggingface/transformers/blob/bcb841f0073fcd7a4fb88ea8064313c17dcab04a/src/transformers/models/llama/tokenization_llama_fast.py#L181 -// Return updated config, or nil +/// Hack for Llama tokenizers, see https://github.com/huggingface/transformers/blob/bcb841f0073fcd7a4fb88ea8064313c17dcab04a/src/transformers/models/llama/tokenization_llama_fast.py#L181 +/// Return updated config, or nil func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?) throws -> Config? { - // If it's already a Template processor (instead of a ByteLevel one), assume it's correct let postProcessor = PostProcessorFactory.fromConfig(config: processorConfig) guard !(postProcessor is TemplateProcessing) else { return nil } let addBosToken = tokenizerConfig.addBosToken?.boolValue ?? false let bosToken = addedTokenAsString(tokenizerConfig.bosToken) - if addBosToken && bosToken == nil { + if addBosToken, bosToken == nil { throw TokenizerError.mismatchedConfig("add_bos_token is True but bos_token is nil") } let addEosToken = tokenizerConfig.addEosToken?.boolValue ?? false let eosToken = addedTokenAsString(tokenizerConfig.eosToken) - if addEosToken && eosToken == nil { + if addEosToken, eosToken == nil { throw TokenizerError.mismatchedConfig("add_eos_token is True but eos_token is nil") } // alt implementation - var single: [[String : Any]] = [] + var single: [[String: Any]] = [] if addBosToken { single = single + [["SpecialToken": ["id": bosToken!, "type_id": 0]]] } @@ -667,7 +664,7 @@ func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?) single = single + [["SpecialToken": ["id": eosToken!, "type_id": 0]]] } - var pair: [[String : Any]] = single + var pair: [[String: Any]] = single if addBosToken { pair = pair + [["SpecialToken": ["id": bosToken!, "type_id": 1]]] } @@ -680,7 +677,7 @@ func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?) return postProcessorConfig } -// See https://github.com/xenova/transformers.js/blob/1a9964fb09b8f54fcbeac46dc6aae8d76795809d/src/tokenizers.js#L3203 for these exceptions +/// See https://github.com/xenova/transformers.js/blob/1a9964fb09b8f54fcbeac46dc6aae8d76795809d/src/tokenizers.js#L3203 for these exceptions class LlamaPreTrainedTokenizer: PreTrainedTokenizer { let isLegacy: Bool @@ -700,4 +697,3 @@ class LlamaPreTrainedTokenizer: PreTrainedTokenizer { try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData) } } - diff --git a/Sources/Tokenizers/Trie.swift b/Sources/Tokenizers/Trie.swift index 6c7f79c..87263cc 100644 --- a/Sources/Tokenizers/Trie.swift +++ b/Sources/Tokenizers/Trie.swift @@ -34,7 +34,9 @@ public extension Trie { } func append(contentsOf container: any Sequence>) { - for t in container { insert(t) } + for t in container { + insert(t) + } } /// Find all leaf nodes that share a common prefix with the input sequence (usually a text) @@ -57,12 +59,12 @@ public extension Trie { /// Find all leaf nodes that share a common prefix with the input sequence (usually a text) /// Returns an iterator func commonPrefixSearchIterator(_ text: any Sequence) -> LeavesWithCommonPrefixIterator { - return LeavesWithCommonPrefixIterator(node: root, text: text) + LeavesWithCommonPrefixIterator(node: root, text: text) } } public extension Trie { - // Only used for testing, could migrate to collection + /// Only used for testing, could migrate to collection func get(_ element: any Sequence) -> Node? { var node = root for item in element { @@ -79,7 +81,7 @@ public class TrieNode { var children: [T: TrieNode] = [:] } -public struct LeavesWithCommonPrefixIterator : Sequence, IteratorProtocol { +public struct LeavesWithCommonPrefixIterator: Sequence, IteratorProtocol { var node: TrieNode var text: any Sequence var seq: [T] = [] diff --git a/Sources/Tokenizers/UnigramTokenizer.swift b/Sources/Tokenizers/UnigramTokenizer.swift index 58a7217..5f88eaf 100644 --- a/Sources/Tokenizers/UnigramTokenizer.swift +++ b/Sources/Tokenizers/UnigramTokenizer.swift @@ -14,6 +14,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel { var token: String var score: Float } + let vocab: [SentencePieceToken] let unknownPiece: SentencePieceToken @@ -30,19 +31,20 @@ class UnigramTokenizer: PreTrainedTokenizerModel { let eosToken: String? let eosTokenId: Int? - // Hardcoded in Unigram tokenizers + /// Hardcoded in Unigram tokenizers let fuseUnknownTokens: Bool = true private let trie: Trie - required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws { + required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws { guard let configVocab = tokenizerData.model?.vocab?.value as? [[Any]] else { throw TokenizerError.missingVocab } vocab = try configVocab.map { piece in guard let token = piece.first as? String, - let scoreValue = piece.last else { + let scoreValue = piece.last + else { throw TokenizerError.malformedVocab } @@ -64,10 +66,10 @@ class UnigramTokenizer: PreTrainedTokenizerModel { guard let unknownTokenId = tokenizerData.model?.unkId?.intValue else { throw TokenizerError.malformedVocab } self.unknownTokenId = unknownTokenId - self.unknownPiece = SentencePieceToken(token: vocab[unknownTokenId].token, score: minScore - 10) + unknownPiece = SentencePieceToken(token: vocab[unknownTokenId].token, score: minScore - 10) tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token as NSString }.enumerated().map { ($1, $0) }) - bosTokenId = tokensToIds[bosToken! as NSString] // May be nil + bosTokenId = tokensToIds[bosToken! as NSString] // May be nil eosToken = tokenizerConfig.eosToken?.stringValue eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString] @@ -77,11 +79,11 @@ class UnigramTokenizer: PreTrainedTokenizerModel { } func convertTokenToId(_ token: String) -> Int? { - return tokensToIds[token as NSString] ?? self.unknownTokenId + tokensToIds[token as NSString] ?? unknownTokenId } func convertIdToToken(_ id: Int) -> String? { - return vocab[id].token + vocab[id].token } func tokenize(text: String) -> [String] { @@ -99,7 +101,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel { guard let tokenId = tokensToIds[token as NSString] else { fatalError("Token not in vocab: \(token)") } let tokenScore = vocab[tokenId].score lattice.insert(startOffset: beginPos, length: token.count, score: tokenScore, tokenId: tokenId) - if !hasSingleNode && token.count == mblen { + if !hasSingleNode, token.count == mblen { hasSingleNode = true } } diff --git a/Sources/Tokenizers/Utils.swift b/Sources/Tokenizers/Utils.swift index 9efacc2..c1f19ff 100644 --- a/Sources/Tokenizers/Utils.swift +++ b/Sources/Tokenizers/Utils.swift @@ -30,17 +30,17 @@ struct Utils { static func dateNow() -> Int64 { // Use `Int` when we don't support 32-bits devices/OSes anymore. // Int crashes on iPhone 5c. - return Int64(Date().timeIntervalSince1970 * 1000) + Int64(Date().timeIntervalSince1970 * 1000) } /// Clamp a val to [min, max] static func clamp(_ val: T, _ vmin: T, _ vmax: T) -> T { - return min(max(vmin, val), vmax) + min(max(vmin, val), vmax) } /// Fake func that can throw. static func fakeThrowable(_ input: T) throws -> T { - return input + input } /// Substring @@ -55,7 +55,7 @@ struct Utils { } /// Invert a (k, v) dictionary - static func invert(_ dict: Dictionary) -> Dictionary { + static func invert(_ dict: [K: V]) -> [V: K] { var inverted: [V: K] = [:] for (k, v) in dict { inverted[v] = k @@ -67,17 +67,16 @@ struct Utils { /// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) static func isChineseChar(_ c: UnicodeScalar) -> Bool { (c.value >= 0x4E00 && c.value <= 0x9FFF) || - (c.value >= 0x3400 && c.value <= 0x4DBF) || - (c.value >= 0x20000 && c.value <= 0x2A6DF) || - (c.value >= 0x2A700 && c.value <= 0x2B73F) || - (c.value >= 0x2B740 && c.value <= 0x2B81F) || - (c.value >= 0x2B820 && c.value <= 0x2CEAF) || - (c.value >= 0xF900 && c.value <= 0xFAFF) || - (c.value >= 0x2F800 && c.value <= 0x2FA1F) + (c.value >= 0x3400 && c.value <= 0x4DBF) || + (c.value >= 0x20000 && c.value <= 0x2A6DF) || + (c.value >= 0x2A700 && c.value <= 0x2B73F) || + (c.value >= 0x2B740 && c.value <= 0x2B81F) || + (c.value >= 0x2B820 && c.value <= 0x2CEAF) || + (c.value >= 0xF900 && c.value <= 0xFAFF) || + (c.value >= 0x2F800 && c.value <= 0x2FA1F) } } enum Constants { static let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"# } - diff --git a/Sources/TransformersCLI/main.swift b/Sources/TransformersCLI/main.swift index a040946..c62c75d 100644 --- a/Sources/TransformersCLI/main.swift +++ b/Sources/TransformersCLI/main.swift @@ -2,8 +2,8 @@ import ArgumentParser import CoreML import Foundation -import Models import Generation +import Models @available(iOS 16.2, macOS 13.1, *) struct TransformersCLI: ParsableCommand { @@ -88,10 +88,10 @@ enum ComputeUnits: String, ExpressibleByArgument, CaseIterable { case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine var asMLComputeUnits: MLComputeUnits { switch self { - case .all: return .all - case .cpuAndGPU: return .cpuAndGPU - case .cpuOnly: return .cpuOnly - case .cpuAndNeuralEngine: return .cpuAndNeuralEngine + case .all: .all + case .cpuAndGPU: .cpuAndGPU + case .cpuOnly: .cpuOnly + case .cpuAndNeuralEngine: .cpuAndNeuralEngine } } } @@ -104,6 +104,6 @@ if #available(iOS 16.2, macOS 13.1, *) { extension Double { func formatted(_ format: String) -> String { - return String(format: "\(format)", self) + String(format: "\(format)", self) } } diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift index 124e609..2c6ed02 100644 --- a/Tests/HubTests/DownloaderTests.swift +++ b/Tests/HubTests/DownloaderTests.swift @@ -5,9 +5,9 @@ // Created by Arda Atahan Ibis on 1/28/25. // -import XCTest import Combine @testable import Hub +import XCTest /// Errors that can occur during the download process enum DownloadError: LocalizedError { @@ -16,10 +16,10 @@ enum DownloadError: LocalizedError { var errorDescription: String? { switch self { - case .invalidDownloadLocation: - return String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid") - case .unexpectedError: - return String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message") + case .invalidDownloadLocation: + String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid") + case .unexpectedError: + String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message") } } } @@ -34,7 +34,9 @@ final class DownloaderTests: XCTestCase { } override func tearDown() { - try? FileManager.default.removeItem(at: tempDir) + if let tempDir, FileManager.default.fileExists(atPath: tempDir.path) { + try? FileManager.default.removeItem(at: tempDir) + } super.tearDown() } @@ -54,7 +56,7 @@ final class DownloaderTests: XCTestCase { "pad_token_id": 0, "vocab_size": 32000 } - + """ let downloader = Downloader( @@ -70,7 +72,7 @@ final class DownloaderTests: XCTestCase { switch state { case .completed: continuation.resume() - case .failed(let error): + case let .failed(error): continuation.resume(throwing: error) case .downloading: break @@ -97,15 +99,13 @@ final class DownloaderTests: XCTestCase { let downloader = Downloader( from: url, to: destination, - expectedSize: 999999 // Incorrect size + expectedSize: 999999 // Incorrect size ) do { try downloader.waitUntilDone() XCTFail("Download should have failed due to size mismatch") - } catch { - - } + } catch { } // Verify no file was created at destination XCTAssertFalse(FileManager.default.fileExists(atPath: destination.path)) @@ -119,7 +119,7 @@ final class DownloaderTests: XCTestCase { // Create parent directory if it doesn't exist try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(), - withIntermediateDirectories: true) + withIntermediateDirectories: true) let downloader = Downloader( from: url, @@ -138,15 +138,15 @@ final class DownloaderTests: XCTestCase { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in subscriber = downloader.downloadState.sink { state in switch state { - case .downloading(let progress): - if threshold != 1.0 && progress >= threshold { + case let .downloading(progress): + if threshold != 1.0, progress >= threshold { // Move to next threshold and interrupt threshold = threshold == 0.5 ? 0.75 : 1.0 downloader.cancel() } case .completed: continuation.resume() - case .failed(let error): + case let .failed(error): continuation.resume(throwing: error) case .notStarted: break diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index f64b35c..451816f 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -174,7 +174,7 @@ class SnapshotDownloadTests: XCTestCase { return base.appending(component: "huggingface-tests") }() - override func setUp() {} + override func setUp() { } override func tearDown() { do { @@ -851,7 +851,7 @@ class SnapshotDownloadTests: XCTestCase { XCTFail("Expected an error to be thrown") } catch let error as HubApi.EnvironmentError { switch error { - case .offlineModeError(let message): + case let .offlineModeError(message): XCTAssertEqual(message, "Repository not available locally") default: XCTFail("Wrong error type: \(error)") @@ -888,7 +888,7 @@ class SnapshotDownloadTests: XCTestCase { XCTFail("Expected an error to be thrown") } catch let error as HubApi.EnvironmentError { switch error { - case .offlineModeError(let message): + case let .offlineModeError(message): XCTAssertEqual(message, "Metadata not available for x.bin") default: XCTFail("Wrong error type: \(error)") @@ -924,7 +924,7 @@ class SnapshotDownloadTests: XCTestCase { XCTFail("Expected an error to be thrown") } catch let error as HubApi.EnvironmentError { switch error { - case .fileIntegrityError(let message): + case let .fileIntegrityError(message): XCTAssertEqual(message, "Hash mismatch for x.bin") default: XCTFail("Wrong error type: \(error)") @@ -959,7 +959,7 @@ class SnapshotDownloadTests: XCTestCase { XCTFail("Expected an error to be thrown") } catch let error as HubApi.EnvironmentError { switch error { - case .offlineModeError(let message): + case let .offlineModeError(message): XCTAssertEqual(message, "No files available locally for this repository") default: XCTFail("Wrong error type: \(error)") diff --git a/Tests/HubTests/HubTests.swift b/Tests/HubTests/HubTests.swift index 1d7bc86..1f726b0 100644 --- a/Tests/HubTests/HubTests.swift +++ b/Tests/HubTests/HubTests.swift @@ -4,9 +4,8 @@ // Created by Pedro Cuenca on 18/05/2023. // -import XCTest @testable import Hub - +import XCTest class HubTests: XCTestCase { let downloadDestination: URL = { @@ -14,7 +13,7 @@ class HubTests: XCTestCase { return base.appending(component: "huggingface-tests") }() - override func setUp() {} + override func setUp() { } override func tearDown() { do { diff --git a/Tests/NormalizerTests/NormalizerTests.swift b/Tests/NormalizerTests/NormalizerTests.swift index fea423a..71dfacf 100644 --- a/Tests/NormalizerTests/NormalizerTests.swift +++ b/Tests/NormalizerTests/NormalizerTests.swift @@ -4,7 +4,6 @@ import XCTest @testable import Tokenizers class NormalizerTests: XCTestCase { - func testLowercaseNormalizer() { let testCases: [(String, String)] = [ ("Café", "café"), @@ -125,8 +124,8 @@ class NormalizerTests: XCTestCase { ("département", "departement"), ] - //TODO: test combinations with/without lowercase - let config = Config(["stripAccents":true]) + // TODO: test combinations with/without lowercase + let config = Config(["stripAccents": true]) let normalizer = BertNormalizer(config: config) for (arg, expect) in testCases { XCTAssertEqual(normalizer.normalize(text: arg), expect) @@ -147,7 +146,7 @@ class NormalizerTests: XCTestCase { ] for (arg, expect) in testCases { - let config = Config(["stripAccents":false]) + let config = Config(["stripAccents": false]) let normalizer = BertNormalizer(config: config) XCTAssertEqual(normalizer.normalize(text: arg), expect) } @@ -248,11 +247,11 @@ class NormalizerTests: XCTestCase { let normalizer = StripNormalizer(config: config) XCTAssertEqual( normalizer.normalize(text: input), expected, - "Failed for input: '\(input)', leftStrip: \(leftStrip), rightStrip: \(rightStrip)") + "Failed for input: '\(input)', leftStrip: \(leftStrip), rightStrip: \(rightStrip)" + ) } let config = Config(["type": NormalizerType.Strip.rawValue]) XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? StripNormalizer) } - } diff --git a/Tests/PostProcessorTests/PostProcessorTests.swift b/Tests/PostProcessorTests/PostProcessorTests.swift index 347bc38..0e46cb2 100644 --- a/Tests/PostProcessorTests/PostProcessorTests.swift +++ b/Tests/PostProcessorTests/PostProcessorTests.swift @@ -1,17 +1,16 @@ -import XCTest -@testable import Tokenizers @testable import Hub +@testable import Tokenizers +import XCTest class PostProcessorTests: XCTestCase { func testRobertaProcessing() { - let testCases: [(Config, [String], [String]?, [String])] = [ + let testCases: [(Config, [String], [String]?, [String])] = [ // Should keep spaces; uneven spaces; ignore `addPrefixSpace`. ( Config(["cls": (0, "[HEAD]") as (UInt, String), "sep": (0, "[END]") as (UInt, String), "trimOffset": false, - "addPrefixSpace": true, - ]), + "addPrefixSpace": true]), [" The", " sun", "sets ", " in ", " the ", "west"], nil, ["[HEAD]", " The", " sun", "sets ", " in ", " the ", "west", "[END]"] @@ -21,8 +20,7 @@ class PostProcessorTests: XCTestCase { Config(["cls": (0, "[START]") as (UInt, String), "sep": (0, "[BREAK]") as (UInt, String), "trimOffset": true, - "addPrefixSpace": true, - ]), + "addPrefixSpace": true]), [" The ", " sun", "sets ", " in ", " the ", "west"], nil, ["[START]", " The ", " sun", "sets ", " in ", " the ", "west", "[BREAK]"] @@ -32,8 +30,7 @@ class PostProcessorTests: XCTestCase { Config(["cls": (0, "[START]") as (UInt, String), "sep": (0, "[BREAK]") as (UInt, String), "trimOffset": true, - "addPrefixSpace": true, - ]), + "addPrefixSpace": true]), [" The ", " sun", "sets ", " in ", " the ", "west"], [], ["[START]", " The ", " sun", "sets ", " in ", " the ", "west", "[BREAK]"] @@ -43,8 +40,7 @@ class PostProcessorTests: XCTestCase { Config(["cls": (0, "[CLS]") as (UInt, String), "sep": (0, "[SEP]") as (UInt, String), "trimOffset": true, - "addPrefixSpace": false, - ]), + "addPrefixSpace": false]), [" The ", " sun", "sets ", " in ", " the ", "west"], nil, ["[CLS]", "The", "sun", "sets", "in", "the", "west", "[SEP]"] @@ -54,20 +50,18 @@ class PostProcessorTests: XCTestCase { Config(["cls": (0, "[CLS]") as (UInt, String), "sep": (0, "[SEP]") as (UInt, String), "trimOffset": true, - "addPrefixSpace": true, - ]), + "addPrefixSpace": true]), [" The ", " sun", "sets ", " in ", " the ", "west"], [".", "The", " cat ", " is ", " sitting ", " on", "the ", "mat"], ["[CLS]", " The ", " sun", "sets ", " in ", " the ", "west", "[SEP]", - "[SEP]", ".", "The", " cat ", " is ", " sitting ", " on", "the ", + "[SEP]", ".", "The", " cat ", " is ", " sitting ", " on", "the ", "mat", "[SEP]"] ), ( Config(["cls": (0, "[CLS]") as (UInt, String), "sep": (0, "[SEP]") as (UInt, String), "trimOffset": true, - "addPrefixSpace": true, - ]), + "addPrefixSpace": true]), [" 你 ", " 好 ", ","], [" 凯 ", " 蒂 ", "!"], ["[CLS]", " 你 ", " 好 ", ",", "[SEP]", "[SEP]", " 凯 ", " 蒂 ", "!", "[SEP]"] diff --git a/Tests/PreTokenizerTests/PreTokenizerTests.swift b/Tests/PreTokenizerTests/PreTokenizerTests.swift index 34e29c5..9715bfa 100644 --- a/Tests/PreTokenizerTests/PreTokenizerTests.swift +++ b/Tests/PreTokenizerTests/PreTokenizerTests.swift @@ -4,12 +4,11 @@ // Created by Jan Krukowski on 23/11/2023. // -import XCTest import Hub @testable import Tokenizers +import XCTest class PreTokenizerTests: XCTestCase { - func testWhitespacePreTokenizer() { let preTokenizer = WhitespacePreTokenizer(config: Config([:])) @@ -152,13 +151,13 @@ class PreTokenizerTests: XCTestCase { ) } - // https://github.com/huggingface/tokenizers/pull/1357 + /// https://github.com/huggingface/tokenizers/pull/1357 func testMetaspacePreTokenizer() { // Prepend "always" let preTokenizer = MetaspacePreTokenizer(config: Config([ "add_prefix_space": true, "replacement": "▁", - "prepend_scheme": "always" + "prepend_scheme": "always", ])) // TODO: different sections on diff --git a/Tests/TensorUtilsTests/LogitsWarperTests.swift b/Tests/TensorUtilsTests/LogitsWarperTests.swift index 0260967..1d5c5d2 100644 --- a/Tests/TensorUtilsTests/LogitsWarperTests.swift +++ b/Tests/TensorUtilsTests/LogitsWarperTests.swift @@ -4,9 +4,9 @@ // Created by Jan Krukowski on 09/12/2023. // -import XCTest import CoreML @testable import TensorUtils +import XCTest final class LogitsWarperTests: XCTestCase { private let accuracy: Float = 0.00001 @@ -68,7 +68,7 @@ final class LogitsWarperTests: XCTestCase { XCTAssertTrue(result1.indices.isEmpty) XCTAssertTrue(result1.logits.isEmpty) - let logits = (0 ..< 10).map { Float($0) } + let logits = (0..<10).map { Float($0) } let indices = Array(logits.indices) let result2 = TopPLogitsWarper(p: 0.99)(indices, logits) XCTAssertEqual(result2.indices, [9, 8, 7, 6, 5]) @@ -89,7 +89,7 @@ final class LogitsWarperTests: XCTestCase { func testRepetitionPenaltyWarper() { let indices = Array(0..<10) - let logits = indices.map({ Float($0) }) + let logits = indices.map { Float($0) } let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits) XCTAssertEqual(result1.indices, indices) @@ -97,7 +97,7 @@ final class LogitsWarperTests: XCTestCase { let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits) XCTAssertEqual(result2.indices, indices) - let logits2 = indices.map({ Float($0) / 3.75 }) + let logits2 = indices.map { Float($0) / 3.75 } XCTAssertEqual(result2.logits, logits2, accuracy: accuracy) let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119]) diff --git a/Tests/TensorUtilsTests/TensorUtilsTests.swift b/Tests/TensorUtilsTests/TensorUtilsTests.swift index 6355165..8fcbcef 100644 --- a/Tests/TensorUtilsTests/TensorUtilsTests.swift +++ b/Tests/TensorUtilsTests/TensorUtilsTests.swift @@ -4,9 +4,9 @@ // Created by Jan Krukowski on 25/11/2023. // -import XCTest import CoreML @testable import TensorUtils +import XCTest final class TensorUtilsTests: XCTestCase { private let accuracy: Float = 0.00001 @@ -30,11 +30,11 @@ final class TensorUtilsTests: XCTestCase { XCTAssertEqual(result3.0, 1) XCTAssertEqual(result3.1, 4.0) - let result4 = Math.argmax32(try MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Float])) + let result4 = try Math.argmax32(MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Float])) XCTAssertEqual(result4.0, 1) XCTAssertEqual(result4.1, 4.0) - let result5 = Math.argmax(try MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Double])) + let result5 = try Math.argmax(MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Double])) XCTAssertEqual(result5.0, 1) XCTAssertEqual(result5.1, 4.0) @@ -44,7 +44,7 @@ final class TensorUtilsTests: XCTestCase { } func testSoftmax() { - XCTAssertEqual(Math.softmax([]), []) + XCTAssertEqual(Math.softmax([]), []) let result1 = Math.softmax([3.0, 4.0, 1.0, 2.0]) XCTAssertEqual(result1, [0.23688284, 0.6439143, 0.032058604, 0.08714432], accuracy: accuracy) diff --git a/Tests/TensorUtilsTests/WeightsTests.swift b/Tests/TensorUtilsTests/WeightsTests.swift index 5d2e478..8a4eccc 100644 --- a/Tests/TensorUtilsTests/WeightsTests.swift +++ b/Tests/TensorUtilsTests/WeightsTests.swift @@ -1,12 +1,9 @@ -@testable import TensorUtils @testable import Hub +@testable import TensorUtils import XCTest class WeightsTests: XCTestCase { - - let downloadDestination: URL = { - FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests") - }() + let downloadDestination: URL = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests") var hubApi: HubApi { HubApi(downloadBase: downloadDestination) } diff --git a/Tests/TokenizersTests/AddedTokensTests.swift b/Tests/TokenizersTests/AddedTokensTests.swift index c82e45f..1880510 100644 --- a/Tests/TokenizersTests/AddedTokensTests.swift +++ b/Tests/TokenizersTests/AddedTokensTests.swift @@ -5,9 +5,9 @@ // Created by Pedro Cuenca on 20240426. // -import XCTest -import Tokenizers import Hub +import Tokenizers +import XCTest class AddedTokensTests: XCTestCase { func testPhiAddedTokens() async throws { @@ -71,6 +71,5 @@ class AddedTokensTests: XCTestCase { "<|raw|><|end|>".split(by: captureRegex), ["<|raw|>", "<|end|>"] ) - } } diff --git a/Tests/TokenizersTests/BertTokenizerTests.swift b/Tests/TokenizersTests/BertTokenizerTests.swift index 93a929c..bae38f6 100644 --- a/Tests/TokenizersTests/BertTokenizerTests.swift +++ b/Tests/TokenizersTests/BertTokenizerTests.swift @@ -6,13 +6,11 @@ // Copyright © 2019 Hugging Face. All rights reserved. // -import XCTest -@testable import Tokenizers @testable import Hub - +@testable import Tokenizers +import XCTest class BertTokenizerTests: XCTestCase { - override func setUp() { // Put setup code here. This method is called before the invocation of each test method in the class. } @@ -45,7 +43,7 @@ class BertTokenizerTests: XCTestCase { XCTAssertEqual( basicTokenizer.tokenize(text: text), tokens ) - /// Verify that `XCTAssertEqual` does what deep equality checks on arrays of strings. + // Verify that `XCTAssertEqual` does what deep equality checks on arrays of strings. XCTAssertEqual(["foo", "bar"], ["foo", "bar"]) } @@ -95,7 +93,7 @@ class BertTokenizerTests: XCTestCase { func testPureChineseTokenization() { let tokenizer = bertTokenizer let text = "明日,大家上山看日出。" - let expectedTokens = ["明", "日", ",", "大", "家", "上", "山", "[UNK]", "日", "出","。"] + let expectedTokens = ["明", "日", ",", "大", "家", "上", "山", "[UNK]", "日", "出", "。"] let tokens = tokenizer.tokenize(text: text) XCTAssertEqual(tokens, expectedTokens) @@ -123,7 +121,7 @@ class BertTokenizerTests: XCTestCase { let tokenizer = bertTokenizer // This is an example of a performance test case. - self.measure { + measure { // Put the code you want to measure the time of here. _ = tokenizer.tokenizeToIds(text: "Brave gaillard, d'où [UNK] êtes vous?") } diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift index 88e1843..e8cb880 100644 --- a/Tests/TokenizersTests/ChatTemplateTests.swift +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -5,8 +5,8 @@ // Created by Anthony DePasquale on 2/10/24. // -import XCTest import Tokenizers +import XCTest class ChatTemplateTests: XCTestCase { let messages = [[ @@ -87,7 +87,7 @@ class ChatTemplateTests: XCTestCase { [ "role": "user", "content": "What is the weather in Paris today?", - ] + ], ] let getCurrentWeatherToolSpec: [String: Any] = [ @@ -100,16 +100,16 @@ class ChatTemplateTests: XCTestCase { "properties": [ "location": [ "type": "string", - "description": "The city and state, e.g. San Francisco, CA" + "description": "The city and state, e.g. San Francisco, CA", ], "unit": [ "type": "string", - "enum": ["celsius", "fahrenheit"] - ] + "enum": ["celsius", "fahrenheit"], + ], ], - "required": ["location"] - ] - ] + "required": ["location"], + ], + ], ] let encoded = try tokenizer.applyChatTemplate(messages: weatherQueryMessages, tools: [getCurrentWeatherToolSpec]) @@ -130,9 +130,10 @@ class ChatTemplateTests: XCTestCase { } if let startRange = decoded.range(of: "\n"), - let endRange = decoded.range(of: "\n", range: startRange.upperBound..", range: startRange.upperBound..system -You are Qwen, created by Alibaba Cloud. You are a helpful assistant. + <|im_start|>system + You are Qwen, created by Alibaba Cloud. You are a helpful assistant. -# Tools + # Tools -You may call one or more functions to assist with the user query. + You may call one or more functions to assist with the user query. -You are provided with function signatures within XML tags: - -""" + You are provided with function signatures within XML tags: + + """ let expectedPromptEnd = """ - + -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } -<|im_end|> -<|im_start|>user -What is the weather in Paris today?<|im_end|> -<|im_start|>assistant + For each function call, return a json object with function name and arguments within XML tags: + + {"name": , "arguments": } + <|im_end|> + <|im_start|>user + What is the weather in Paris today?<|im_end|> + <|im_start|>assistant -""" + """ XCTAssertTrue(decoded.hasPrefix(expectedPromptStart), "Prompt should start with expected system message") XCTAssertTrue(decoded.hasSuffix(expectedPromptEnd), "Prompt should end with expected format") @@ -178,7 +179,7 @@ What is the weather in Paris today?<|im_end|> XCTAssertTrue(tokenizer.hasChatTemplate) } - // Test for vision models with a vision chat template in chat_template.json + /// Test for vision models with a vision chat template in chat_template.json func testChatTemplateFromChatTemplateJson() async throws { let visionMessages = [ [ @@ -193,7 +194,7 @@ What is the weather in Paris today?<|im_end|> "image_url": "example.jpg", ] as [String: String], ] as [[String: String]], - ] as [String: Any] + ] as [String: Any], ] as [[String: Any]] // Qwen 2 VL does not have a chat_template.json file. The chat template is in tokenizer_config.json. let qwen2VLTokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2-VL-7B-Instruct-4bit") @@ -204,13 +205,13 @@ What is the weather in Paris today?<|im_end|> let qwen2_5VLEncoded = try qwen2_5VLTokenizer.applyChatTemplate(messages: visionMessages) let qwen2_5VLDecoded = qwen2_5VLTokenizer.decode(tokens: qwen2_5VLEncoded) let expectedOutput = """ -<|im_start|>system -You are a helpful assistant.<|im_end|> -<|im_start|>user -What's in this image?<|vision_start|><|image_pad|><|vision_end|><|im_end|> -<|im_start|>assistant + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's in this image?<|vision_start|><|image_pad|><|vision_end|><|im_end|> + <|im_start|>assistant -""" + """ XCTAssertEqual(qwen2VLEncoded, qwen2_5VLEncoded, "Encoded sequences should be equal") XCTAssertEqual(qwen2VLDecoded, qwen2_5VLDecoded, "Decoded sequences should be equal") XCTAssertEqual(qwen2_5VLDecoded, expectedOutput, "Decoded sequence should match expected output") @@ -223,7 +224,7 @@ What's in this image?<|vision_start|><|image_pad|><|vision_end|><|im_end|> do { _ = try tokenizer.applyChatTemplate(messages: []) XCTFail() - } catch TokenizerError.chatTemplate(let message) { + } catch let TokenizerError.chatTemplate(message) { XCTAssertEqual(message, "This tokenizer does not have a chat template, and no template was passed.") } catch { XCTFail() diff --git a/Tests/TokenizersTests/DecoderTests.swift b/Tests/TokenizersTests/DecoderTests.swift index b2ecf33..36779c0 100644 --- a/Tests/TokenizersTests/DecoderTests.swift +++ b/Tests/TokenizersTests/DecoderTests.swift @@ -4,12 +4,12 @@ // Created by Pedro Cuenca on 20231123. // -import XCTest import Hub @testable import Tokenizers +import XCTest class DecoderTests: XCTestCase { - // https://github.com/huggingface/tokenizers/pull/1357 + /// https://github.com/huggingface/tokenizers/pull/1357 func testMetaspaceDecoder() { let decoder = MetaspaceDecoder(config: Config([ "add_prefix_space": true, diff --git a/Tests/TokenizersTests/FactoryTests.swift b/Tests/TokenizersTests/FactoryTests.swift index 88c1c57..e713acb 100644 --- a/Tests/TokenizersTests/FactoryTests.swift +++ b/Tests/TokenizersTests/FactoryTests.swift @@ -1,13 +1,13 @@ // // FactoryTests.swift -// +// // // Created by Pedro Cuenca on 4/8/23. // -import XCTest -import Tokenizers import Hub +import Tokenizers +import XCTest class TestWithCustomHubDownloadLocation: XCTestCase { let downloadDestination: URL = { @@ -15,7 +15,7 @@ class TestWithCustomHubDownloadLocation: XCTestCase { return base.appending(component: "huggingface-tests") }() - override func setUp() {} + override func setUp() { } override func tearDown() { do { @@ -26,7 +26,7 @@ class TestWithCustomHubDownloadLocation: XCTestCase { } var hubApi: HubApi { - return HubApi(downloadBase: downloadDestination) + HubApi(downloadBase: downloadDestination) } } diff --git a/Tests/TokenizersTests/SplitTests.swift b/Tests/TokenizersTests/SplitTests.swift index c14b212..95db797 100644 --- a/Tests/TokenizersTests/SplitTests.swift +++ b/Tests/TokenizersTests/SplitTests.swift @@ -5,11 +5,11 @@ // Created by Pedro Cuenca on 20240120. // -import XCTest import Tokenizers +import XCTest class SplitTests: XCTestCase { - func testSplitBehaviorMergedWithPrevious() { + func testSplitBehaviorMergedWithPrevious() { XCTAssertEqual( "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious), ["the-", "final-", "-", "countdown"] diff --git a/Tests/TokenizersTests/SquadDataset.swift b/Tests/TokenizersTests/SquadDataset.swift index 98781b6..2be2116 100644 --- a/Tests/TokenizersTests/SquadDataset.swift +++ b/Tests/TokenizersTests/SquadDataset.swift @@ -8,7 +8,6 @@ import Foundation - /// Our internal type, also used in unit tests struct SquadExample { let qaId: String @@ -24,19 +23,23 @@ struct SquadDataset: Codable { let data: [SquadDatum] let version: String } + struct SquadDatum: Codable { let paragraphs: [SquadParagraph] let title: String } + struct SquadParagraph: Codable { let context: String let qas: [SquadQA] } + struct SquadQA: Codable { let answers: [SquadAnswer] let id: String let question: String } + struct SquadAnswer: Codable { let answer_start: Int let text: String @@ -54,7 +57,7 @@ struct Squad { for datum in squadDataset.data { for paragraph in datum.paragraphs { for qa in paragraph.qas { - let example = SquadExample(qaId: qa.id, context: paragraph.context, question: qa.question, answerText: qa.answers[0].text, startPos: qa.answers[0].answer_start, endPos: -1) // todo: remove -1 + let example = SquadExample(qaId: qa.id, context: paragraph.context, question: qa.question, answerText: qa.answers[0].text, startPos: qa.answers[0].answer_start, endPos: -1) // TODO: remove -1 examples.append(example) } } diff --git a/Tests/TokenizersTests/TokenizerTests.swift b/Tests/TokenizersTests/TokenizerTests.swift index eae7003..1911feb 100644 --- a/Tests/TokenizersTests/TokenizerTests.swift +++ b/Tests/TokenizersTests/TokenizerTests.swift @@ -6,10 +6,10 @@ // Copyright © 2023 Hugging Face. All rights reserved. // -import XCTest import Hub -@testable import Tokenizers @testable import Models +@testable import Tokenizers +import XCTest class GPT2TokenizerTests: TokenizerTests { override class var hubModelName: String? { "distilgpt2" } @@ -84,7 +84,7 @@ class GemmaTokenizerTests: TokenizerTests { } // These are two different characters - let cases = ["à" /* 0x61 0x300 */, "à" /* 0xe0 */] + let cases = ["à" /* 0x61 0x300 */, "à" /* 0xe0 */ ] let expected = [217138, 1305] // These are different characters @@ -212,7 +212,6 @@ class BertSpacesTests: XCTestCase { } } - struct EncodedTokenizerSamplesDataset: Decodable { let text: String // Bad naming, not just for bpe. @@ -222,8 +221,7 @@ struct EncodedTokenizerSamplesDataset: Decodable { let decoded_text: String } - -typealias EdgeCasesDataset = [String : [EdgeCase]] +typealias EdgeCasesDataset = [String: [EdgeCase]] struct EdgeCase: Decodable { let input: String @@ -238,14 +236,13 @@ struct EncodedData: Decodable { let attention_mask: [Int] } - class TokenizerTester { let encodedSamplesFilename: String let unknownTokenId: Int? - private var configuration: LanguageModelConfigurationFromHub? = nil - private var edgeCases: [EdgeCase]? = nil - private var _tokenizer: Tokenizer? = nil + private var configuration: LanguageModelConfigurationFromHub? + private var edgeCases: [EdgeCase]? + private var _tokenizer: Tokenizer? init(hubModelName: String, encodedSamplesFilename: String, unknownTokenId: Int?, hubApi: HubApi) { configuration = LanguageModelConfigurationFromHub(modelName: hubModelName, hubApi: hubApi) @@ -271,7 +268,6 @@ class TokenizerTester { return dataset }() - var tokenizer: Tokenizer? { get async { guard _tokenizer == nil else { return _tokenizer! } @@ -322,7 +318,7 @@ class TokenizerTester { /// Test encode and decode for a few edge cases func testEdgeCases() async { - guard let edgeCases = edgeCases else { + guard let edgeCases else { print("Edge cases test ignored") return } @@ -363,13 +359,13 @@ class TokenizerTester { } class TokenizerTests: XCTestCase { - // Parallel testing in Xcode (when enabled) uses different processes, so this shouldn't be a problem + /// Parallel testing in Xcode (when enabled) uses different processes, so this shouldn't be a problem static var _tester: TokenizerTester? = nil class var hubModelName: String? { nil } class var encodedSamplesFilename: String? { nil } - // Known id retrieved from Python, to verify it was parsed correctly + /// Known id retrieved from Python, to verify it was parsed correctly class var unknownTokenId: Int? { nil } static var downloadDestination: URL = { @@ -380,7 +376,7 @@ class TokenizerTests: XCTestCase { class var hubApi: HubApi { HubApi(downloadBase: downloadDestination) } override class func setUp() { - if let hubModelName = hubModelName, let encodedSamplesFilename = encodedSamplesFilename { + if let hubModelName, let encodedSamplesFilename { _tester = TokenizerTester( hubModelName: hubModelName, encodedSamplesFilename: encodedSamplesFilename, diff --git a/Tests/TokenizersTests/TrieTests.swift b/Tests/TokenizersTests/TrieTests.swift index 15c54f8..64f7528 100644 --- a/Tests/TokenizersTests/TrieTests.swift +++ b/Tests/TokenizersTests/TrieTests.swift @@ -5,8 +5,8 @@ // Created by Pedro Cuenca on 12/1/24. // -import XCTest @testable import Tokenizers +import XCTest class TrieTests: XCTestCase { func testTrieBuilding() { @@ -19,11 +19,11 @@ class TrieTests: XCTestCase { let c = trie.get("c") XCTAssertNotNil(c) - XCTAssertEqual(c!.children.count, 1) // "a" + XCTAssertEqual(c!.children.count, 1) // "a" let ca = trie.get("ca") XCTAssertNotNil(ca) - XCTAssertEqual(ca!.children.count, 2) // "r", "t" + XCTAssertEqual(ca!.children.count, 2) // "r", "t" let car = trie.get("car") XCTAssertNotNil(car)