From 49a760db5a3c2ae7c97c2e2e14bd6f602cb05ebb Mon Sep 17 00:00:00 2001 From: Piotr Kowalczuk Date: Sat, 22 Mar 2025 20:36:27 +0100 Subject: [PATCH 1/2] Sendable Hub.Downloader, Hub.Hub and Hub.HubApi --- Package.swift | 11 +- Sources/Hub/Downloader.swift | 350 +++++++++++++++++--------- Sources/Hub/Hub.swift | 36 +-- Sources/Hub/HubApi.swift | 358 +++++++++++++++------------ Tests/HubTests/DownloaderTests.swift | 146 +++++------ Tests/HubTests/HubApiTests.swift | 124 ++++------ 6 files changed, 565 insertions(+), 460 deletions(-) diff --git a/Package.swift b/Package.swift index 95f6b77..981bd35 100644 --- a/Package.swift +++ b/Package.swift @@ -1,8 +1,13 @@ -// swift-tools-version: 5.8 +// swift-tools-version: 5.9 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription +// Define the strict concurrency settings to be applied to all targets. +let swiftSettings: [SwiftSetting] = [ + .enableExperimentalFeature("StrictConcurrency") +] + let package = Package( name: "swift-transformers", platforms: [.iOS(.v16), .macOS(.v13)], @@ -24,13 +29,13 @@ let package = Package( ] ), .executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]), - .target(name: "Hub", resources: [.process("FallbackConfigs")]), + .target(name: "Hub", resources: [.process("FallbackConfigs")], swiftSettings: swiftSettings), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]), .target(name: "TensorUtils"), .target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]), .target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]), - .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]), + .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")], swiftSettings: swiftSettings), .testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]), .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]), diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift index 0e16e7a..8073795 100644 --- a/Sources/Hub/Downloader.swift +++ b/Sources/Hub/Downloader.swift @@ -9,10 +9,11 @@ import Combine import Foundation -class Downloader: NSObject, ObservableObject { - private(set) var destination: URL - - private let chunkSize = 10 * 1024 * 1024 // 10MB +final class Downloader: NSObject, Sendable, ObservableObject { + private let destination: URL + private let incompleteDestination: URL + private let downloadResumeState: DownloadResumeState = .init() + private let chunkSize: Int enum DownloadState { case notStarted @@ -27,37 +28,25 @@ class Downloader: NSObject, ObservableObject { case tempFileNotFound } - private(set) lazy var downloadState: CurrentValueSubject = CurrentValueSubject(.notStarted) - private var stateSubscriber: Cancellable? - - private(set) var tempFilePath: URL - private(set) var expectedSize: Int? - private(set) var downloadedSize: Int = 0 + private let broadcaster: Broadcaster = Broadcaster { + return DownloadState.notStarted + } - var session: URLSession? = nil - var downloadTask: Task? = nil + private let sessionConfig: URLSessionConfiguration + let session: SessionActor = SessionActor() + private let task: TaskActor = TaskActor() init( - from url: URL, to destination: URL, incompleteDestination: URL, - using authToken: String? = nil, inBackground: Bool = false, - headers: [String: String]? = nil, - expectedSize: Int? = nil, - timeout: TimeInterval = 10, - numRetries: Int = 5 + chunkSize: Int = 10 * 1024 * 1024 // 10MB ) { self.destination = destination - self.expectedSize = expectedSize - // Create incomplete file path based on destination - tempFilePath = incompleteDestination - - // If resume size wasn't specified, check for an existing incomplete file - let resumeSize = Self.incompleteFileSize(at: incompleteDestination) + self.incompleteDestination = incompleteDestination + self.chunkSize = chunkSize - super.init() let sessionIdentifier = "swift-transformers.hub.downloader" var config = URLSessionConfiguration.default @@ -66,23 +55,33 @@ class Downloader: NSObject, ObservableObject { config.isDiscretionary = false config.sessionSendsLaunchEvents = true } - - session = URLSession(configuration: config, delegate: self, delegateQueue: nil) - - setUpDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries) + self.sessionConfig = config } - /// Check if an incomplete file exists for the destination and returns its size - /// - Parameter destination: The destination URL for the download - /// - Returns: Size of the incomplete file if it exists, otherwise 0 - static func incompleteFileSize(at incompletePath: URL) -> Int { - if FileManager.default.fileExists(atPath: incompletePath.path) { - if let attributes = try? FileManager.default.attributesOfItem(atPath: incompletePath.path), let fileSize = attributes[.size] as? Int { - return fileSize - } + func download( + from url: URL, + using authToken: String? = nil, + headers: [String: String]? = nil, + expectedSize: Int? = nil, + timeout: TimeInterval = 10, + numRetries: Int = 5 + ) async -> AsyncStream { + if let task = await self.task.get() { + task.cancel() } - - return 0 + await self.downloadResumeState.setExpectedSize(expectedSize) + let resumeSize = Self.incompleteFileSize(at: self.incompleteDestination) + await self.session.set(URLSession(configuration: self.sessionConfig, delegate: self, delegateQueue: nil)) + await self.setUpDownload( + from: url, + with: authToken, + resumeSize: resumeSize, + headers: headers, + timeout: timeout, + numRetries: numRetries + ) + + return await self.broadcaster.subscribe() } /// Sets up and initiates a file download operation @@ -100,77 +99,92 @@ class Downloader: NSObject, ObservableObject { with authToken: String?, resumeSize: Int, headers: [String: String]?, - expectedSize: Int?, timeout: TimeInterval, numRetries: Int - ) { - session?.getAllTasks { tasks in - // If there's an existing pending background task with the same URL, let it proceed. - if let existing = tasks.filter({ $0.originalRequest?.url == url }).first { - switch existing.state { - case .running: - return - case .suspended: - existing.resume() - return - case .canceling, .completed: - existing.cancel() - @unknown default: - existing.cancel() - } + ) async { + let resumeSize = Self.incompleteFileSize(at: self.incompleteDestination) + guard let tasks = await self.session.get()?.allTasks else { + return + } + + // If there's an existing pending background task with the same URL, let it proceed. + if let existing = tasks.filter({ $0.originalRequest?.url == url }).first { + switch existing.state { + case .running: + return + case .suspended: + existing.resume() + return + case .canceling, .completed: + existing.cancel() + break + @unknown default: + existing.cancel() } + } - self.downloadTask = Task { + await self.task.set( + Task { do { - // Set up the request with appropriate headers var request = URLRequest(url: url) + + // Use headers from argument else create an empty header dictionary var requestHeaders = headers ?? [:] - if let authToken { + // Populate header auth and range fields + if let authToken = authToken { requestHeaders["Authorization"] = "Bearer \(authToken)" } - self.downloadedSize = resumeSize + await self.downloadResumeState.setDownloadedSize(resumeSize) + + if resumeSize > 0 { + requestHeaders["Range"] = "bytes=\(resumeSize)-" + } // Set Range header if we're resuming if resumeSize > 0 { requestHeaders["Range"] = "bytes=\(resumeSize)-" // Calculate and show initial progress - if let expectedSize, expectedSize > 0 { + if let expectedSize = await self.downloadResumeState.expectedSize, expectedSize > 0 { let initialProgress = Double(resumeSize) / Double(expectedSize) - self.downloadState.value = .downloading(initialProgress) + await self.broadcaster.broadcast(state: .downloading(initialProgress)) } else { - self.downloadState.value = .downloading(0) + await self.broadcaster.broadcast(state: .downloading(0)) } } else { - self.downloadState.value = .downloading(0) + await self.broadcaster.broadcast(state: .downloading(0)) } request.timeoutInterval = timeout request.allHTTPHeaderFields = requestHeaders // Open the incomplete file for writing - let tempFile = try FileHandle(forWritingTo: self.tempFilePath) + let tempFile = try FileHandle(forWritingTo: self.incompleteDestination) // If resuming, seek to end of file if resumeSize > 0 { try tempFile.seekToEnd() } - try await self.httpGet(request: request, tempFile: tempFile, resumeSize: self.downloadedSize, numRetries: numRetries, expectedSize: expectedSize) + defer { tempFile.closeFile() } - // Clean up and move the completed download to its final destination - tempFile.closeFile() + try await self.httpGet(request: request, tempFile: tempFile, numRetries: numRetries) try Task.checkCancellation() - try FileManager.default.moveDownloadedFile(from: self.tempFilePath, to: self.destination) - self.downloadState.value = .completed(self.destination) + try FileManager.default.moveDownloadedFile(from: self.incompleteDestination, to: self.destination) + + // // Clean up and move the completed download to its final destination + // tempFile.closeFile() + // try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination) + + await self.broadcaster.broadcast(state: .completed(self.destination)) } catch { - self.downloadState.value = .failed(error) + await self.broadcaster.broadcast(state: .failed(error)) } } - } + ) } /// Downloads a file from given URL using chunked transfer and handles retries. @@ -187,27 +201,26 @@ class Downloader: NSObject, ObservableObject { private func httpGet( request: URLRequest, tempFile: FileHandle, - resumeSize: Int, - numRetries: Int, - expectedSize: Int? + numRetries: Int ) async throws { - guard let session else { + guard let session = await self.session.get() else { throw DownloadError.unexpectedError } // Create a new request with Range header for resuming var newRequest = request - if resumeSize > 0 { - newRequest.setValue("bytes=\(resumeSize)-", forHTTPHeaderField: "Range") + if await self.downloadResumeState.downloadedSize > 0 { + newRequest.setValue("bytes=\(await self.downloadResumeState.downloadedSize)-", forHTTPHeaderField: "Range") } // Start the download and get the byte stream let (asyncBytes, response) = try await session.bytes(for: newRequest) - guard let httpResponse = response as? HTTPURLResponse else { + guard let response = response as? HTTPURLResponse else { throw DownloadError.unexpectedError } - guard (200..<300).contains(httpResponse.statusCode) else { + + guard (200..<300).contains(response.statusCode) else { throw DownloadError.unexpectedError } @@ -220,21 +233,22 @@ class Downloader: NSObject, ObservableObject { buffer.append(byte) // When buffer is full, write to disk if buffer.count == chunkSize { - if !buffer.isEmpty { // Filter out keep-alive chunks + if !buffer.isEmpty { // Filter out keep-alive chunks try tempFile.write(contentsOf: buffer) buffer.removeAll(keepingCapacity: true) - downloadedSize += chunkSize + + await self.downloadResumeState.incDownloadedSize(chunkSize) newNumRetries = 5 - guard let expectedSize else { continue } - let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0 - downloadState.value = .downloading(progress) + guard let expectedSize = await self.downloadResumeState.expectedSize else { continue } + let progress = expectedSize != 0 ? Double(await self.downloadResumeState.downloadedSize) / Double(expectedSize) : 0 + await self.broadcaster.broadcast(state: .downloading(progress)) } } } if !buffer.isEmpty { try tempFile.write(contentsOf: buffer) - downloadedSize += buffer.count + await self.downloadResumeState.incDownloadedSize(buffer.count) buffer.removeAll(keepingCapacity: true) newNumRetries = 5 } @@ -244,74 +258,73 @@ class Downloader: NSObject, ObservableObject { } try await Task.sleep(nanoseconds: 1_000_000_000) - let config = URLSessionConfiguration.default - self.session = URLSession(configuration: config, delegate: self, delegateQueue: nil) + await self.session.set(URLSession(configuration: self.sessionConfig, delegate: self, delegateQueue: nil)) try await httpGet( request: request, tempFile: tempFile, - resumeSize: self.downloadedSize, - numRetries: newNumRetries - 1, - expectedSize: expectedSize + numRetries: newNumRetries - 1 ) + return } // Verify the downloaded file size matches the expected size let actualSize = try tempFile.seekToEnd() - if let expectedSize, expectedSize != actualSize { + if let expectedSize = await self.downloadResumeState.expectedSize, expectedSize != actualSize { throw DownloadError.unexpectedError } } - @discardableResult - func waitUntilDone() throws -> URL { - // It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky) - let semaphore = DispatchSemaphore(value: 0) - stateSubscriber = downloadState.sink { state in - switch state { - case .completed: semaphore.signal() - case .failed: semaphore.signal() - default: break - } - } - semaphore.wait() + func cancel() async { + await self.session.get()?.invalidateAndCancel() + await self.task.get()?.cancel() + await self.broadcaster.broadcast(state: .failed(URLError(.cancelled))) + } - switch downloadState.value { - case let .completed(url): return url - case let .failed(error): throw error - default: throw DownloadError.unexpectedError + /// Check if an incomplete file exists for the destination and returns its size + /// - Parameter destination: The destination URL for the download + /// - Returns: Size of the incomplete file if it exists, otherwise 0 + static func incompleteFileSize(at incompletePath: URL) -> Int { + if FileManager.default.fileExists(atPath: incompletePath.path) { + if let attributes = try? FileManager.default.attributesOfItem(atPath: incompletePath.path), let fileSize = attributes[.size] as? Int { + return fileSize + } } - } - func cancel() { - session?.invalidateAndCancel() - downloadTask?.cancel() - downloadState.value = .failed(URLError(.cancelled)) + return 0 } } extension Downloader: URLSessionDownloadDelegate { func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) { - downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite)) + Task { + await self.broadcaster.broadcast(state: .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))) + } } func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { do { // If the downloaded file already exists on the filesystem, overwrite it - try FileManager.default.moveDownloadedFile(from: location, to: destination) - downloadState.value = .completed(destination) + try FileManager.default.moveDownloadedFile(from: location, to: self.destination) + Task { + await self.broadcaster.broadcast(state: .completed(destination)) + } } catch { - downloadState.value = .failed(error) + Task { + await self.broadcaster.broadcast(state: .failed(error)) + } } } func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - if let error { - downloadState.value = .failed(error) -// } else if let response = task.response as? HTTPURLResponse { -// print("HTTP response status code: \(response.statusCode)") -// let headers = response.allHeaderFields -// print("HTTP response headers: \(headers)") + if let error = error { + Task { + await self.broadcaster.broadcast(state: .failed(error)) + } + // } else if let response = task.response as? HTTPURLResponse { + // print("HTTP response status code: \(response.statusCode)") + // let headers = response.allHeaderFields + // print("HTTP response headers: \(headers)") } } } @@ -328,3 +341,96 @@ extension FileManager { try moveItem(at: srcURL, to: dstURL) } } + +private actor DownloadResumeState { + var expectedSize: Int? + var downloadedSize: Int = 0 + + func setExpectedSize(_ size: Int?) { + self.expectedSize = size + } + + func setDownloadedSize(_ size: Int) { + self.downloadedSize = size + } + + func incDownloadedSize(_ size: Int) { + self.downloadedSize += size + } +} + +actor Broadcaster { + private let initialState: @Sendable () async -> E? + private var latestState: E? + private var continuations: [UUID: AsyncStream.Continuation] = [:] + + init(initialState: @Sendable @escaping () async -> E?) { + self.initialState = initialState + } + + deinit { + self.continuations.removeAll() + } + + func subscribe() -> AsyncStream { + return AsyncStream { continuation in + let id = UUID() + self.continuations[id] = continuation + + continuation.onTermination = { @Sendable status in + Task { + await self.unsubscribe(id) + } + } + + Task { + if let state = self.latestState { + continuation.yield(state) + return + } + if let state = await self.initialState() { + continuation.yield(state) + } + } + } + } + + private func unsubscribe(_ id: UUID) { + self.continuations.removeValue(forKey: id) + } + + func broadcast(state: E) async { + self.latestState = state + await withTaskGroup(of: Void.self) { group in + for continuation in continuations.values { + group.addTask { + continuation.yield(state) + } + } + } + } +} + +actor SessionActor { + private var urlSession: URLSession? = nil + + func set(_ urlSession: URLSession?) { + self.urlSession = urlSession + } + + func get() -> URLSession? { + return self.urlSession + } +} + +actor TaskActor { + private var task: Task? = nil + + func set(_ task: Task?) { + self.task = task + } + + func get() -> Task? { + return self.task + } +} diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 74ce0bf..85fd8fc 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -7,10 +7,10 @@ import Foundation -public struct Hub { } +public struct Hub: Sendable {} -public extension Hub { - enum HubClientError: LocalizedError { +extension Hub { + public enum HubClientError: LocalizedError { case authorizationRequired case httpStatusCode(Int) case parse @@ -51,13 +51,13 @@ public extension Hub { } } - enum RepoType: String, Codable { + public enum RepoType: String, Codable { case models case datasets case spaces } - struct Repo: Codable { + public struct Repo: Codable { public let id: String public let type: RepoType @@ -68,22 +68,22 @@ public extension Hub { } } -public class LanguageModelConfigurationFromHub { +public final class LanguageModelConfigurationFromHub: Sendable { struct Configurations { var modelConfig: Config var tokenizerConfig: Config? var tokenizerData: Config } - private var configPromise: Task? + private let configPromise: Task public init( modelName: String, revision: String = "main", hubApi: HubApi = .shared ) { - configPromise = Task.init { - try await self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi) + self.configPromise = Task.init { + return try await Self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi) } } @@ -91,22 +91,22 @@ public class LanguageModelConfigurationFromHub { modelFolder: URL, hubApi: HubApi = .shared ) { - configPromise = Task { - try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi) + self.configPromise = Task { + return try await Self.loadConfig(modelFolder: modelFolder, hubApi: hubApi) } } public var modelConfig: Config { get async throws { - try await configPromise!.value.modelConfig + try await configPromise.value.modelConfig } } public var tokenizerConfig: Config? { get async throws { - if let hubConfig = try await configPromise!.value.tokenizerConfig { + if let hubConfig = try await configPromise.value.tokenizerConfig { // Try to guess the class if it's not present and the modelType is - if let _: String = hubConfig.tokenizerClass?.string() { return hubConfig } + if hubConfig.tokenizerClass?.string() != nil { return hubConfig } guard let modelType = try await modelType else { return hubConfig } // If the config exists but doesn't contain a tokenizerClass, use a fallback config if we have it @@ -129,7 +129,7 @@ public class LanguageModelConfigurationFromHub { public var tokenizerData: Config { get async throws { - try await configPromise!.value.tokenizerData + try await configPromise.value.tokenizerData } } @@ -139,7 +139,7 @@ public class LanguageModelConfigurationFromHub { } } - func loadConfig( + static func loadConfig( modelName: String, revision: String, hubApi: HubApi = .shared @@ -167,7 +167,7 @@ public class LanguageModelConfigurationFromHub { } } - func loadConfig( + static func loadConfig( modelFolder: URL, hubApi: HubApi = .shared ) async throws -> Configurations { @@ -204,7 +204,7 @@ public class LanguageModelConfigurationFromHub { // Try to load .jinja template as plain text chatTemplate = try? String(contentsOf: chatTemplateJinjaURL, encoding: .utf8) } else if FileManager.default.fileExists(atPath: chatTemplateJsonURL.path), - let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL) + let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL) { // Fall back to .json template chatTemplate = chatTemplateConfig.chatTemplate.string() diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index adfbf4a..39afd5d 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -10,18 +10,24 @@ import Foundation import Network import os -public struct HubApi { +public struct HubApi: Sendable { var downloadBase: URL var hfToken: String? var endpoint: String var useBackgroundSession: Bool - var useOfflineMode: Bool? + var useOfflineMode: Bool? = nil private let networkMonitor = NetworkMonitor() public typealias RepoType = Hub.RepoType public typealias Repo = Hub.Repo - public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false, useOfflineMode: Bool? = nil) { + public init( + downloadBase: URL? = nil, + hfToken: String? = nil, + endpoint: String = "https://huggingface.co", + useBackgroundSession: Bool = false, + useOfflineMode: Bool? = nil + ) { self.hfToken = hfToken ?? Self.hfTokenFromEnv() if let downloadBase { self.downloadBase = downloadBase @@ -43,8 +49,8 @@ public struct HubApi { private static let logger = Logger() } -private extension HubApi { - static func hfTokenFromEnv() -> String? { +extension HubApi { + fileprivate static func hfTokenFromEnv() -> String? { let possibleTokens = [ { ProcessInfo.processInfo.environment["HF_TOKEN"] }, { ProcessInfo.processInfo.environment["HUGGING_FACE_HUB_TOKEN"] }, @@ -76,18 +82,18 @@ private extension HubApi { } /// File retrieval -public extension HubApi { +extension HubApi { /// Model data for parsed filenames - struct Sibling: Codable { + public struct Sibling: Codable { let rfilename: String } - struct SiblingsResponse: Codable { + public struct SiblingsResponse: Codable { let siblings: [Sibling] } /// Throws error if the response code is not 20X - func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { + public func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) if let hfToken { request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") @@ -118,7 +124,7 @@ public extension HubApi { /// Throws error if page does not exist or is not accessible. /// Allows relative redirects but ignores absolute ones for LFS files. - func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { + public func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) request.httpMethod = "HEAD" if let hfToken { @@ -133,16 +139,15 @@ public extension HubApi { guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } switch response.statusCode { - case 200..<400: break // Allow redirects to pass through to the redirect delegate - case 401, 403: throw Hub.HubClientError.authorizationRequired - case 404: throw Hub.HubClientError.fileNotFound(url.lastPathComponent) + case 200..<400: break // Allow redirects to pass through to the redirect delegate + case 400..<500: throw Hub.HubClientError.authorizationRequired default: throw Hub.HubClientError.httpStatusCode(response.statusCode) } return (data, response) } - func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] { + public func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] { // Read repo info and only parse "siblings" let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)/revision/\(revision)")! let (data, _) = try await httpGet(for: url) @@ -157,22 +162,22 @@ public extension HubApi { return Array(selected) } - func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { - try await getFilenames(from: Repo(id: repoId), matching: globs) + public func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { + return try await getFilenames(from: Repo(id: repoId), matching: globs) } - func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { - try await getFilenames(from: repo, matching: [glob]) + public func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { + return try await getFilenames(from: repo, matching: [glob]) } - func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { - try await getFilenames(from: Repo(id: repoId), matching: [glob]) + public func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { + return try await getFilenames(from: Repo(id: repoId), matching: [glob]) } } /// Additional Errors -public extension HubApi { - enum EnvironmentError: LocalizedError { +extension HubApi { + public enum EnvironmentError: LocalizedError { case invalidMetadataError(String) case offlineModeError(String) case fileIntegrityError(String) @@ -180,31 +185,31 @@ public extension HubApi { public var errorDescription: String? { switch self { - case let .invalidMetadataError(message): - String(localized: "Invalid metadata: \(message)") - case let .offlineModeError(message): - String(localized: "Offline mode error: \(message)") - case let .fileIntegrityError(message): - String(localized: "File integrity check failed: \(message)") - case let .fileWriteError(message): - String(localized: "Failed to write file: \(message)") + case .invalidMetadataError(let message): + return String(localized: "Invalid metadata: \(message)") + case .offlineModeError(let message): + return String(localized: "Offline mode error: \(message)") + case .fileIntegrityError(let message): + return String(localized: "File integrity check failed: \(message)") + case .fileWriteError(let message): + return String(localized: "Failed to write file: \(message)") } } } } /// Configuration loading helpers -public extension HubApi { +extension HubApi { /// Assumes the file has already been downloaded. /// `filename` is relative to the download base. - func configuration(from filename: String, in repo: Repo) throws -> Config { + public func configuration(from filename: String, in repo: Repo) throws -> Config { let fileURL = localRepoLocation(repo).appending(path: filename) return try configuration(fileURL: fileURL) } /// Assumes the file is already present at local url. /// `fileURL` is a complete local file path for the given model - func configuration(fileURL: URL) throws -> Config { + public func configuration(fileURL: URL) throws -> Config { let data = try Data(contentsOf: fileURL) let parsed = try JSONSerialization.jsonObject(with: data, options: []) guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse } @@ -213,8 +218,8 @@ public extension HubApi { } /// Whoami -public extension HubApi { - func whoami() async throws -> Config { +extension HubApi { + public func whoami() async throws -> Config { guard hfToken != nil else { throw Hub.HubClientError.authorizationRequired } let url = URL(string: "\(endpoint)/api/whoami-v2")! @@ -227,8 +232,8 @@ public extension HubApi { } /// Snaphsot download -public extension HubApi { - func localRepoLocation(_ repo: Repo) -> URL { +extension HubApi { + public func localRepoLocation(_ repo: Repo) -> URL { downloadBase.appending(component: repo.type.rawValue).appending(component: repo.id) } @@ -241,7 +246,7 @@ public extension HubApi { /// - filePath: The path of the file for which metadata is being read. /// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed. /// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid. - func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? { + public func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? { if FileManager.default.fileExists(atPath: metadataPath.path) { do { let contents = try String(contentsOf: metadataPath, encoding: .utf8) @@ -285,13 +290,13 @@ public extension HubApi { return nil } - func isValidHash(hash: String, pattern: String) -> Bool { + public func isValidHash(hash: String, pattern: String) -> Bool { let regex = try? NSRegularExpression(pattern: pattern) let range = NSRange(location: 0, length: hash.utf16.count) return regex?.firstMatch(in: hash, options: [], range: range) != nil } - func computeFileHash(file url: URL) throws -> String { + public func computeFileHash(file url: URL) throws -> String { // Open file for reading guard let fileHandle = try? FileHandle(forReadingFrom: url) else { throw Hub.HubClientError.fileNotFound(url.lastPathComponent) @@ -302,13 +307,13 @@ public extension HubApi { } var hasher = SHA256() - let chunkSize = 1024 * 1024 // 1MB chunks + let chunkSize = 1024 * 1024 // 1MB chunks while autoreleasepool(invoking: { let nextChunk = try? fileHandle.read(upToCount: chunkSize) guard let nextChunk, - !nextChunk.isEmpty + !nextChunk.isEmpty else { return false } @@ -316,14 +321,14 @@ public extension HubApi { hasher.update(data: nextChunk) return true - }) { } + }) {} let digest = hasher.finalize() return digest.map { String(format: "%02x", $0) }.joined() } /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 - func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws { + public func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws { let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" do { try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true) @@ -333,8 +338,7 @@ public extension HubApi { } } - struct HubFileDownloader { - let hub: HubApi + public struct HubFileDownloader { let repo: Repo let revision: String let repoDestination: URL @@ -382,22 +386,24 @@ public extension HubApi { /// (See for example PipelineLoader in swift-coreml-diffusers) @discardableResult func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { - let localMetadata = try hub.readDownloadMetadata(metadataPath: metadataDestination) - let remoteMetadata = try await hub.getFileMetadata(url: source) + let localMetadata = try HubApi.shared.readDownloadMetadata(metadataPath: metadataDestination) + let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source) let localCommitHash = localMetadata?.commitHash ?? "" let remoteCommitHash = remoteMetadata.commitHash ?? "" // Local file exists + metadata exists + commit_hash matches => return file - if hub.isValidHash(hash: remoteCommitHash, pattern: hub.commitHashPattern), downloaded, localMetadata != nil, localCommitHash == remoteCommitHash { + if HubApi.shared.isValidHash(hash: remoteCommitHash, pattern: HubApi.shared.commitHashPattern), downloaded, localMetadata != nil, + localCommitHash == remoteCommitHash + { return destination } // From now on, etag, commit_hash, url and size are not empty guard let remoteCommitHash = remoteMetadata.commitHash, - let remoteEtag = remoteMetadata.etag, - let remoteSize = remoteMetadata.size, - remoteMetadata.location != "" + let remoteEtag = remoteMetadata.etag, + let remoteSize = remoteMetadata.size, + remoteMetadata.location != "" else { throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server") } @@ -406,7 +412,7 @@ public extension HubApi { if downloaded { // etag matches => update metadata and return file if localMetadata?.etag == remoteEtag { - try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) + try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) return destination } @@ -414,10 +420,10 @@ public extension HubApi { // => means it's an LFS file (large) // => let's compute local hash and compare // => if match, update metadata and return file - if hub.isValidHash(hash: remoteEtag, pattern: hub.sha256Pattern) { - let fileHash = try hub.computeFileHash(file: destination) + if HubApi.shared.isValidHash(hash: remoteEtag, pattern: HubApi.shared.sha256Pattern) { + let fileHash = try HubApi.shared.computeFileHash(file: destination) if fileHash == remoteEtag { - try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) + try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) return destination } } @@ -427,51 +433,46 @@ public extension HubApi { let incompleteDestination = repoMetadataDestination.appending(path: relativeFilename + ".\(remoteEtag).incomplete") try prepareCacheDestination(incompleteDestination) - let downloader = Downloader( - from: source, - to: destination, - incompleteDestination: incompleteDestination, - using: hfToken, - inBackground: backgroundSession, - expectedSize: remoteSize - ) + let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination, inBackground: backgroundSession) - return try await withTaskCancellationHandler { - let downloadSubscriber = downloader.downloadState.sink { state in + try await withTaskCancellationHandler { + let sub = await downloader.download(from: source, using: hfToken, expectedSize: remoteSize) + listen: for await state in sub { switch state { - case let .downloading(progress): + case .notStarted: + continue + case .downloading(let progress): progressHandler(progress) - case .completed, .failed, .notStarted: - break - } - } - do { - _ = try withExtendedLifetime(downloadSubscriber) { - try downloader.waitUntilDone() + case .failed(let error): + throw error + case .completed: + break listen } - - try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) - - return destination - } catch { - // If download fails, leave the incomplete file in place for future resume - throw error } } onCancel: { - downloader.cancel() + Task { + await downloader.cancel() + } } + + try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) + + return destination } } @discardableResult - func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + public func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) + async throws -> URL + { let repoDestination = localRepoLocation(repo) - let repoMetadataDestination = repoDestination + let repoMetadataDestination = + repoDestination .appendingPathComponent(".cache") .appendingPathComponent("huggingface") .appendingPathComponent("download") - if useOfflineMode ?? NetworkMonitor.shared.shouldUseOfflineMode() { + if await NetworkMonitor.shared.state.shouldUseOfflineMode() || useOfflineMode == true { if !FileManager.default.fileExists(atPath: repoDestination.path) { throw EnvironmentError.offlineModeError(String(localized: "Repository not available locally")) } @@ -482,10 +483,12 @@ public extension HubApi { } for fileUrl in fileUrls { - let metadataPath = URL(fileURLWithPath: fileUrl.path.replacingOccurrences( - of: repoDestination.path, - with: repoMetadataDestination.path - ) + ".metadata") + let metadataPath = URL( + fileURLWithPath: fileUrl.path.replacingOccurrences( + of: repoDestination.path, + with: repoMetadataDestination.path + ) + ".metadata" + ) let localMetadata = try readDownloadMetadata(metadataPath: metadataPath) @@ -511,7 +514,6 @@ public extension HubApi { for filename in filenames { let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1) let downloader = HubFileDownloader( - hub: self, repo: repo, revision: revision, repoDestination: repoDestination, @@ -521,36 +523,42 @@ public extension HubApi { endpoint: endpoint, backgroundSession: useBackgroundSession ) + try await downloader.download { fractionDownloaded in fileProgress.completedUnitCount = Int64(100 * fractionDownloaded) progressHandler(progress) } + if Task.isCancelled { + return repoDestination + } + fileProgress.completedUnitCount = 100 } + progressHandler(progress) return repoDestination } @discardableResult - func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler) + public func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + return try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } @discardableResult - func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - try await snapshot(from: repo, revision: revision, matching: [glob], progressHandler: progressHandler) + public func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + return try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler) } @discardableResult - func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler) + public func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + return try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler) } } /// Metadata -public extension HubApi { +extension HubApi { /// Data structure containing information about a file versioned on the Hub - struct FileMetadata { + public struct FileMetadata { /// The commit hash related to the file public let commitHash: String? @@ -565,7 +573,7 @@ public extension HubApi { } /// Metadata about a file in the local directory related to a download process - struct LocalDownloadFileMetadata { + public struct LocalDownloadFileMetadata { /// Commit hash of the file in the repo public let commitHash: String @@ -599,7 +607,7 @@ public extension HubApi { ) } - func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] { + public func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] { let files = try await getFilenames(from: repo, matching: globs) let url = URL(string: "\(endpoint)/\(repo.id)/resolve/\(revision)")! var selectedMetadata: [FileMetadata] = [] @@ -610,28 +618,45 @@ public extension HubApi { return selectedMetadata } - func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] { + public func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] { try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: globs) } - func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] { + public func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] { try await getFileMetadata(from: repo, revision: revision, matching: [glob]) } - func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] { + public func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] { try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: [glob]) } } /// Network monitor helper class to help decide whether to use offline mode -private extension HubApi { - private final class NetworkMonitor { - private var monitor: NWPathMonitor - private var queue: DispatchQueue +extension HubApi { + private actor NetworkStateActor { + public var isConnected: Bool = false + public var isExpensive: Bool = false + public var isConstrained: Bool = false + + func update(path: NWPath) { + self.isConnected = path.status == .satisfied + self.isExpensive = path.isExpensive + self.isConstrained = path.isConstrained + } - private(set) var isConnected: Bool = false - private(set) var isExpensive: Bool = false - private(set) var isConstrained: Bool = false + func shouldUseOfflineMode() -> Bool { + if ProcessInfo.processInfo.environment["CI_DISABLE_NETWORK_MONITOR"] == "1" { + return false + } + return !isConnected || isExpensive || isConstrained + } + } + + private final class NetworkMonitor: Sendable { + private let monitor: NWPathMonitor + private let queue: DispatchQueue + + public let state: NetworkStateActor = .init() static let shared = NetworkMonitor() @@ -643,27 +668,19 @@ private extension HubApi { func startMonitoring() { monitor.pathUpdateHandler = { [weak self] path in - guard let self else { return } - - isConnected = path.status == .satisfied - isExpensive = path.isExpensive - isConstrained = path.isConstrained + guard let self = self else { return } + Task { + await self.state.update(path: path) + } } - monitor.start(queue: queue) + monitor.start(queue: self.queue) } func stopMonitoring() { monitor.cancel() } - func shouldUseOfflineMode() -> Bool { - if ProcessInfo.processInfo.environment["CI_DISABLE_NETWORK_MONITOR"] == "1" { - return false - } - return !isConnected || isExpensive || isConstrained - } - deinit { stopMonitoring() } @@ -671,80 +688,84 @@ private extension HubApi { } /// Stateless wrappers that use `HubApi` instances -public extension Hub { - static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { - try await HubApi.shared.getFilenames(from: repo, matching: globs) +extension Hub { + public static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { + return try await HubApi.shared.getFilenames(from: repo, matching: globs) } - static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { - try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs) + public static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { + return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs) } - static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { - try await HubApi.shared.getFilenames(from: repo, matching: glob) + public static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { + return try await HubApi.shared.getFilenames(from: repo, matching: glob) } - static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { - try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob) + public static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { + return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob) } - static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) + public static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + return try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) } - static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) + public static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws + -> URL + { + return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } - static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler) + public static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + return try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler) } - static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) + public static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) } - static func whoami(token: String) async throws -> Config { - try await HubApi(hfToken: token).whoami() + public static func whoami(token: String) async throws -> Config { + return try await HubApi(hfToken: token).whoami() } - static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata { - try await HubApi.shared.getFileMetadata(url: fileURL) + public static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata { + return try await HubApi.shared.getFileMetadata(url: fileURL) } - static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { - try await HubApi.shared.getFileMetadata(from: repo, matching: globs) + public static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { + return try await HubApi.shared.getFileMetadata(from: repo, matching: globs) } - static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { - try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs) + public static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { + return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs) } - static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] { - try await HubApi.shared.getFileMetadata(from: repo, matching: [glob]) + public static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] { + return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob]) } - static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] { - try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob]) + public static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] { + return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob]) } } -public extension [String] { - func matching(glob: String) -> [String] { +extension [String] { + public func matching(glob: String) -> [String] { filter { fnmatch(glob, $0, 0) == 0 } } } -public extension FileManager { - func getFileUrls(at directoryUrl: URL) throws -> [URL] { +extension FileManager { + public func getFileUrls(at directoryUrl: URL) throws -> [URL] { var fileUrls = [URL]() // Get all contents including subdirectories - guard let enumerator = FileManager.default.enumerator( - at: directoryUrl, - includingPropertiesForKeys: [.isRegularFileKey, .isHiddenKey], - options: [.skipsHiddenFiles] - ) else { + guard + let enumerator = FileManager.default.enumerator( + at: directoryUrl, + includingPropertiesForKeys: [.isRegularFileKey, .isHiddenKey], + options: [.skipsHiddenFiles] + ) + else { return fileUrls } @@ -765,19 +786,26 @@ public extension FileManager { /// Only allow relative redirects and reject others /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258 -private class RedirectDelegate: NSObject, URLSessionTaskDelegate { - func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) { +private final class RedirectDelegate: NSObject, URLSessionTaskDelegate, Sendable { + func urlSession( + _ session: URLSession, + task: URLSessionTask, + willPerformHTTPRedirection response: HTTPURLResponse, + newRequest request: URLRequest, + completionHandler: @escaping (URLRequest?) -> Void + ) { // Check if it's a redirect status code (300-399) if (300...399).contains(response.statusCode) { // Get the Location header if let locationString = response.value(forHTTPHeaderField: "Location"), - let locationUrl = URL(string: locationString) + let locationUrl = URL(string: locationString) { + // Check if it's a relative redirect (no host component) if locationUrl.host == nil { // For relative redirects, construct the new URL using the original request's base if let originalUrl = task.originalRequest?.url, - var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) + var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) { // Update the path component with the relative path components.path = locationUrl.path diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift index d62d2e8..09e779c 100644 --- a/Tests/HubTests/DownloaderTests.swift +++ b/Tests/HubTests/DownloaderTests.swift @@ -6,6 +6,8 @@ // import Combine +import XCTest + @testable import Hub import XCTest @@ -17,19 +19,20 @@ enum DownloadError: LocalizedError { var errorDescription: String? { switch self { case .invalidDownloadLocation: - String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid") + return String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid") case .unexpectedError: - String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message") + return String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message") } } } private extension Downloader { - func interruptDownload() { - session?.invalidateAndCancel() + func interruptDownload() async { + await self.session.get()?.invalidateAndCancel() } } + final class DownloaderTests: XCTestCase { var tempDir: URL! @@ -53,18 +56,18 @@ final class DownloaderTests: XCTestCase { let etag = try await Hub.getFileMetadata(fileURL: url).etag! let destination = tempDir.appendingPathComponent("config.json") let fileContent = """ - { - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "model_type": "llama", - "pad_token_id": 0, - "vocab_size": 32000 - } + { + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "model_type": "llama", + "pad_token_id": 0, + "vocab_size": 32000 + } - """ + """ let cacheDir = tempDir.appendingPathComponent("cache") try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true) @@ -72,33 +75,22 @@ final class DownloaderTests: XCTestCase { let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) - let downloader = Downloader( - from: url, - to: destination, - incompleteDestination: incompleteDestination - ) - - // Store subscriber outside the continuation to maintain its lifecycle - var subscriber: AnyCancellable? - - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - subscriber = downloader.downloadState.sink { state in - switch state { - case .completed: - continuation.resume() - case let .failed(error): - continuation.resume(throwing: error) - case .downloading: - break - case .notStarted: - break - } + let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination) + let sub = await downloader.download(from: url) + + listen: for await state in sub { + switch state { + case .notStarted: + continue + case .downloading(let progress): + continue + case .failed(let error): + throw error + case .completed: + break listen } } - // Cancel subscription after continuation completes - subscriber?.cancel() - // Verify download completed successfully XCTAssertTrue(FileManager.default.fileExists(atPath: destination.path)) XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), fileContent) @@ -116,18 +108,22 @@ final class DownloaderTests: XCTestCase { let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) - // Create downloader with incorrect expected size - let downloader = Downloader( - from: url, - to: destination, - incompleteDestination: incompleteDestination, - expectedSize: 999999 // Incorrect size - ) - - do { - try downloader.waitUntilDone() - XCTFail("Download should have failed due to size mismatch") - } catch { } + let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination) + // Download with incorrect expected size + let sub = await downloader.download(from: url, expectedSize: 999999) // Incorrect size + listen: for await state in sub { + switch state { + case .notStarted: + continue + case .downloading(let progress): + continue + case .failed: + break listen + case .completed: + XCTFail("Download should have failed due to size mismatch") + break listen + } + } // Verify no file was created at destination XCTAssertFalse(FileManager.default.fileExists(atPath: destination.path)) @@ -141,8 +137,10 @@ final class DownloaderTests: XCTestCase { let destination = tempDir.appendingPathComponent("SAM%202%20Studio%201.1.zip") // Create parent directory if it doesn't exist - try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(), - withIntermediateDirectories: true) + try FileManager.default.createDirectory( + at: destination.deletingLastPathComponent(), + withIntermediateDirectories: true + ) let cacheDir = tempDir.appendingPathComponent("cache") try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true) @@ -150,42 +148,32 @@ final class DownloaderTests: XCTestCase { let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) - let downloader = Downloader( - from: url, - to: destination, - incompleteDestination: incompleteDestination, - expectedSize: 73194001 // Correct size for verification - ) + let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination) + let sub = await downloader.download(from: url, expectedSize: 73_194_001) // Correct size for verification // First interruption point at 50% var threshold = 0.5 - var subscriber: AnyCancellable? - do { // Monitor download progress and interrupt at thresholds to test if // download continues from where it left off - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - subscriber = downloader.downloadState.sink { state in - switch state { - case let .downloading(progress): - if threshold != 1.0, progress >= threshold { - // Move to next threshold and interrupt - threshold = threshold == 0.5 ? 0.75 : 1.0 - downloader.interruptDownload() - } - case .completed: - continuation.resume() - case let .failed(error): - continuation.resume(throwing: error) - case .notStarted: - break + listen: for await state in sub { + switch state { + case .notStarted: + continue + case .downloading(let progress): + if threshold != 1.0 && progress >= threshold { + // Move to next threshold and interrupt + threshold = threshold == 0.5 ? 0.75 : 1.0 + await downloader.interruptDownload() } + case .failed(let error): + throw error + case .completed: + break listen } } - - subscriber?.cancel() - + // Verify the file exists and is complete if FileManager.default.fileExists(atPath: destination.path) { let attributes = try FileManager.default.attributesOfItem(atPath: destination.path) diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index b03716f..5a53452 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -4,9 +4,10 @@ // Created by Pedro Cuenca on 20231230. // -@testable import Hub import XCTest +@testable import Hub + class HubApiTests: XCTestCase { override func setUp() { // Put setup code here. This method is called before the invocation of each test method in the class. @@ -150,10 +151,14 @@ class HubApiTests: XCTestCase { do { let revision = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb" let etag = "fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107" - let location = "https://cdn-lfs.hf.co/repos/4a/4e/4a4e587f66a2979dcd75e1d7324df8ee9ef74be3582a05bea31c2c26d0d467d0/fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.mlmodel%3B+filename%3D%22model.mlmodel" + let location = + "https://cdn-lfs.hf.co/repos/4a/4e/4a4e587f66a2979dcd75e1d7324df8ee9ef74be3582a05bea31c2c26d0d467d0/fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.mlmodel%3B+filename%3D%22model.mlmodel" let size = 504766 - let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel") + let url = URL( + string: + "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel" + ) let metadata = try await Hub.getFileMetadata(fileURL: url!) XCTAssertEqual(metadata.commitHash, revision) @@ -188,7 +193,12 @@ class SnapshotDownloadTests: XCTestCase { var filenames: [String] = [] let prefix = downloadDestination.appending(path: "models/\(repo)").path.appending("/") - if let enumerator = FileManager.default.enumerator(at: url, includingPropertiesForKeys: [.isRegularFileKey], options: [.skipsHiddenFiles], errorHandler: nil) { + if let enumerator = FileManager.default.enumerator( + at: url, + includingPropertiesForKeys: [.isRegularFileKey], + options: [.skipsHiddenFiles], + errorHandler: nil + ) { for case let fileURL as URL in enumerator { do { let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey]) @@ -256,7 +266,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual( Set(downloadedFilenames), Set([ - "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json" ]) ) } @@ -405,7 +415,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual( Set(downloadedMetadataFilenames), Set([ - ".cache/huggingface/download/tokenizer.json.metadata", + ".cache/huggingface/download/tokenizer.json.metadata" ]) ) @@ -534,7 +544,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual( Set(downloadedMetadataFilenames), Set([ - ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata" ]) ) @@ -915,7 +925,11 @@ class SnapshotDownloadTests: XCTestCase { let metadataDestination = downloadedTo.appendingPathComponent(".cache/huggingface/download").appendingPathComponent("x.bin.metadata") - try "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2ab4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4\n0\n".write(to: metadataDestination, atomically: true, encoding: .utf8) + try "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2ab4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4\n0\n".write( + to: metadataDestination, + atomically: true, + encoding: .utf8 + ) hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) @@ -972,7 +986,9 @@ class SnapshotDownloadTests: XCTestCase { func testResumeDownloadFromEmptyIncomplete() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - var downloadedTo = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent("Library/Caches/huggingface-tests/models/coreml-projects/Llama-2-7b-chat-coreml") + var downloadedTo = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent( + "Library/Caches/huggingface-tests/models/coreml-projects/Llama-2-7b-chat-coreml" + ) let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") @@ -993,17 +1009,17 @@ class SnapshotDownloadTests: XCTestCase { let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) let expected = """ - { - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "model_type": "llama", - "pad_token_id": 0, - "vocab_size": 32000 - } - """ + { + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "model_type": "llama", + "pad_token_id": 0, + "vocab_size": 32000 + } + """ XCTAssertTrue(fileContents.contains(expected)) } @@ -1032,17 +1048,17 @@ class SnapshotDownloadTests: XCTestCase { let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) let expected = """ - X - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "model_type": "llama", - "pad_token_id": 0, - "vocab_size": 32000 - } - """ + X + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "model_type": "llama", + "pad_token_id": 0, + "vocab_size": 32000 + } + """ XCTAssertTrue(fileContents.contains(expected)) } @@ -1070,6 +1086,7 @@ class SnapshotDownloadTests: XCTestCase { // Cancel the download once we've seen progress downloadTask.cancel() + try await Task.sleep(nanoseconds: 5_000_000_000) // Resume download with a new task @@ -1078,48 +1095,9 @@ class SnapshotDownloadTests: XCTestCase { } let filePath = downloadedTo.appendingPathComponent(targetFile) - XCTAssertTrue(FileManager.default.fileExists(atPath: filePath.path), - "Downloaded file should exist at \(filePath.path)") - } - - func testDownloadWithRevision() async throws { - let hubApi = HubApi(downloadBase: downloadDestination) - var lastProgress: Progress? = nil - - let commitHash = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb" - let downloadedTo = try await hubApi.snapshot(from: repo, revision: commitHash, matching: "*.json") { progress in - print("Total Progress: \(progress.fractionCompleted)") - print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") - lastProgress = progress - } - - let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) - XCTAssertEqual(lastProgress?.fractionCompleted, 1) - XCTAssertEqual(lastProgress?.completedUnitCount, 6) - XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - XCTAssertEqual( - Set(downloadedFilenames), - Set([ - "config.json", "tokenizer.json", "tokenizer_config.json", - "llama-2-7b-chat.mlpackage/Manifest.json", - "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", - "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", - ]) + XCTAssertTrue( + FileManager.default.fileExists(atPath: filePath.path), + "Downloaded file should exist at \(filePath.path)" ) - - do { - let revision = "nonexistent-revision" - try await hubApi.snapshot(from: repo, revision: revision, matching: "*.json") - XCTFail("Expected an error to be thrown") - } catch let error as Hub.HubClientError { - switch error { - case .fileNotFound: - break // Error type is correct - default: - XCTFail("Wrong error type: \(error)") - } - } catch { - XCTFail("Unexpected error: \(error)") - } } } From 8de7f498fddfcdd8218fdad7a58ff2d7d1562709 Mon Sep 17 00:00:00 2001 From: Piotr Kowalczuk Date: Mon, 9 Jun 2025 11:55:45 +0200 Subject: [PATCH 2/2] Sendable Hub.Downloader, Hub.Hub and Hub.HubApi .2 --- Package.swift | 4 +- Sources/Hub/Downloader.swift | 80 +++++----- Sources/Hub/Hub.swift | 20 +-- Sources/Hub/HubApi.swift | 228 ++++++++++++++------------- Tests/HubTests/DownloaderTests.swift | 47 +++--- Tests/HubTests/HubApiTests.swift | 93 ++++++++--- 6 files changed, 257 insertions(+), 215 deletions(-) diff --git a/Package.swift b/Package.swift index 981bd35..43c5696 100644 --- a/Package.swift +++ b/Package.swift @@ -3,9 +3,9 @@ import PackageDescription -// Define the strict concurrency settings to be applied to all targets. +/// Define the strict concurrency settings to be applied to all targets. let swiftSettings: [SwiftSetting] = [ - .enableExperimentalFeature("StrictConcurrency") + .enableExperimentalFeature("StrictConcurrency"), ] let package = Package( diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift index 8073795..cc0171f 100644 --- a/Sources/Hub/Downloader.swift +++ b/Sources/Hub/Downloader.swift @@ -29,18 +29,18 @@ final class Downloader: NSObject, Sendable, ObservableObject { } private let broadcaster: Broadcaster = Broadcaster { - return DownloadState.notStarted + DownloadState.notStarted } private let sessionConfig: URLSessionConfiguration - let session: SessionActor = SessionActor() - private let task: TaskActor = TaskActor() + let session: SessionActor = .init() + private let task: TaskActor = .init() init( to destination: URL, incompleteDestination: URL, inBackground: Bool = false, - chunkSize: Int = 10 * 1024 * 1024 // 10MB + chunkSize: Int = 10 * 1024 * 1024 // 10MB ) { self.destination = destination // Create incomplete file path based on destination @@ -55,7 +55,7 @@ final class Downloader: NSObject, Sendable, ObservableObject { config.isDiscretionary = false config.sessionSendsLaunchEvents = true } - self.sessionConfig = config + sessionConfig = config } func download( @@ -66,13 +66,13 @@ final class Downloader: NSObject, Sendable, ObservableObject { timeout: TimeInterval = 10, numRetries: Int = 5 ) async -> AsyncStream { - if let task = await self.task.get() { + if let task = await task.get() { task.cancel() } - await self.downloadResumeState.setExpectedSize(expectedSize) - let resumeSize = Self.incompleteFileSize(at: self.incompleteDestination) - await self.session.set(URLSession(configuration: self.sessionConfig, delegate: self, delegateQueue: nil)) - await self.setUpDownload( + await downloadResumeState.setExpectedSize(expectedSize) + let resumeSize = Self.incompleteFileSize(at: incompleteDestination) + await session.set(URLSession(configuration: sessionConfig, delegate: self, delegateQueue: nil)) + await setUpDownload( from: url, with: authToken, resumeSize: resumeSize, @@ -81,7 +81,7 @@ final class Downloader: NSObject, Sendable, ObservableObject { numRetries: numRetries ) - return await self.broadcaster.subscribe() + return await broadcaster.subscribe() } /// Sets up and initiates a file download operation @@ -102,8 +102,8 @@ final class Downloader: NSObject, Sendable, ObservableObject { timeout: TimeInterval, numRetries: Int ) async { - let resumeSize = Self.incompleteFileSize(at: self.incompleteDestination) - guard let tasks = await self.session.get()?.allTasks else { + let resumeSize = Self.incompleteFileSize(at: incompleteDestination) + guard let tasks = await session.get()?.allTasks else { return } @@ -123,7 +123,7 @@ final class Downloader: NSObject, Sendable, ObservableObject { } } - await self.task.set( + await task.set( Task { do { var request = URLRequest(url: url) @@ -132,7 +132,7 @@ final class Downloader: NSObject, Sendable, ObservableObject { var requestHeaders = headers ?? [:] // Populate header auth and range fields - if let authToken = authToken { + if let authToken { requestHeaders["Authorization"] = "Bearer \(authToken)" } @@ -203,14 +203,14 @@ final class Downloader: NSObject, Sendable, ObservableObject { tempFile: FileHandle, numRetries: Int ) async throws { - guard let session = await self.session.get() else { + guard let session = await session.get() else { throw DownloadError.unexpectedError } // Create a new request with Range header for resuming var newRequest = request - if await self.downloadResumeState.downloadedSize > 0 { - newRequest.setValue("bytes=\(await self.downloadResumeState.downloadedSize)-", forHTTPHeaderField: "Range") + if await downloadResumeState.downloadedSize > 0 { + await newRequest.setValue("bytes=\(downloadResumeState.downloadedSize)-", forHTTPHeaderField: "Range") } // Start the download and get the byte stream @@ -233,22 +233,22 @@ final class Downloader: NSObject, Sendable, ObservableObject { buffer.append(byte) // When buffer is full, write to disk if buffer.count == chunkSize { - if !buffer.isEmpty { // Filter out keep-alive chunks + if !buffer.isEmpty { // Filter out keep-alive chunks try tempFile.write(contentsOf: buffer) buffer.removeAll(keepingCapacity: true) - await self.downloadResumeState.incDownloadedSize(chunkSize) + await downloadResumeState.incDownloadedSize(chunkSize) newNumRetries = 5 - guard let expectedSize = await self.downloadResumeState.expectedSize else { continue } - let progress = expectedSize != 0 ? Double(await self.downloadResumeState.downloadedSize) / Double(expectedSize) : 0 - await self.broadcaster.broadcast(state: .downloading(progress)) + guard let expectedSize = await downloadResumeState.expectedSize else { continue } + let progress = await expectedSize != 0 ? Double(downloadResumeState.downloadedSize) / Double(expectedSize) : 0 + await broadcaster.broadcast(state: .downloading(progress)) } } } if !buffer.isEmpty { try tempFile.write(contentsOf: buffer) - await self.downloadResumeState.incDownloadedSize(buffer.count) + await downloadResumeState.incDownloadedSize(buffer.count) buffer.removeAll(keepingCapacity: true) newNumRetries = 5 } @@ -270,15 +270,15 @@ final class Downloader: NSObject, Sendable, ObservableObject { // Verify the downloaded file size matches the expected size let actualSize = try tempFile.seekToEnd() - if let expectedSize = await self.downloadResumeState.expectedSize, expectedSize != actualSize { + if let expectedSize = await downloadResumeState.expectedSize, expectedSize != actualSize { throw DownloadError.unexpectedError } } func cancel() async { - await self.session.get()?.invalidateAndCancel() - await self.task.get()?.cancel() - await self.broadcaster.broadcast(state: .failed(URLError(.cancelled))) + await session.get()?.invalidateAndCancel() + await task.get()?.cancel() + await broadcaster.broadcast(state: .failed(URLError(.cancelled))) } /// Check if an incomplete file exists for the destination and returns its size @@ -305,7 +305,7 @@ extension Downloader: URLSessionDownloadDelegate { func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { do { // If the downloaded file already exists on the filesystem, overwrite it - try FileManager.default.moveDownloadedFile(from: location, to: self.destination) + try FileManager.default.moveDownloadedFile(from: location, to: destination) Task { await self.broadcaster.broadcast(state: .completed(destination)) } @@ -317,7 +317,7 @@ extension Downloader: URLSessionDownloadDelegate { } func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - if let error = error { + if let error { Task { await self.broadcaster.broadcast(state: .failed(error)) } @@ -347,15 +347,15 @@ private actor DownloadResumeState { var downloadedSize: Int = 0 func setExpectedSize(_ size: Int?) { - self.expectedSize = size + expectedSize = size } func setDownloadedSize(_ size: Int) { - self.downloadedSize = size + downloadedSize = size } func incDownloadedSize(_ size: Int) { - self.downloadedSize += size + downloadedSize += size } } @@ -373,7 +373,7 @@ actor Broadcaster { } func subscribe() -> AsyncStream { - return AsyncStream { continuation in + AsyncStream { continuation in let id = UUID() self.continuations[id] = continuation @@ -396,11 +396,11 @@ actor Broadcaster { } private func unsubscribe(_ id: UUID) { - self.continuations.removeValue(forKey: id) + continuations.removeValue(forKey: id) } func broadcast(state: E) async { - self.latestState = state + latestState = state await withTaskGroup(of: Void.self) { group in for continuation in continuations.values { group.addTask { @@ -412,25 +412,25 @@ actor Broadcaster { } actor SessionActor { - private var urlSession: URLSession? = nil + private var urlSession: URLSession? func set(_ urlSession: URLSession?) { self.urlSession = urlSession } func get() -> URLSession? { - return self.urlSession + urlSession } } actor TaskActor { - private var task: Task? = nil + private var task: Task? func set(_ task: Task?) { self.task = task } func get() -> Task? { - return self.task + task } } diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 85fd8fc..00b7b2c 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -7,10 +7,10 @@ import Foundation -public struct Hub: Sendable {} +public struct Hub: Sendable { } -extension Hub { - public enum HubClientError: LocalizedError { +public extension Hub { + enum HubClientError: LocalizedError { case authorizationRequired case httpStatusCode(Int) case parse @@ -51,13 +51,13 @@ extension Hub { } } - public enum RepoType: String, Codable { + enum RepoType: String, Codable { case models case datasets case spaces } - public struct Repo: Codable { + struct Repo: Codable { public let id: String public let type: RepoType @@ -82,8 +82,8 @@ public final class LanguageModelConfigurationFromHub: Sendable { revision: String = "main", hubApi: HubApi = .shared ) { - self.configPromise = Task.init { - return try await Self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi) + configPromise = Task.init { + try await Self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi) } } @@ -91,8 +91,8 @@ public final class LanguageModelConfigurationFromHub: Sendable { modelFolder: URL, hubApi: HubApi = .shared ) { - self.configPromise = Task { - return try await Self.loadConfig(modelFolder: modelFolder, hubApi: hubApi) + configPromise = Task { + try await Self.loadConfig(modelFolder: modelFolder, hubApi: hubApi) } } @@ -204,7 +204,7 @@ public final class LanguageModelConfigurationFromHub: Sendable { // Try to load .jinja template as plain text chatTemplate = try? String(contentsOf: chatTemplateJinjaURL, encoding: .utf8) } else if FileManager.default.fileExists(atPath: chatTemplateJsonURL.path), - let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL) + let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL) { // Fall back to .json template chatTemplate = chatTemplateConfig.chatTemplate.string() diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 39afd5d..7c8f61e 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -49,8 +49,8 @@ public struct HubApi: Sendable { private static let logger = Logger() } -extension HubApi { - fileprivate static func hfTokenFromEnv() -> String? { +private extension HubApi { + static func hfTokenFromEnv() -> String? { let possibleTokens = [ { ProcessInfo.processInfo.environment["HF_TOKEN"] }, { ProcessInfo.processInfo.environment["HUGGING_FACE_HUB_TOKEN"] }, @@ -82,18 +82,18 @@ extension HubApi { } /// File retrieval -extension HubApi { +public extension HubApi { /// Model data for parsed filenames - public struct Sibling: Codable { + struct Sibling: Codable { let rfilename: String } - public struct SiblingsResponse: Codable { + struct SiblingsResponse: Codable { let siblings: [Sibling] } /// Throws error if the response code is not 20X - public func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { + func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) if let hfToken { request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") @@ -124,7 +124,7 @@ extension HubApi { /// Throws error if page does not exist or is not accessible. /// Allows relative redirects but ignores absolute ones for LFS files. - public func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { + func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) request.httpMethod = "HEAD" if let hfToken { @@ -139,15 +139,16 @@ extension HubApi { guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } switch response.statusCode { - case 200..<400: break // Allow redirects to pass through to the redirect delegate - case 400..<500: throw Hub.HubClientError.authorizationRequired + case 200..<400: break // Allow redirects to pass through to the redirect delegate + case 401, 403: throw Hub.HubClientError.authorizationRequired + case 404: throw Hub.HubClientError.fileNotFound(url.lastPathComponent) default: throw Hub.HubClientError.httpStatusCode(response.statusCode) } return (data, response) } - public func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] { + func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] { // Read repo info and only parse "siblings" let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)/revision/\(revision)")! let (data, _) = try await httpGet(for: url) @@ -162,22 +163,22 @@ extension HubApi { return Array(selected) } - public func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { - return try await getFilenames(from: Repo(id: repoId), matching: globs) + func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { + try await getFilenames(from: Repo(id: repoId), matching: globs) } - public func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { - return try await getFilenames(from: repo, matching: [glob]) + func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { + try await getFilenames(from: repo, matching: [glob]) } - public func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { - return try await getFilenames(from: Repo(id: repoId), matching: [glob]) + func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { + try await getFilenames(from: Repo(id: repoId), matching: [glob]) } } /// Additional Errors -extension HubApi { - public enum EnvironmentError: LocalizedError { +public extension HubApi { + enum EnvironmentError: LocalizedError { case invalidMetadataError(String) case offlineModeError(String) case fileIntegrityError(String) @@ -185,31 +186,31 @@ extension HubApi { public var errorDescription: String? { switch self { - case .invalidMetadataError(let message): - return String(localized: "Invalid metadata: \(message)") - case .offlineModeError(let message): - return String(localized: "Offline mode error: \(message)") - case .fileIntegrityError(let message): - return String(localized: "File integrity check failed: \(message)") - case .fileWriteError(let message): - return String(localized: "Failed to write file: \(message)") + case let .invalidMetadataError(message): + String(localized: "Invalid metadata: \(message)") + case let .offlineModeError(message): + String(localized: "Offline mode error: \(message)") + case let .fileIntegrityError(message): + String(localized: "File integrity check failed: \(message)") + case let .fileWriteError(message): + String(localized: "Failed to write file: \(message)") } } } } /// Configuration loading helpers -extension HubApi { +public extension HubApi { /// Assumes the file has already been downloaded. /// `filename` is relative to the download base. - public func configuration(from filename: String, in repo: Repo) throws -> Config { + func configuration(from filename: String, in repo: Repo) throws -> Config { let fileURL = localRepoLocation(repo).appending(path: filename) return try configuration(fileURL: fileURL) } /// Assumes the file is already present at local url. /// `fileURL` is a complete local file path for the given model - public func configuration(fileURL: URL) throws -> Config { + func configuration(fileURL: URL) throws -> Config { let data = try Data(contentsOf: fileURL) let parsed = try JSONSerialization.jsonObject(with: data, options: []) guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse } @@ -218,8 +219,8 @@ extension HubApi { } /// Whoami -extension HubApi { - public func whoami() async throws -> Config { +public extension HubApi { + func whoami() async throws -> Config { guard hfToken != nil else { throw Hub.HubClientError.authorizationRequired } let url = URL(string: "\(endpoint)/api/whoami-v2")! @@ -232,8 +233,8 @@ extension HubApi { } /// Snaphsot download -extension HubApi { - public func localRepoLocation(_ repo: Repo) -> URL { +public extension HubApi { + func localRepoLocation(_ repo: Repo) -> URL { downloadBase.appending(component: repo.type.rawValue).appending(component: repo.id) } @@ -246,7 +247,7 @@ extension HubApi { /// - filePath: The path of the file for which metadata is being read. /// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed. /// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid. - public func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? { + func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? { if FileManager.default.fileExists(atPath: metadataPath.path) { do { let contents = try String(contentsOf: metadataPath, encoding: .utf8) @@ -290,13 +291,13 @@ extension HubApi { return nil } - public func isValidHash(hash: String, pattern: String) -> Bool { + func isValidHash(hash: String, pattern: String) -> Bool { let regex = try? NSRegularExpression(pattern: pattern) let range = NSRange(location: 0, length: hash.utf16.count) return regex?.firstMatch(in: hash, options: [], range: range) != nil } - public func computeFileHash(file url: URL) throws -> String { + func computeFileHash(file url: URL) throws -> String { // Open file for reading guard let fileHandle = try? FileHandle(forReadingFrom: url) else { throw Hub.HubClientError.fileNotFound(url.lastPathComponent) @@ -307,13 +308,13 @@ extension HubApi { } var hasher = SHA256() - let chunkSize = 1024 * 1024 // 1MB chunks + let chunkSize = 1024 * 1024 // 1MB chunks while autoreleasepool(invoking: { let nextChunk = try? fileHandle.read(upToCount: chunkSize) guard let nextChunk, - !nextChunk.isEmpty + !nextChunk.isEmpty else { return false } @@ -321,14 +322,14 @@ extension HubApi { hasher.update(data: nextChunk) return true - }) {} + }) { } let digest = hasher.finalize() return digest.map { String(format: "%02x", $0) }.joined() } /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 - public func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws { + func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws { let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" do { try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true) @@ -338,7 +339,8 @@ extension HubApi { } } - public struct HubFileDownloader { + struct HubFileDownloader { + let hub: HubApi let repo: Repo let revision: String let repoDestination: URL @@ -386,24 +388,24 @@ extension HubApi { /// (See for example PipelineLoader in swift-coreml-diffusers) @discardableResult func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { - let localMetadata = try HubApi.shared.readDownloadMetadata(metadataPath: metadataDestination) - let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source) + let localMetadata = try hub.readDownloadMetadata(metadataPath: metadataDestination) + let remoteMetadata = try await hub.getFileMetadata(url: source) let localCommitHash = localMetadata?.commitHash ?? "" let remoteCommitHash = remoteMetadata.commitHash ?? "" // Local file exists + metadata exists + commit_hash matches => return file - if HubApi.shared.isValidHash(hash: remoteCommitHash, pattern: HubApi.shared.commitHashPattern), downloaded, localMetadata != nil, - localCommitHash == remoteCommitHash + if hub.isValidHash(hash: remoteCommitHash, pattern: hub.commitHashPattern), downloaded, localMetadata != nil, + localCommitHash == remoteCommitHash { return destination } // From now on, etag, commit_hash, url and size are not empty guard let remoteCommitHash = remoteMetadata.commitHash, - let remoteEtag = remoteMetadata.etag, - let remoteSize = remoteMetadata.size, - remoteMetadata.location != "" + let remoteEtag = remoteMetadata.etag, + let remoteSize = remoteMetadata.size, + remoteMetadata.location != "" else { throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server") } @@ -412,7 +414,7 @@ extension HubApi { if downloaded { // etag matches => update metadata and return file if localMetadata?.etag == remoteEtag { - try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) + try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) return destination } @@ -420,10 +422,10 @@ extension HubApi { // => means it's an LFS file (large) // => let's compute local hash and compare // => if match, update metadata and return file - if HubApi.shared.isValidHash(hash: remoteEtag, pattern: HubApi.shared.sha256Pattern) { - let fileHash = try HubApi.shared.computeFileHash(file: destination) + if hub.isValidHash(hash: remoteEtag, pattern: hub.sha256Pattern) { + let fileHash = try hub.computeFileHash(file: destination) if fileHash == remoteEtag { - try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) + try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) return destination } } @@ -441,9 +443,9 @@ extension HubApi { switch state { case .notStarted: continue - case .downloading(let progress): + case let .downloading(progress): progressHandler(progress) - case .failed(let error): + case let .failed(error): throw error case .completed: break listen @@ -455,22 +457,22 @@ extension HubApi { } } - try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) + try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) return destination } } @discardableResult - public func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) + func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { let repoDestination = localRepoLocation(repo) let repoMetadataDestination = repoDestination - .appendingPathComponent(".cache") - .appendingPathComponent("huggingface") - .appendingPathComponent("download") + .appendingPathComponent(".cache") + .appendingPathComponent("huggingface") + .appendingPathComponent("download") if await NetworkMonitor.shared.state.shouldUseOfflineMode() || useOfflineMode == true { if !FileManager.default.fileExists(atPath: repoDestination.path) { @@ -514,6 +516,7 @@ extension HubApi { for filename in filenames { let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1) let downloader = HubFileDownloader( + hub: self, repo: repo, revision: revision, repoDestination: repoDestination, @@ -540,25 +543,25 @@ extension HubApi { } @discardableResult - public func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) + func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler) } @discardableResult - public func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler) + func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + try await snapshot(from: repo, revision: revision, matching: [glob], progressHandler: progressHandler) } @discardableResult - public func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler) + func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler) } } /// Metadata -extension HubApi { +public extension HubApi { /// Data structure containing information about a file versioned on the Hub - public struct FileMetadata { + struct FileMetadata { /// The commit hash related to the file public let commitHash: String? @@ -573,7 +576,7 @@ extension HubApi { } /// Metadata about a file in the local directory related to a download process - public struct LocalDownloadFileMetadata { + struct LocalDownloadFileMetadata { /// Commit hash of the file in the repo public let commitHash: String @@ -607,7 +610,7 @@ extension HubApi { ) } - public func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] { + func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] { let files = try await getFilenames(from: repo, matching: globs) let url = URL(string: "\(endpoint)/\(repo.id)/resolve/\(revision)")! var selectedMetadata: [FileMetadata] = [] @@ -618,15 +621,15 @@ extension HubApi { return selectedMetadata } - public func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] { + func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] { try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: globs) } - public func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] { + func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] { try await getFileMetadata(from: repo, revision: revision, matching: [glob]) } - public func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] { + func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] { try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: [glob]) } } @@ -639,9 +642,9 @@ extension HubApi { public var isConstrained: Bool = false func update(path: NWPath) { - self.isConnected = path.status == .satisfied - self.isExpensive = path.isExpensive - self.isConstrained = path.isConstrained + isConnected = path.status == .satisfied + isExpensive = path.isExpensive + isConstrained = path.isConstrained } func shouldUseOfflineMode() -> Bool { @@ -668,13 +671,13 @@ extension HubApi { func startMonitoring() { monitor.pathUpdateHandler = { [weak self] path in - guard let self = self else { return } + guard let self else { return } Task { await self.state.update(path: path) } } - monitor.start(queue: self.queue) + monitor.start(queue: queue) } func stopMonitoring() { @@ -688,74 +691,74 @@ extension HubApi { } /// Stateless wrappers that use `HubApi` instances -extension Hub { - public static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { - return try await HubApi.shared.getFilenames(from: repo, matching: globs) +public extension Hub { + static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { + try await HubApi.shared.getFilenames(from: repo, matching: globs) } - public static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { - return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs) + static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] { + try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs) } - public static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { - return try await HubApi.shared.getFilenames(from: repo, matching: glob) + static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] { + try await HubApi.shared.getFilenames(from: repo, matching: glob) } - public static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { - return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob) + static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] { + try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob) } - public static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) + static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) } - public static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws + static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) + try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } - public static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler) + static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler) } - public static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { - return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) + static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler) } - public static func whoami(token: String) async throws -> Config { - return try await HubApi(hfToken: token).whoami() + static func whoami(token: String) async throws -> Config { + try await HubApi(hfToken: token).whoami() } - public static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata { - return try await HubApi.shared.getFileMetadata(url: fileURL) + static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata { + try await HubApi.shared.getFileMetadata(url: fileURL) } - public static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { - return try await HubApi.shared.getFileMetadata(from: repo, matching: globs) + static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { + try await HubApi.shared.getFileMetadata(from: repo, matching: globs) } - public static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { - return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs) + static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { + try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs) } - public static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] { - return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob]) + static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] { + try await HubApi.shared.getFileMetadata(from: repo, matching: [glob]) } - public static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] { - return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob]) + static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] { + try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob]) } } -extension [String] { - public func matching(glob: String) -> [String] { +public extension [String] { + func matching(glob: String) -> [String] { filter { fnmatch(glob, $0, 0) == 0 } } } -extension FileManager { - public func getFileUrls(at directoryUrl: URL) throws -> [URL] { +public extension FileManager { + func getFileUrls(at directoryUrl: URL) throws -> [URL] { var fileUrls = [URL]() // Get all contents including subdirectories @@ -798,14 +801,13 @@ private final class RedirectDelegate: NSObject, URLSessionTaskDelegate, Sendable if (300...399).contains(response.statusCode) { // Get the Location header if let locationString = response.value(forHTTPHeaderField: "Location"), - let locationUrl = URL(string: locationString) + let locationUrl = URL(string: locationString) { - // Check if it's a relative redirect (no host component) if locationUrl.host == nil { // For relative redirects, construct the new URL using the original request's base if let originalUrl = task.originalRequest?.url, - var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) + var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) { // Update the path component with the relative path components.path = locationUrl.path diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift index 09e779c..d1533d2 100644 --- a/Tests/HubTests/DownloaderTests.swift +++ b/Tests/HubTests/DownloaderTests.swift @@ -19,20 +19,19 @@ enum DownloadError: LocalizedError { var errorDescription: String? { switch self { case .invalidDownloadLocation: - return String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid") + String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid") case .unexpectedError: - return String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message") + String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message") } } } private extension Downloader { func interruptDownload() async { - await self.session.get()?.invalidateAndCancel() + await session.get()?.invalidateAndCancel() } } - final class DownloaderTests: XCTestCase { var tempDir: URL! @@ -56,18 +55,18 @@ final class DownloaderTests: XCTestCase { let etag = try await Hub.getFileMetadata(fileURL: url).etag! let destination = tempDir.appendingPathComponent("config.json") let fileContent = """ - { - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "model_type": "llama", - "pad_token_id": 0, - "vocab_size": 32000 - } + { + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "model_type": "llama", + "pad_token_id": 0, + "vocab_size": 32000 + } - """ + """ let cacheDir = tempDir.appendingPathComponent("cache") try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true) @@ -82,9 +81,9 @@ final class DownloaderTests: XCTestCase { switch state { case .notStarted: continue - case .downloading(let progress): + case .downloading: continue - case .failed(let error): + case let .failed(error): throw error case .completed: break listen @@ -110,12 +109,12 @@ final class DownloaderTests: XCTestCase { let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination) // Download with incorrect expected size - let sub = await downloader.download(from: url, expectedSize: 999999) // Incorrect size + let sub = await downloader.download(from: url, expectedSize: 999999) // Incorrect size listen: for await state in sub { switch state { case .notStarted: continue - case .downloading(let progress): + case .downloading: continue case .failed: break listen @@ -149,7 +148,7 @@ final class DownloaderTests: XCTestCase { FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination) - let sub = await downloader.download(from: url, expectedSize: 73_194_001) // Correct size for verification + let sub = await downloader.download(from: url, expectedSize: 73_194_001) // Correct size for verification // First interruption point at 50% var threshold = 0.5 @@ -161,19 +160,19 @@ final class DownloaderTests: XCTestCase { switch state { case .notStarted: continue - case .downloading(let progress): - if threshold != 1.0 && progress >= threshold { + case let .downloading(progress): + if threshold != 1.0, progress >= threshold { // Move to next threshold and interrupt threshold = threshold == 0.5 ? 0.75 : 1.0 await downloader.interruptDownload() } - case .failed(let error): + case let .failed(error): throw error case .completed: break listen } } - + // Verify the file exists and is complete if FileManager.default.fileExists(atPath: destination.path) { let attributes = try FileManager.default.attributesOfItem(atPath: destination.path) diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 5a53452..61d4968 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -157,7 +157,7 @@ class HubApiTests: XCTestCase { let url = URL( string: - "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel" + "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel" ) let metadata = try await Hub.getFileMetadata(fileURL: url!) @@ -266,7 +266,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual( Set(downloadedFilenames), Set([ - "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json" + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", ]) ) } @@ -415,7 +415,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual( Set(downloadedMetadataFilenames), Set([ - ".cache/huggingface/download/tokenizer.json.metadata" + ".cache/huggingface/download/tokenizer.json.metadata", ]) ) @@ -544,7 +544,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual( Set(downloadedMetadataFilenames), Set([ - ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata" + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata", ]) ) @@ -1009,17 +1009,17 @@ class SnapshotDownloadTests: XCTestCase { let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) let expected = """ - { - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "model_type": "llama", - "pad_token_id": 0, - "vocab_size": 32000 - } - """ + { + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "model_type": "llama", + "pad_token_id": 0, + "vocab_size": 32000 + } + """ XCTAssertTrue(fileContents.contains(expected)) } @@ -1048,17 +1048,17 @@ class SnapshotDownloadTests: XCTestCase { let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) let expected = """ - X - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "model_type": "llama", - "pad_token_id": 0, - "vocab_size": 32000 - } - """ + X + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "model_type": "llama", + "pad_token_id": 0, + "vocab_size": 32000 + } + """ XCTAssertTrue(fileContents.contains(expected)) } @@ -1100,4 +1100,45 @@ class SnapshotDownloadTests: XCTestCase { "Downloaded file should exist at \(filePath.path)" ) } + + func testDownloadWithRevision() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let commitHash = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb" + let downloadedTo = try await hubApi.snapshot(from: repo, revision: commitHash, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + XCTAssertEqual( + Set(downloadedFilenames), + Set([ + "config.json", "tokenizer.json", "tokenizer_config.json", + "llama-2-7b-chat.mlpackage/Manifest.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + ]) + ) + + do { + let revision = "nonexistent-revision" + try await hubApi.snapshot(from: repo, revision: revision, matching: "*.json") + XCTFail("Expected an error to be thrown") + } catch let error as Hub.HubClientError { + switch error { + case .fileNotFound: + break // Error type is correct + default: + XCTFail("Wrong error type: \(error)") + } + } catch { + XCTFail("Unexpected error: \(error)") + } + } }