From 49a760db5a3c2ae7c97c2e2e14bd6f602cb05ebb Mon Sep 17 00:00:00 2001
From: Piotr Kowalczuk
Date: Sat, 22 Mar 2025 20:36:27 +0100
Subject: [PATCH 1/2] Sendable Hub.Downloader, Hub.Hub and Hub.HubApi
---
Package.swift | 11 +-
Sources/Hub/Downloader.swift | 350 +++++++++++++++++---------
Sources/Hub/Hub.swift | 36 +--
Sources/Hub/HubApi.swift | 358 +++++++++++++++------------
Tests/HubTests/DownloaderTests.swift | 146 +++++------
Tests/HubTests/HubApiTests.swift | 124 ++++------
6 files changed, 565 insertions(+), 460 deletions(-)
diff --git a/Package.swift b/Package.swift
index 95f6b77..981bd35 100644
--- a/Package.swift
+++ b/Package.swift
@@ -1,8 +1,13 @@
-// swift-tools-version: 5.8
+// swift-tools-version: 5.9
// The swift-tools-version declares the minimum version of Swift required to build this package.
import PackageDescription
+// Define the strict concurrency settings to be applied to all targets.
+let swiftSettings: [SwiftSetting] = [
+ .enableExperimentalFeature("StrictConcurrency")
+]
+
let package = Package(
name: "swift-transformers",
platforms: [.iOS(.v16), .macOS(.v13)],
@@ -24,13 +29,13 @@ let package = Package(
]
),
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
- .target(name: "Hub", resources: [.process("FallbackConfigs")]),
+ .target(name: "Hub", resources: [.process("FallbackConfigs")], swiftSettings: swiftSettings),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),
.target(name: "TensorUtils"),
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]),
- .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),
+ .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")], swiftSettings: swiftSettings),
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]),
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift
index 0e16e7a..8073795 100644
--- a/Sources/Hub/Downloader.swift
+++ b/Sources/Hub/Downloader.swift
@@ -9,10 +9,11 @@
import Combine
import Foundation
-class Downloader: NSObject, ObservableObject {
- private(set) var destination: URL
-
- private let chunkSize = 10 * 1024 * 1024 // 10MB
+final class Downloader: NSObject, Sendable, ObservableObject {
+ private let destination: URL
+ private let incompleteDestination: URL
+ private let downloadResumeState: DownloadResumeState = .init()
+ private let chunkSize: Int
enum DownloadState {
case notStarted
@@ -27,37 +28,25 @@ class Downloader: NSObject, ObservableObject {
case tempFileNotFound
}
- private(set) lazy var downloadState: CurrentValueSubject = CurrentValueSubject(.notStarted)
- private var stateSubscriber: Cancellable?
-
- private(set) var tempFilePath: URL
- private(set) var expectedSize: Int?
- private(set) var downloadedSize: Int = 0
+ private let broadcaster: Broadcaster = Broadcaster {
+ return DownloadState.notStarted
+ }
- var session: URLSession? = nil
- var downloadTask: Task? = nil
+ private let sessionConfig: URLSessionConfiguration
+ let session: SessionActor = SessionActor()
+ private let task: TaskActor = TaskActor()
init(
- from url: URL,
to destination: URL,
incompleteDestination: URL,
- using authToken: String? = nil,
inBackground: Bool = false,
- headers: [String: String]? = nil,
- expectedSize: Int? = nil,
- timeout: TimeInterval = 10,
- numRetries: Int = 5
+ chunkSize: Int = 10 * 1024 * 1024 // 10MB
) {
self.destination = destination
- self.expectedSize = expectedSize
-
// Create incomplete file path based on destination
- tempFilePath = incompleteDestination
-
- // If resume size wasn't specified, check for an existing incomplete file
- let resumeSize = Self.incompleteFileSize(at: incompleteDestination)
+ self.incompleteDestination = incompleteDestination
+ self.chunkSize = chunkSize
- super.init()
let sessionIdentifier = "swift-transformers.hub.downloader"
var config = URLSessionConfiguration.default
@@ -66,23 +55,33 @@ class Downloader: NSObject, ObservableObject {
config.isDiscretionary = false
config.sessionSendsLaunchEvents = true
}
-
- session = URLSession(configuration: config, delegate: self, delegateQueue: nil)
-
- setUpDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
+ self.sessionConfig = config
}
- /// Check if an incomplete file exists for the destination and returns its size
- /// - Parameter destination: The destination URL for the download
- /// - Returns: Size of the incomplete file if it exists, otherwise 0
- static func incompleteFileSize(at incompletePath: URL) -> Int {
- if FileManager.default.fileExists(atPath: incompletePath.path) {
- if let attributes = try? FileManager.default.attributesOfItem(atPath: incompletePath.path), let fileSize = attributes[.size] as? Int {
- return fileSize
- }
+ func download(
+ from url: URL,
+ using authToken: String? = nil,
+ headers: [String: String]? = nil,
+ expectedSize: Int? = nil,
+ timeout: TimeInterval = 10,
+ numRetries: Int = 5
+ ) async -> AsyncStream {
+ if let task = await self.task.get() {
+ task.cancel()
}
-
- return 0
+ await self.downloadResumeState.setExpectedSize(expectedSize)
+ let resumeSize = Self.incompleteFileSize(at: self.incompleteDestination)
+ await self.session.set(URLSession(configuration: self.sessionConfig, delegate: self, delegateQueue: nil))
+ await self.setUpDownload(
+ from: url,
+ with: authToken,
+ resumeSize: resumeSize,
+ headers: headers,
+ timeout: timeout,
+ numRetries: numRetries
+ )
+
+ return await self.broadcaster.subscribe()
}
/// Sets up and initiates a file download operation
@@ -100,77 +99,92 @@ class Downloader: NSObject, ObservableObject {
with authToken: String?,
resumeSize: Int,
headers: [String: String]?,
- expectedSize: Int?,
timeout: TimeInterval,
numRetries: Int
- ) {
- session?.getAllTasks { tasks in
- // If there's an existing pending background task with the same URL, let it proceed.
- if let existing = tasks.filter({ $0.originalRequest?.url == url }).first {
- switch existing.state {
- case .running:
- return
- case .suspended:
- existing.resume()
- return
- case .canceling, .completed:
- existing.cancel()
- @unknown default:
- existing.cancel()
- }
+ ) async {
+ let resumeSize = Self.incompleteFileSize(at: self.incompleteDestination)
+ guard let tasks = await self.session.get()?.allTasks else {
+ return
+ }
+
+ // If there's an existing pending background task with the same URL, let it proceed.
+ if let existing = tasks.filter({ $0.originalRequest?.url == url }).first {
+ switch existing.state {
+ case .running:
+ return
+ case .suspended:
+ existing.resume()
+ return
+ case .canceling, .completed:
+ existing.cancel()
+ break
+ @unknown default:
+ existing.cancel()
}
+ }
- self.downloadTask = Task {
+ await self.task.set(
+ Task {
do {
- // Set up the request with appropriate headers
var request = URLRequest(url: url)
+
+ // Use headers from argument else create an empty header dictionary
var requestHeaders = headers ?? [:]
- if let authToken {
+ // Populate header auth and range fields
+ if let authToken = authToken {
requestHeaders["Authorization"] = "Bearer \(authToken)"
}
- self.downloadedSize = resumeSize
+ await self.downloadResumeState.setDownloadedSize(resumeSize)
+
+ if resumeSize > 0 {
+ requestHeaders["Range"] = "bytes=\(resumeSize)-"
+ }
// Set Range header if we're resuming
if resumeSize > 0 {
requestHeaders["Range"] = "bytes=\(resumeSize)-"
// Calculate and show initial progress
- if let expectedSize, expectedSize > 0 {
+ if let expectedSize = await self.downloadResumeState.expectedSize, expectedSize > 0 {
let initialProgress = Double(resumeSize) / Double(expectedSize)
- self.downloadState.value = .downloading(initialProgress)
+ await self.broadcaster.broadcast(state: .downloading(initialProgress))
} else {
- self.downloadState.value = .downloading(0)
+ await self.broadcaster.broadcast(state: .downloading(0))
}
} else {
- self.downloadState.value = .downloading(0)
+ await self.broadcaster.broadcast(state: .downloading(0))
}
request.timeoutInterval = timeout
request.allHTTPHeaderFields = requestHeaders
// Open the incomplete file for writing
- let tempFile = try FileHandle(forWritingTo: self.tempFilePath)
+ let tempFile = try FileHandle(forWritingTo: self.incompleteDestination)
// If resuming, seek to end of file
if resumeSize > 0 {
try tempFile.seekToEnd()
}
- try await self.httpGet(request: request, tempFile: tempFile, resumeSize: self.downloadedSize, numRetries: numRetries, expectedSize: expectedSize)
+ defer { tempFile.closeFile() }
- // Clean up and move the completed download to its final destination
- tempFile.closeFile()
+ try await self.httpGet(request: request, tempFile: tempFile, numRetries: numRetries)
try Task.checkCancellation()
- try FileManager.default.moveDownloadedFile(from: self.tempFilePath, to: self.destination)
- self.downloadState.value = .completed(self.destination)
+ try FileManager.default.moveDownloadedFile(from: self.incompleteDestination, to: self.destination)
+
+ // // Clean up and move the completed download to its final destination
+ // tempFile.closeFile()
+ // try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination)
+
+ await self.broadcaster.broadcast(state: .completed(self.destination))
} catch {
- self.downloadState.value = .failed(error)
+ await self.broadcaster.broadcast(state: .failed(error))
}
}
- }
+ )
}
/// Downloads a file from given URL using chunked transfer and handles retries.
@@ -187,27 +201,26 @@ class Downloader: NSObject, ObservableObject {
private func httpGet(
request: URLRequest,
tempFile: FileHandle,
- resumeSize: Int,
- numRetries: Int,
- expectedSize: Int?
+ numRetries: Int
) async throws {
- guard let session else {
+ guard let session = await self.session.get() else {
throw DownloadError.unexpectedError
}
// Create a new request with Range header for resuming
var newRequest = request
- if resumeSize > 0 {
- newRequest.setValue("bytes=\(resumeSize)-", forHTTPHeaderField: "Range")
+ if await self.downloadResumeState.downloadedSize > 0 {
+ newRequest.setValue("bytes=\(await self.downloadResumeState.downloadedSize)-", forHTTPHeaderField: "Range")
}
// Start the download and get the byte stream
let (asyncBytes, response) = try await session.bytes(for: newRequest)
- guard let httpResponse = response as? HTTPURLResponse else {
+ guard let response = response as? HTTPURLResponse else {
throw DownloadError.unexpectedError
}
- guard (200..<300).contains(httpResponse.statusCode) else {
+
+ guard (200..<300).contains(response.statusCode) else {
throw DownloadError.unexpectedError
}
@@ -220,21 +233,22 @@ class Downloader: NSObject, ObservableObject {
buffer.append(byte)
// When buffer is full, write to disk
if buffer.count == chunkSize {
- if !buffer.isEmpty { // Filter out keep-alive chunks
+ if !buffer.isEmpty { // Filter out keep-alive chunks
try tempFile.write(contentsOf: buffer)
buffer.removeAll(keepingCapacity: true)
- downloadedSize += chunkSize
+
+ await self.downloadResumeState.incDownloadedSize(chunkSize)
newNumRetries = 5
- guard let expectedSize else { continue }
- let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0
- downloadState.value = .downloading(progress)
+ guard let expectedSize = await self.downloadResumeState.expectedSize else { continue }
+ let progress = expectedSize != 0 ? Double(await self.downloadResumeState.downloadedSize) / Double(expectedSize) : 0
+ await self.broadcaster.broadcast(state: .downloading(progress))
}
}
}
if !buffer.isEmpty {
try tempFile.write(contentsOf: buffer)
- downloadedSize += buffer.count
+ await self.downloadResumeState.incDownloadedSize(buffer.count)
buffer.removeAll(keepingCapacity: true)
newNumRetries = 5
}
@@ -244,74 +258,73 @@ class Downloader: NSObject, ObservableObject {
}
try await Task.sleep(nanoseconds: 1_000_000_000)
- let config = URLSessionConfiguration.default
- self.session = URLSession(configuration: config, delegate: self, delegateQueue: nil)
+ await self.session.set(URLSession(configuration: self.sessionConfig, delegate: self, delegateQueue: nil))
try await httpGet(
request: request,
tempFile: tempFile,
- resumeSize: self.downloadedSize,
- numRetries: newNumRetries - 1,
- expectedSize: expectedSize
+ numRetries: newNumRetries - 1
)
+ return
}
// Verify the downloaded file size matches the expected size
let actualSize = try tempFile.seekToEnd()
- if let expectedSize, expectedSize != actualSize {
+ if let expectedSize = await self.downloadResumeState.expectedSize, expectedSize != actualSize {
throw DownloadError.unexpectedError
}
}
- @discardableResult
- func waitUntilDone() throws -> URL {
- // It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
- let semaphore = DispatchSemaphore(value: 0)
- stateSubscriber = downloadState.sink { state in
- switch state {
- case .completed: semaphore.signal()
- case .failed: semaphore.signal()
- default: break
- }
- }
- semaphore.wait()
+ func cancel() async {
+ await self.session.get()?.invalidateAndCancel()
+ await self.task.get()?.cancel()
+ await self.broadcaster.broadcast(state: .failed(URLError(.cancelled)))
+ }
- switch downloadState.value {
- case let .completed(url): return url
- case let .failed(error): throw error
- default: throw DownloadError.unexpectedError
+ /// Check if an incomplete file exists for the destination and returns its size
+ /// - Parameter destination: The destination URL for the download
+ /// - Returns: Size of the incomplete file if it exists, otherwise 0
+ static func incompleteFileSize(at incompletePath: URL) -> Int {
+ if FileManager.default.fileExists(atPath: incompletePath.path) {
+ if let attributes = try? FileManager.default.attributesOfItem(atPath: incompletePath.path), let fileSize = attributes[.size] as? Int {
+ return fileSize
+ }
}
- }
- func cancel() {
- session?.invalidateAndCancel()
- downloadTask?.cancel()
- downloadState.value = .failed(URLError(.cancelled))
+ return 0
}
}
extension Downloader: URLSessionDownloadDelegate {
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
- downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))
+ Task {
+ await self.broadcaster.broadcast(state: .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite)))
+ }
}
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
do {
// If the downloaded file already exists on the filesystem, overwrite it
- try FileManager.default.moveDownloadedFile(from: location, to: destination)
- downloadState.value = .completed(destination)
+ try FileManager.default.moveDownloadedFile(from: location, to: self.destination)
+ Task {
+ await self.broadcaster.broadcast(state: .completed(destination))
+ }
} catch {
- downloadState.value = .failed(error)
+ Task {
+ await self.broadcaster.broadcast(state: .failed(error))
+ }
}
}
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
- if let error {
- downloadState.value = .failed(error)
-// } else if let response = task.response as? HTTPURLResponse {
-// print("HTTP response status code: \(response.statusCode)")
-// let headers = response.allHeaderFields
-// print("HTTP response headers: \(headers)")
+ if let error = error {
+ Task {
+ await self.broadcaster.broadcast(state: .failed(error))
+ }
+ // } else if let response = task.response as? HTTPURLResponse {
+ // print("HTTP response status code: \(response.statusCode)")
+ // let headers = response.allHeaderFields
+ // print("HTTP response headers: \(headers)")
}
}
}
@@ -328,3 +341,96 @@ extension FileManager {
try moveItem(at: srcURL, to: dstURL)
}
}
+
+private actor DownloadResumeState {
+ var expectedSize: Int?
+ var downloadedSize: Int = 0
+
+ func setExpectedSize(_ size: Int?) {
+ self.expectedSize = size
+ }
+
+ func setDownloadedSize(_ size: Int) {
+ self.downloadedSize = size
+ }
+
+ func incDownloadedSize(_ size: Int) {
+ self.downloadedSize += size
+ }
+}
+
+actor Broadcaster {
+ private let initialState: @Sendable () async -> E?
+ private var latestState: E?
+ private var continuations: [UUID: AsyncStream.Continuation] = [:]
+
+ init(initialState: @Sendable @escaping () async -> E?) {
+ self.initialState = initialState
+ }
+
+ deinit {
+ self.continuations.removeAll()
+ }
+
+ func subscribe() -> AsyncStream {
+ return AsyncStream { continuation in
+ let id = UUID()
+ self.continuations[id] = continuation
+
+ continuation.onTermination = { @Sendable status in
+ Task {
+ await self.unsubscribe(id)
+ }
+ }
+
+ Task {
+ if let state = self.latestState {
+ continuation.yield(state)
+ return
+ }
+ if let state = await self.initialState() {
+ continuation.yield(state)
+ }
+ }
+ }
+ }
+
+ private func unsubscribe(_ id: UUID) {
+ self.continuations.removeValue(forKey: id)
+ }
+
+ func broadcast(state: E) async {
+ self.latestState = state
+ await withTaskGroup(of: Void.self) { group in
+ for continuation in continuations.values {
+ group.addTask {
+ continuation.yield(state)
+ }
+ }
+ }
+ }
+}
+
+actor SessionActor {
+ private var urlSession: URLSession? = nil
+
+ func set(_ urlSession: URLSession?) {
+ self.urlSession = urlSession
+ }
+
+ func get() -> URLSession? {
+ return self.urlSession
+ }
+}
+
+actor TaskActor {
+ private var task: Task? = nil
+
+ func set(_ task: Task?) {
+ self.task = task
+ }
+
+ func get() -> Task? {
+ return self.task
+ }
+}
diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift
index 74ce0bf..85fd8fc 100644
--- a/Sources/Hub/Hub.swift
+++ b/Sources/Hub/Hub.swift
@@ -7,10 +7,10 @@
import Foundation
-public struct Hub { }
+public struct Hub: Sendable {}
-public extension Hub {
- enum HubClientError: LocalizedError {
+extension Hub {
+ public enum HubClientError: LocalizedError {
case authorizationRequired
case httpStatusCode(Int)
case parse
@@ -51,13 +51,13 @@ public extension Hub {
}
}
- enum RepoType: String, Codable {
+ public enum RepoType: String, Codable {
case models
case datasets
case spaces
}
- struct Repo: Codable {
+ public struct Repo: Codable {
public let id: String
public let type: RepoType
@@ -68,22 +68,22 @@ public extension Hub {
}
}
-public class LanguageModelConfigurationFromHub {
+public final class LanguageModelConfigurationFromHub: Sendable {
struct Configurations {
var modelConfig: Config
var tokenizerConfig: Config?
var tokenizerData: Config
}
- private var configPromise: Task?
+ private let configPromise: Task
public init(
modelName: String,
revision: String = "main",
hubApi: HubApi = .shared
) {
- configPromise = Task.init {
- try await self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi)
+ self.configPromise = Task.init {
+ return try await Self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi)
}
}
@@ -91,22 +91,22 @@ public class LanguageModelConfigurationFromHub {
modelFolder: URL,
hubApi: HubApi = .shared
) {
- configPromise = Task {
- try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
+ self.configPromise = Task {
+ return try await Self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
}
}
public var modelConfig: Config {
get async throws {
- try await configPromise!.value.modelConfig
+ try await configPromise.value.modelConfig
}
}
public var tokenizerConfig: Config? {
get async throws {
- if let hubConfig = try await configPromise!.value.tokenizerConfig {
+ if let hubConfig = try await configPromise.value.tokenizerConfig {
// Try to guess the class if it's not present and the modelType is
- if let _: String = hubConfig.tokenizerClass?.string() { return hubConfig }
+ if hubConfig.tokenizerClass?.string() != nil { return hubConfig }
guard let modelType = try await modelType else { return hubConfig }
// If the config exists but doesn't contain a tokenizerClass, use a fallback config if we have it
@@ -129,7 +129,7 @@ public class LanguageModelConfigurationFromHub {
public var tokenizerData: Config {
get async throws {
- try await configPromise!.value.tokenizerData
+ try await configPromise.value.tokenizerData
}
}
@@ -139,7 +139,7 @@ public class LanguageModelConfigurationFromHub {
}
}
- func loadConfig(
+ static func loadConfig(
modelName: String,
revision: String,
hubApi: HubApi = .shared
@@ -167,7 +167,7 @@ public class LanguageModelConfigurationFromHub {
}
}
- func loadConfig(
+ static func loadConfig(
modelFolder: URL,
hubApi: HubApi = .shared
) async throws -> Configurations {
@@ -204,7 +204,7 @@ public class LanguageModelConfigurationFromHub {
// Try to load .jinja template as plain text
chatTemplate = try? String(contentsOf: chatTemplateJinjaURL, encoding: .utf8)
} else if FileManager.default.fileExists(atPath: chatTemplateJsonURL.path),
- let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL)
+ let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL)
{
// Fall back to .json template
chatTemplate = chatTemplateConfig.chatTemplate.string()
diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift
index adfbf4a..39afd5d 100644
--- a/Sources/Hub/HubApi.swift
+++ b/Sources/Hub/HubApi.swift
@@ -10,18 +10,24 @@ import Foundation
import Network
import os
-public struct HubApi {
+public struct HubApi: Sendable {
var downloadBase: URL
var hfToken: String?
var endpoint: String
var useBackgroundSession: Bool
- var useOfflineMode: Bool?
+ var useOfflineMode: Bool? = nil
private let networkMonitor = NetworkMonitor()
public typealias RepoType = Hub.RepoType
public typealias Repo = Hub.Repo
- public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false, useOfflineMode: Bool? = nil) {
+ public init(
+ downloadBase: URL? = nil,
+ hfToken: String? = nil,
+ endpoint: String = "https://huggingface.co",
+ useBackgroundSession: Bool = false,
+ useOfflineMode: Bool? = nil
+ ) {
self.hfToken = hfToken ?? Self.hfTokenFromEnv()
if let downloadBase {
self.downloadBase = downloadBase
@@ -43,8 +49,8 @@ public struct HubApi {
private static let logger = Logger()
}
-private extension HubApi {
- static func hfTokenFromEnv() -> String? {
+extension HubApi {
+ fileprivate static func hfTokenFromEnv() -> String? {
let possibleTokens = [
{ ProcessInfo.processInfo.environment["HF_TOKEN"] },
{ ProcessInfo.processInfo.environment["HUGGING_FACE_HUB_TOKEN"] },
@@ -76,18 +82,18 @@ private extension HubApi {
}
/// File retrieval
-public extension HubApi {
+extension HubApi {
/// Model data for parsed filenames
- struct Sibling: Codable {
+ public struct Sibling: Codable {
let rfilename: String
}
- struct SiblingsResponse: Codable {
+ public struct SiblingsResponse: Codable {
let siblings: [Sibling]
}
/// Throws error if the response code is not 20X
- func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) {
+ public func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) {
var request = URLRequest(url: url)
if let hfToken {
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization")
@@ -118,7 +124,7 @@ public extension HubApi {
/// Throws error if page does not exist or is not accessible.
/// Allows relative redirects but ignores absolute ones for LFS files.
- func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
+ public func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
var request = URLRequest(url: url)
request.httpMethod = "HEAD"
if let hfToken {
@@ -133,16 +139,15 @@ public extension HubApi {
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError }
switch response.statusCode {
- case 200..<400: break // Allow redirects to pass through to the redirect delegate
- case 401, 403: throw Hub.HubClientError.authorizationRequired
- case 404: throw Hub.HubClientError.fileNotFound(url.lastPathComponent)
+ case 200..<400: break // Allow redirects to pass through to the redirect delegate
+ case 400..<500: throw Hub.HubClientError.authorizationRequired
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
}
return (data, response)
}
- func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] {
+ public func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] {
// Read repo info and only parse "siblings"
let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)/revision/\(revision)")!
let (data, _) = try await httpGet(for: url)
@@ -157,22 +162,22 @@ public extension HubApi {
return Array(selected)
}
- func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
- try await getFilenames(from: Repo(id: repoId), matching: globs)
+ public func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
+ return try await getFilenames(from: Repo(id: repoId), matching: globs)
}
- func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] {
- try await getFilenames(from: repo, matching: [glob])
+ public func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] {
+ return try await getFilenames(from: repo, matching: [glob])
}
- func getFilenames(from repoId: String, matching glob: String) async throws -> [String] {
- try await getFilenames(from: Repo(id: repoId), matching: [glob])
+ public func getFilenames(from repoId: String, matching glob: String) async throws -> [String] {
+ return try await getFilenames(from: Repo(id: repoId), matching: [glob])
}
}
/// Additional Errors
-public extension HubApi {
- enum EnvironmentError: LocalizedError {
+extension HubApi {
+ public enum EnvironmentError: LocalizedError {
case invalidMetadataError(String)
case offlineModeError(String)
case fileIntegrityError(String)
@@ -180,31 +185,31 @@ public extension HubApi {
public var errorDescription: String? {
switch self {
- case let .invalidMetadataError(message):
- String(localized: "Invalid metadata: \(message)")
- case let .offlineModeError(message):
- String(localized: "Offline mode error: \(message)")
- case let .fileIntegrityError(message):
- String(localized: "File integrity check failed: \(message)")
- case let .fileWriteError(message):
- String(localized: "Failed to write file: \(message)")
+ case .invalidMetadataError(let message):
+ return String(localized: "Invalid metadata: \(message)")
+ case .offlineModeError(let message):
+ return String(localized: "Offline mode error: \(message)")
+ case .fileIntegrityError(let message):
+ return String(localized: "File integrity check failed: \(message)")
+ case .fileWriteError(let message):
+ return String(localized: "Failed to write file: \(message)")
}
}
}
}
/// Configuration loading helpers
-public extension HubApi {
+extension HubApi {
/// Assumes the file has already been downloaded.
/// `filename` is relative to the download base.
- func configuration(from filename: String, in repo: Repo) throws -> Config {
+ public func configuration(from filename: String, in repo: Repo) throws -> Config {
let fileURL = localRepoLocation(repo).appending(path: filename)
return try configuration(fileURL: fileURL)
}
/// Assumes the file is already present at local url.
/// `fileURL` is a complete local file path for the given model
- func configuration(fileURL: URL) throws -> Config {
+ public func configuration(fileURL: URL) throws -> Config {
let data = try Data(contentsOf: fileURL)
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse }
@@ -213,8 +218,8 @@ public extension HubApi {
}
/// Whoami
-public extension HubApi {
- func whoami() async throws -> Config {
+extension HubApi {
+ public func whoami() async throws -> Config {
guard hfToken != nil else { throw Hub.HubClientError.authorizationRequired }
let url = URL(string: "\(endpoint)/api/whoami-v2")!
@@ -227,8 +232,8 @@ public extension HubApi {
}
/// Snaphsot download
-public extension HubApi {
- func localRepoLocation(_ repo: Repo) -> URL {
+extension HubApi {
+ public func localRepoLocation(_ repo: Repo) -> URL {
downloadBase.appending(component: repo.type.rawValue).appending(component: repo.id)
}
@@ -241,7 +246,7 @@ public extension HubApi {
/// - filePath: The path of the file for which metadata is being read.
/// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed.
/// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid.
- func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? {
+ public func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? {
if FileManager.default.fileExists(atPath: metadataPath.path) {
do {
let contents = try String(contentsOf: metadataPath, encoding: .utf8)
@@ -285,13 +290,13 @@ public extension HubApi {
return nil
}
- func isValidHash(hash: String, pattern: String) -> Bool {
+ public func isValidHash(hash: String, pattern: String) -> Bool {
let regex = try? NSRegularExpression(pattern: pattern)
let range = NSRange(location: 0, length: hash.utf16.count)
return regex?.firstMatch(in: hash, options: [], range: range) != nil
}
- func computeFileHash(file url: URL) throws -> String {
+ public func computeFileHash(file url: URL) throws -> String {
// Open file for reading
guard let fileHandle = try? FileHandle(forReadingFrom: url) else {
throw Hub.HubClientError.fileNotFound(url.lastPathComponent)
@@ -302,13 +307,13 @@ public extension HubApi {
}
var hasher = SHA256()
- let chunkSize = 1024 * 1024 // 1MB chunks
+ let chunkSize = 1024 * 1024 // 1MB chunks
while autoreleasepool(invoking: {
let nextChunk = try? fileHandle.read(upToCount: chunkSize)
guard let nextChunk,
- !nextChunk.isEmpty
+ !nextChunk.isEmpty
else {
return false
}
@@ -316,14 +321,14 @@ public extension HubApi {
hasher.update(data: nextChunk)
return true
- }) { }
+ }) {}
let digest = hasher.finalize()
return digest.map { String(format: "%02x", $0) }.joined()
}
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391
- func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws {
+ public func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws {
let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n"
do {
try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true)
@@ -333,8 +338,7 @@ public extension HubApi {
}
}
- struct HubFileDownloader {
- let hub: HubApi
+ public struct HubFileDownloader {
let repo: Repo
let revision: String
let repoDestination: URL
@@ -382,22 +386,24 @@ public extension HubApi {
/// (See for example PipelineLoader in swift-coreml-diffusers)
@discardableResult
func download(progressHandler: @escaping (Double) -> Void) async throws -> URL {
- let localMetadata = try hub.readDownloadMetadata(metadataPath: metadataDestination)
- let remoteMetadata = try await hub.getFileMetadata(url: source)
+ let localMetadata = try HubApi.shared.readDownloadMetadata(metadataPath: metadataDestination)
+ let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source)
let localCommitHash = localMetadata?.commitHash ?? ""
let remoteCommitHash = remoteMetadata.commitHash ?? ""
// Local file exists + metadata exists + commit_hash matches => return file
- if hub.isValidHash(hash: remoteCommitHash, pattern: hub.commitHashPattern), downloaded, localMetadata != nil, localCommitHash == remoteCommitHash {
+ if HubApi.shared.isValidHash(hash: remoteCommitHash, pattern: HubApi.shared.commitHashPattern), downloaded, localMetadata != nil,
+ localCommitHash == remoteCommitHash
+ {
return destination
}
// From now on, etag, commit_hash, url and size are not empty
guard let remoteCommitHash = remoteMetadata.commitHash,
- let remoteEtag = remoteMetadata.etag,
- let remoteSize = remoteMetadata.size,
- remoteMetadata.location != ""
+ let remoteEtag = remoteMetadata.etag,
+ let remoteSize = remoteMetadata.size,
+ remoteMetadata.location != ""
else {
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
}
@@ -406,7 +412,7 @@ public extension HubApi {
if downloaded {
// etag matches => update metadata and return file
if localMetadata?.etag == remoteEtag {
- try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
+ try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return destination
}
@@ -414,10 +420,10 @@ public extension HubApi {
// => means it's an LFS file (large)
// => let's compute local hash and compare
// => if match, update metadata and return file
- if hub.isValidHash(hash: remoteEtag, pattern: hub.sha256Pattern) {
- let fileHash = try hub.computeFileHash(file: destination)
+ if HubApi.shared.isValidHash(hash: remoteEtag, pattern: HubApi.shared.sha256Pattern) {
+ let fileHash = try HubApi.shared.computeFileHash(file: destination)
if fileHash == remoteEtag {
- try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
+ try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return destination
}
}
@@ -427,51 +433,46 @@ public extension HubApi {
let incompleteDestination = repoMetadataDestination.appending(path: relativeFilename + ".\(remoteEtag).incomplete")
try prepareCacheDestination(incompleteDestination)
- let downloader = Downloader(
- from: source,
- to: destination,
- incompleteDestination: incompleteDestination,
- using: hfToken,
- inBackground: backgroundSession,
- expectedSize: remoteSize
- )
+ let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination, inBackground: backgroundSession)
- return try await withTaskCancellationHandler {
- let downloadSubscriber = downloader.downloadState.sink { state in
+ try await withTaskCancellationHandler {
+ let sub = await downloader.download(from: source, using: hfToken, expectedSize: remoteSize)
+ listen: for await state in sub {
switch state {
- case let .downloading(progress):
+ case .notStarted:
+ continue
+ case .downloading(let progress):
progressHandler(progress)
- case .completed, .failed, .notStarted:
- break
- }
- }
- do {
- _ = try withExtendedLifetime(downloadSubscriber) {
- try downloader.waitUntilDone()
+ case .failed(let error):
+ throw error
+ case .completed:
+ break listen
}
-
- try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
-
- return destination
- } catch {
- // If download fails, leave the incomplete file in place for future resume
- throw error
}
} onCancel: {
- downloader.cancel()
+ Task {
+ await downloader.cancel()
+ }
}
+
+ try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
+
+ return destination
}
}
@discardableResult
- func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ public func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in })
+ async throws -> URL
+ {
let repoDestination = localRepoLocation(repo)
- let repoMetadataDestination = repoDestination
+ let repoMetadataDestination =
+ repoDestination
.appendingPathComponent(".cache")
.appendingPathComponent("huggingface")
.appendingPathComponent("download")
- if useOfflineMode ?? NetworkMonitor.shared.shouldUseOfflineMode() {
+ if await NetworkMonitor.shared.state.shouldUseOfflineMode() || useOfflineMode == true {
if !FileManager.default.fileExists(atPath: repoDestination.path) {
throw EnvironmentError.offlineModeError(String(localized: "Repository not available locally"))
}
@@ -482,10 +483,12 @@ public extension HubApi {
}
for fileUrl in fileUrls {
- let metadataPath = URL(fileURLWithPath: fileUrl.path.replacingOccurrences(
- of: repoDestination.path,
- with: repoMetadataDestination.path
- ) + ".metadata")
+ let metadataPath = URL(
+ fileURLWithPath: fileUrl.path.replacingOccurrences(
+ of: repoDestination.path,
+ with: repoMetadataDestination.path
+ ) + ".metadata"
+ )
let localMetadata = try readDownloadMetadata(metadataPath: metadataPath)
@@ -511,7 +514,6 @@ public extension HubApi {
for filename in filenames {
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
let downloader = HubFileDownloader(
- hub: self,
repo: repo,
revision: revision,
repoDestination: repoDestination,
@@ -521,36 +523,42 @@ public extension HubApi {
endpoint: endpoint,
backgroundSession: useBackgroundSession
)
+
try await downloader.download { fractionDownloaded in
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)
progressHandler(progress)
}
+ if Task.isCancelled {
+ return repoDestination
+ }
+
fileProgress.completedUnitCount = 100
}
+
progressHandler(progress)
return repoDestination
}
@discardableResult
- func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler)
+ public func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ return try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
}
@discardableResult
- func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- try await snapshot(from: repo, revision: revision, matching: [glob], progressHandler: progressHandler)
+ public func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ return try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler)
}
@discardableResult
- func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler)
+ public func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ return try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler)
}
}
/// Metadata
-public extension HubApi {
+extension HubApi {
/// Data structure containing information about a file versioned on the Hub
- struct FileMetadata {
+ public struct FileMetadata {
/// The commit hash related to the file
public let commitHash: String?
@@ -565,7 +573,7 @@ public extension HubApi {
}
/// Metadata about a file in the local directory related to a download process
- struct LocalDownloadFileMetadata {
+ public struct LocalDownloadFileMetadata {
/// Commit hash of the file in the repo
public let commitHash: String
@@ -599,7 +607,7 @@ public extension HubApi {
)
}
- func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
+ public func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
let files = try await getFilenames(from: repo, matching: globs)
let url = URL(string: "\(endpoint)/\(repo.id)/resolve/\(revision)")!
var selectedMetadata: [FileMetadata] = []
@@ -610,28 +618,45 @@ public extension HubApi {
return selectedMetadata
}
- func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
+ public func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: globs)
}
- func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
+ public func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
try await getFileMetadata(from: repo, revision: revision, matching: [glob])
}
- func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
+ public func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: [glob])
}
}
/// Network monitor helper class to help decide whether to use offline mode
-private extension HubApi {
- private final class NetworkMonitor {
- private var monitor: NWPathMonitor
- private var queue: DispatchQueue
+extension HubApi {
+ private actor NetworkStateActor {
+ public var isConnected: Bool = false
+ public var isExpensive: Bool = false
+ public var isConstrained: Bool = false
+
+ func update(path: NWPath) {
+ self.isConnected = path.status == .satisfied
+ self.isExpensive = path.isExpensive
+ self.isConstrained = path.isConstrained
+ }
- private(set) var isConnected: Bool = false
- private(set) var isExpensive: Bool = false
- private(set) var isConstrained: Bool = false
+ func shouldUseOfflineMode() -> Bool {
+ if ProcessInfo.processInfo.environment["CI_DISABLE_NETWORK_MONITOR"] == "1" {
+ return false
+ }
+ return !isConnected || isExpensive || isConstrained
+ }
+ }
+
+ private final class NetworkMonitor: Sendable {
+ private let monitor: NWPathMonitor
+ private let queue: DispatchQueue
+
+ public let state: NetworkStateActor = .init()
static let shared = NetworkMonitor()
@@ -643,27 +668,19 @@ private extension HubApi {
func startMonitoring() {
monitor.pathUpdateHandler = { [weak self] path in
- guard let self else { return }
-
- isConnected = path.status == .satisfied
- isExpensive = path.isExpensive
- isConstrained = path.isConstrained
+ guard let self = self else { return }
+ Task {
+ await self.state.update(path: path)
+ }
}
- monitor.start(queue: queue)
+ monitor.start(queue: self.queue)
}
func stopMonitoring() {
monitor.cancel()
}
- func shouldUseOfflineMode() -> Bool {
- if ProcessInfo.processInfo.environment["CI_DISABLE_NETWORK_MONITOR"] == "1" {
- return false
- }
- return !isConnected || isExpensive || isConstrained
- }
-
deinit {
stopMonitoring()
}
@@ -671,80 +688,84 @@ private extension HubApi {
}
/// Stateless wrappers that use `HubApi` instances
-public extension Hub {
- static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] {
- try await HubApi.shared.getFilenames(from: repo, matching: globs)
+extension Hub {
+ public static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] {
+ return try await HubApi.shared.getFilenames(from: repo, matching: globs)
}
- static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
- try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs)
+ public static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
+ return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs)
}
- static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] {
- try await HubApi.shared.getFilenames(from: repo, matching: glob)
+ public static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] {
+ return try await HubApi.shared.getFilenames(from: repo, matching: glob)
}
- static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] {
- try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob)
+ public static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] {
+ return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob)
}
- static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler)
+ public static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ return try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler)
}
- static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
+ public static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws
+ -> URL
+ {
+ return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
}
- static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler)
+ public static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ return try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler)
}
- static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler)
+ public static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler)
}
- static func whoami(token: String) async throws -> Config {
- try await HubApi(hfToken: token).whoami()
+ public static func whoami(token: String) async throws -> Config {
+ return try await HubApi(hfToken: token).whoami()
}
- static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata {
- try await HubApi.shared.getFileMetadata(url: fileURL)
+ public static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata {
+ return try await HubApi.shared.getFileMetadata(url: fileURL)
}
- static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
- try await HubApi.shared.getFileMetadata(from: repo, matching: globs)
+ public static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
+ return try await HubApi.shared.getFileMetadata(from: repo, matching: globs)
}
- static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
- try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs)
+ public static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
+ return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs)
}
- static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] {
- try await HubApi.shared.getFileMetadata(from: repo, matching: [glob])
+ public static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] {
+ return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob])
}
- static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] {
- try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob])
+ public static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] {
+ return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob])
}
}
-public extension [String] {
- func matching(glob: String) -> [String] {
+extension [String] {
+ public func matching(glob: String) -> [String] {
filter { fnmatch(glob, $0, 0) == 0 }
}
}
-public extension FileManager {
- func getFileUrls(at directoryUrl: URL) throws -> [URL] {
+extension FileManager {
+ public func getFileUrls(at directoryUrl: URL) throws -> [URL] {
var fileUrls = [URL]()
// Get all contents including subdirectories
- guard let enumerator = FileManager.default.enumerator(
- at: directoryUrl,
- includingPropertiesForKeys: [.isRegularFileKey, .isHiddenKey],
- options: [.skipsHiddenFiles]
- ) else {
+ guard
+ let enumerator = FileManager.default.enumerator(
+ at: directoryUrl,
+ includingPropertiesForKeys: [.isRegularFileKey, .isHiddenKey],
+ options: [.skipsHiddenFiles]
+ )
+ else {
return fileUrls
}
@@ -765,19 +786,26 @@ public extension FileManager {
/// Only allow relative redirects and reject others
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258
-private class RedirectDelegate: NSObject, URLSessionTaskDelegate {
- func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) {
+private final class RedirectDelegate: NSObject, URLSessionTaskDelegate, Sendable {
+ func urlSession(
+ _ session: URLSession,
+ task: URLSessionTask,
+ willPerformHTTPRedirection response: HTTPURLResponse,
+ newRequest request: URLRequest,
+ completionHandler: @escaping (URLRequest?) -> Void
+ ) {
// Check if it's a redirect status code (300-399)
if (300...399).contains(response.statusCode) {
// Get the Location header
if let locationString = response.value(forHTTPHeaderField: "Location"),
- let locationUrl = URL(string: locationString)
+ let locationUrl = URL(string: locationString)
{
+
// Check if it's a relative redirect (no host component)
if locationUrl.host == nil {
// For relative redirects, construct the new URL using the original request's base
if let originalUrl = task.originalRequest?.url,
- var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true)
+ var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true)
{
// Update the path component with the relative path
components.path = locationUrl.path
diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift
index d62d2e8..09e779c 100644
--- a/Tests/HubTests/DownloaderTests.swift
+++ b/Tests/HubTests/DownloaderTests.swift
@@ -6,6 +6,8 @@
//
import Combine
+import XCTest
+
@testable import Hub
import XCTest
@@ -17,19 +19,20 @@ enum DownloadError: LocalizedError {
var errorDescription: String? {
switch self {
case .invalidDownloadLocation:
- String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid")
+ return String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid")
case .unexpectedError:
- String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message")
+ return String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message")
}
}
}
private extension Downloader {
- func interruptDownload() {
- session?.invalidateAndCancel()
+ func interruptDownload() async {
+ await self.session.get()?.invalidateAndCancel()
}
}
+
final class DownloaderTests: XCTestCase {
var tempDir: URL!
@@ -53,18 +56,18 @@ final class DownloaderTests: XCTestCase {
let etag = try await Hub.getFileMetadata(fileURL: url).etag!
let destination = tempDir.appendingPathComponent("config.json")
let fileContent = """
- {
- "architectures": [
- "LlamaForCausalLM"
- ],
- "bos_token_id": 1,
- "eos_token_id": 2,
- "model_type": "llama",
- "pad_token_id": 0,
- "vocab_size": 32000
- }
+ {
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "model_type": "llama",
+ "pad_token_id": 0,
+ "vocab_size": 32000
+ }
- """
+ """
let cacheDir = tempDir.appendingPathComponent("cache")
try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true)
@@ -72,33 +75,22 @@ final class DownloaderTests: XCTestCase {
let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete")
FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil)
- let downloader = Downloader(
- from: url,
- to: destination,
- incompleteDestination: incompleteDestination
- )
-
- // Store subscriber outside the continuation to maintain its lifecycle
- var subscriber: AnyCancellable?
-
- try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in
- subscriber = downloader.downloadState.sink { state in
- switch state {
- case .completed:
- continuation.resume()
- case let .failed(error):
- continuation.resume(throwing: error)
- case .downloading:
- break
- case .notStarted:
- break
- }
+ let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination)
+ let sub = await downloader.download(from: url)
+
+ listen: for await state in sub {
+ switch state {
+ case .notStarted:
+ continue
+ case .downloading(let progress):
+ continue
+ case .failed(let error):
+ throw error
+ case .completed:
+ break listen
}
}
- // Cancel subscription after continuation completes
- subscriber?.cancel()
-
// Verify download completed successfully
XCTAssertTrue(FileManager.default.fileExists(atPath: destination.path))
XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), fileContent)
@@ -116,18 +108,22 @@ final class DownloaderTests: XCTestCase {
let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete")
FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil)
- // Create downloader with incorrect expected size
- let downloader = Downloader(
- from: url,
- to: destination,
- incompleteDestination: incompleteDestination,
- expectedSize: 999999 // Incorrect size
- )
-
- do {
- try downloader.waitUntilDone()
- XCTFail("Download should have failed due to size mismatch")
- } catch { }
+ let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination)
+ // Download with incorrect expected size
+ let sub = await downloader.download(from: url, expectedSize: 999999) // Incorrect size
+ listen: for await state in sub {
+ switch state {
+ case .notStarted:
+ continue
+ case .downloading(let progress):
+ continue
+ case .failed:
+ break listen
+ case .completed:
+ XCTFail("Download should have failed due to size mismatch")
+ break listen
+ }
+ }
// Verify no file was created at destination
XCTAssertFalse(FileManager.default.fileExists(atPath: destination.path))
@@ -141,8 +137,10 @@ final class DownloaderTests: XCTestCase {
let destination = tempDir.appendingPathComponent("SAM%202%20Studio%201.1.zip")
// Create parent directory if it doesn't exist
- try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(),
- withIntermediateDirectories: true)
+ try FileManager.default.createDirectory(
+ at: destination.deletingLastPathComponent(),
+ withIntermediateDirectories: true
+ )
let cacheDir = tempDir.appendingPathComponent("cache")
try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true)
@@ -150,42 +148,32 @@ final class DownloaderTests: XCTestCase {
let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete")
FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil)
- let downloader = Downloader(
- from: url,
- to: destination,
- incompleteDestination: incompleteDestination,
- expectedSize: 73194001 // Correct size for verification
- )
+ let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination)
+ let sub = await downloader.download(from: url, expectedSize: 73_194_001) // Correct size for verification
// First interruption point at 50%
var threshold = 0.5
- var subscriber: AnyCancellable?
-
do {
// Monitor download progress and interrupt at thresholds to test if
// download continues from where it left off
- try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in
- subscriber = downloader.downloadState.sink { state in
- switch state {
- case let .downloading(progress):
- if threshold != 1.0, progress >= threshold {
- // Move to next threshold and interrupt
- threshold = threshold == 0.5 ? 0.75 : 1.0
- downloader.interruptDownload()
- }
- case .completed:
- continuation.resume()
- case let .failed(error):
- continuation.resume(throwing: error)
- case .notStarted:
- break
+ listen: for await state in sub {
+ switch state {
+ case .notStarted:
+ continue
+ case .downloading(let progress):
+ if threshold != 1.0 && progress >= threshold {
+ // Move to next threshold and interrupt
+ threshold = threshold == 0.5 ? 0.75 : 1.0
+ await downloader.interruptDownload()
}
+ case .failed(let error):
+ throw error
+ case .completed:
+ break listen
}
}
-
- subscriber?.cancel()
-
+
// Verify the file exists and is complete
if FileManager.default.fileExists(atPath: destination.path) {
let attributes = try FileManager.default.attributesOfItem(atPath: destination.path)
diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift
index b03716f..5a53452 100644
--- a/Tests/HubTests/HubApiTests.swift
+++ b/Tests/HubTests/HubApiTests.swift
@@ -4,9 +4,10 @@
// Created by Pedro Cuenca on 20231230.
//
-@testable import Hub
import XCTest
+@testable import Hub
+
class HubApiTests: XCTestCase {
override func setUp() {
// Put setup code here. This method is called before the invocation of each test method in the class.
@@ -150,10 +151,14 @@ class HubApiTests: XCTestCase {
do {
let revision = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb"
let etag = "fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107"
- let location = "https://cdn-lfs.hf.co/repos/4a/4e/4a4e587f66a2979dcd75e1d7324df8ee9ef74be3582a05bea31c2c26d0d467d0/fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.mlmodel%3B+filename%3D%22model.mlmodel"
+ let location =
+ "https://cdn-lfs.hf.co/repos/4a/4e/4a4e587f66a2979dcd75e1d7324df8ee9ef74be3582a05bea31c2c26d0d467d0/fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.mlmodel%3B+filename%3D%22model.mlmodel"
let size = 504766
- let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel")
+ let url = URL(
+ string:
+ "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel"
+ )
let metadata = try await Hub.getFileMetadata(fileURL: url!)
XCTAssertEqual(metadata.commitHash, revision)
@@ -188,7 +193,12 @@ class SnapshotDownloadTests: XCTestCase {
var filenames: [String] = []
let prefix = downloadDestination.appending(path: "models/\(repo)").path.appending("/")
- if let enumerator = FileManager.default.enumerator(at: url, includingPropertiesForKeys: [.isRegularFileKey], options: [.skipsHiddenFiles], errorHandler: nil) {
+ if let enumerator = FileManager.default.enumerator(
+ at: url,
+ includingPropertiesForKeys: [.isRegularFileKey],
+ options: [.skipsHiddenFiles],
+ errorHandler: nil
+ ) {
for case let fileURL as URL in enumerator {
do {
let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey])
@@ -256,7 +266,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertEqual(
Set(downloadedFilenames),
Set([
- "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json",
+ "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json"
])
)
}
@@ -405,7 +415,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertEqual(
Set(downloadedMetadataFilenames),
Set([
- ".cache/huggingface/download/tokenizer.json.metadata",
+ ".cache/huggingface/download/tokenizer.json.metadata"
])
)
@@ -534,7 +544,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertEqual(
Set(downloadedMetadataFilenames),
Set([
- ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata",
+ ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata"
])
)
@@ -915,7 +925,11 @@ class SnapshotDownloadTests: XCTestCase {
let metadataDestination = downloadedTo.appendingPathComponent(".cache/huggingface/download").appendingPathComponent("x.bin.metadata")
- try "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2ab4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4\n0\n".write(to: metadataDestination, atomically: true, encoding: .utf8)
+ try "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2ab4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4\n0\n".write(
+ to: metadataDestination,
+ atomically: true,
+ encoding: .utf8
+ )
hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true)
@@ -972,7 +986,9 @@ class SnapshotDownloadTests: XCTestCase {
func testResumeDownloadFromEmptyIncomplete() async throws {
let hubApi = HubApi(downloadBase: downloadDestination)
var lastProgress: Progress? = nil
- var downloadedTo = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent("Library/Caches/huggingface-tests/models/coreml-projects/Llama-2-7b-chat-coreml")
+ var downloadedTo = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent(
+ "Library/Caches/huggingface-tests/models/coreml-projects/Llama-2-7b-chat-coreml"
+ )
let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download")
@@ -993,17 +1009,17 @@ class SnapshotDownloadTests: XCTestCase {
let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path)
let expected = """
- {
- "architectures": [
- "LlamaForCausalLM"
- ],
- "bos_token_id": 1,
- "eos_token_id": 2,
- "model_type": "llama",
- "pad_token_id": 0,
- "vocab_size": 32000
- }
- """
+ {
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "model_type": "llama",
+ "pad_token_id": 0,
+ "vocab_size": 32000
+ }
+ """
XCTAssertTrue(fileContents.contains(expected))
}
@@ -1032,17 +1048,17 @@ class SnapshotDownloadTests: XCTestCase {
let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path)
let expected = """
- X
- "architectures": [
- "LlamaForCausalLM"
- ],
- "bos_token_id": 1,
- "eos_token_id": 2,
- "model_type": "llama",
- "pad_token_id": 0,
- "vocab_size": 32000
- }
- """
+ X
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "model_type": "llama",
+ "pad_token_id": 0,
+ "vocab_size": 32000
+ }
+ """
XCTAssertTrue(fileContents.contains(expected))
}
@@ -1070,6 +1086,7 @@ class SnapshotDownloadTests: XCTestCase {
// Cancel the download once we've seen progress
downloadTask.cancel()
+
try await Task.sleep(nanoseconds: 5_000_000_000)
// Resume download with a new task
@@ -1078,48 +1095,9 @@ class SnapshotDownloadTests: XCTestCase {
}
let filePath = downloadedTo.appendingPathComponent(targetFile)
- XCTAssertTrue(FileManager.default.fileExists(atPath: filePath.path),
- "Downloaded file should exist at \(filePath.path)")
- }
-
- func testDownloadWithRevision() async throws {
- let hubApi = HubApi(downloadBase: downloadDestination)
- var lastProgress: Progress? = nil
-
- let commitHash = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb"
- let downloadedTo = try await hubApi.snapshot(from: repo, revision: commitHash, matching: "*.json") { progress in
- print("Total Progress: \(progress.fractionCompleted)")
- print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)")
- lastProgress = progress
- }
-
- let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo)
- XCTAssertEqual(lastProgress?.fractionCompleted, 1)
- XCTAssertEqual(lastProgress?.completedUnitCount, 6)
- XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
- XCTAssertEqual(
- Set(downloadedFilenames),
- Set([
- "config.json", "tokenizer.json", "tokenizer_config.json",
- "llama-2-7b-chat.mlpackage/Manifest.json",
- "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json",
- "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json",
- ])
+ XCTAssertTrue(
+ FileManager.default.fileExists(atPath: filePath.path),
+ "Downloaded file should exist at \(filePath.path)"
)
-
- do {
- let revision = "nonexistent-revision"
- try await hubApi.snapshot(from: repo, revision: revision, matching: "*.json")
- XCTFail("Expected an error to be thrown")
- } catch let error as Hub.HubClientError {
- switch error {
- case .fileNotFound:
- break // Error type is correct
- default:
- XCTFail("Wrong error type: \(error)")
- }
- } catch {
- XCTFail("Unexpected error: \(error)")
- }
}
}
From 8de7f498fddfcdd8218fdad7a58ff2d7d1562709 Mon Sep 17 00:00:00 2001
From: Piotr Kowalczuk
Date: Mon, 9 Jun 2025 11:55:45 +0200
Subject: [PATCH 2/2] Sendable Hub.Downloader, Hub.Hub and Hub.HubApi .2
---
Package.swift | 4 +-
Sources/Hub/Downloader.swift | 80 +++++-----
Sources/Hub/Hub.swift | 20 +--
Sources/Hub/HubApi.swift | 228 ++++++++++++++-------------
Tests/HubTests/DownloaderTests.swift | 47 +++---
Tests/HubTests/HubApiTests.swift | 93 ++++++++---
6 files changed, 257 insertions(+), 215 deletions(-)
diff --git a/Package.swift b/Package.swift
index 981bd35..43c5696 100644
--- a/Package.swift
+++ b/Package.swift
@@ -3,9 +3,9 @@
import PackageDescription
-// Define the strict concurrency settings to be applied to all targets.
+/// Define the strict concurrency settings to be applied to all targets.
let swiftSettings: [SwiftSetting] = [
- .enableExperimentalFeature("StrictConcurrency")
+ .enableExperimentalFeature("StrictConcurrency"),
]
let package = Package(
diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift
index 8073795..cc0171f 100644
--- a/Sources/Hub/Downloader.swift
+++ b/Sources/Hub/Downloader.swift
@@ -29,18 +29,18 @@ final class Downloader: NSObject, Sendable, ObservableObject {
}
private let broadcaster: Broadcaster = Broadcaster {
- return DownloadState.notStarted
+ DownloadState.notStarted
}
private let sessionConfig: URLSessionConfiguration
- let session: SessionActor = SessionActor()
- private let task: TaskActor = TaskActor()
+ let session: SessionActor = .init()
+ private let task: TaskActor = .init()
init(
to destination: URL,
incompleteDestination: URL,
inBackground: Bool = false,
- chunkSize: Int = 10 * 1024 * 1024 // 10MB
+ chunkSize: Int = 10 * 1024 * 1024 // 10MB
) {
self.destination = destination
// Create incomplete file path based on destination
@@ -55,7 +55,7 @@ final class Downloader: NSObject, Sendable, ObservableObject {
config.isDiscretionary = false
config.sessionSendsLaunchEvents = true
}
- self.sessionConfig = config
+ sessionConfig = config
}
func download(
@@ -66,13 +66,13 @@ final class Downloader: NSObject, Sendable, ObservableObject {
timeout: TimeInterval = 10,
numRetries: Int = 5
) async -> AsyncStream {
- if let task = await self.task.get() {
+ if let task = await task.get() {
task.cancel()
}
- await self.downloadResumeState.setExpectedSize(expectedSize)
- let resumeSize = Self.incompleteFileSize(at: self.incompleteDestination)
- await self.session.set(URLSession(configuration: self.sessionConfig, delegate: self, delegateQueue: nil))
- await self.setUpDownload(
+ await downloadResumeState.setExpectedSize(expectedSize)
+ let resumeSize = Self.incompleteFileSize(at: incompleteDestination)
+ await session.set(URLSession(configuration: sessionConfig, delegate: self, delegateQueue: nil))
+ await setUpDownload(
from: url,
with: authToken,
resumeSize: resumeSize,
@@ -81,7 +81,7 @@ final class Downloader: NSObject, Sendable, ObservableObject {
numRetries: numRetries
)
- return await self.broadcaster.subscribe()
+ return await broadcaster.subscribe()
}
/// Sets up and initiates a file download operation
@@ -102,8 +102,8 @@ final class Downloader: NSObject, Sendable, ObservableObject {
timeout: TimeInterval,
numRetries: Int
) async {
- let resumeSize = Self.incompleteFileSize(at: self.incompleteDestination)
- guard let tasks = await self.session.get()?.allTasks else {
+ let resumeSize = Self.incompleteFileSize(at: incompleteDestination)
+ guard let tasks = await session.get()?.allTasks else {
return
}
@@ -123,7 +123,7 @@ final class Downloader: NSObject, Sendable, ObservableObject {
}
}
- await self.task.set(
+ await task.set(
Task {
do {
var request = URLRequest(url: url)
@@ -132,7 +132,7 @@ final class Downloader: NSObject, Sendable, ObservableObject {
var requestHeaders = headers ?? [:]
// Populate header auth and range fields
- if let authToken = authToken {
+ if let authToken {
requestHeaders["Authorization"] = "Bearer \(authToken)"
}
@@ -203,14 +203,14 @@ final class Downloader: NSObject, Sendable, ObservableObject {
tempFile: FileHandle,
numRetries: Int
) async throws {
- guard let session = await self.session.get() else {
+ guard let session = await session.get() else {
throw DownloadError.unexpectedError
}
// Create a new request with Range header for resuming
var newRequest = request
- if await self.downloadResumeState.downloadedSize > 0 {
- newRequest.setValue("bytes=\(await self.downloadResumeState.downloadedSize)-", forHTTPHeaderField: "Range")
+ if await downloadResumeState.downloadedSize > 0 {
+ await newRequest.setValue("bytes=\(downloadResumeState.downloadedSize)-", forHTTPHeaderField: "Range")
}
// Start the download and get the byte stream
@@ -233,22 +233,22 @@ final class Downloader: NSObject, Sendable, ObservableObject {
buffer.append(byte)
// When buffer is full, write to disk
if buffer.count == chunkSize {
- if !buffer.isEmpty { // Filter out keep-alive chunks
+ if !buffer.isEmpty { // Filter out keep-alive chunks
try tempFile.write(contentsOf: buffer)
buffer.removeAll(keepingCapacity: true)
- await self.downloadResumeState.incDownloadedSize(chunkSize)
+ await downloadResumeState.incDownloadedSize(chunkSize)
newNumRetries = 5
- guard let expectedSize = await self.downloadResumeState.expectedSize else { continue }
- let progress = expectedSize != 0 ? Double(await self.downloadResumeState.downloadedSize) / Double(expectedSize) : 0
- await self.broadcaster.broadcast(state: .downloading(progress))
+ guard let expectedSize = await downloadResumeState.expectedSize else { continue }
+ let progress = await expectedSize != 0 ? Double(downloadResumeState.downloadedSize) / Double(expectedSize) : 0
+ await broadcaster.broadcast(state: .downloading(progress))
}
}
}
if !buffer.isEmpty {
try tempFile.write(contentsOf: buffer)
- await self.downloadResumeState.incDownloadedSize(buffer.count)
+ await downloadResumeState.incDownloadedSize(buffer.count)
buffer.removeAll(keepingCapacity: true)
newNumRetries = 5
}
@@ -270,15 +270,15 @@ final class Downloader: NSObject, Sendable, ObservableObject {
// Verify the downloaded file size matches the expected size
let actualSize = try tempFile.seekToEnd()
- if let expectedSize = await self.downloadResumeState.expectedSize, expectedSize != actualSize {
+ if let expectedSize = await downloadResumeState.expectedSize, expectedSize != actualSize {
throw DownloadError.unexpectedError
}
}
func cancel() async {
- await self.session.get()?.invalidateAndCancel()
- await self.task.get()?.cancel()
- await self.broadcaster.broadcast(state: .failed(URLError(.cancelled)))
+ await session.get()?.invalidateAndCancel()
+ await task.get()?.cancel()
+ await broadcaster.broadcast(state: .failed(URLError(.cancelled)))
}
/// Check if an incomplete file exists for the destination and returns its size
@@ -305,7 +305,7 @@ extension Downloader: URLSessionDownloadDelegate {
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
do {
// If the downloaded file already exists on the filesystem, overwrite it
- try FileManager.default.moveDownloadedFile(from: location, to: self.destination)
+ try FileManager.default.moveDownloadedFile(from: location, to: destination)
Task {
await self.broadcaster.broadcast(state: .completed(destination))
}
@@ -317,7 +317,7 @@ extension Downloader: URLSessionDownloadDelegate {
}
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
- if let error = error {
+ if let error {
Task {
await self.broadcaster.broadcast(state: .failed(error))
}
@@ -347,15 +347,15 @@ private actor DownloadResumeState {
var downloadedSize: Int = 0
func setExpectedSize(_ size: Int?) {
- self.expectedSize = size
+ expectedSize = size
}
func setDownloadedSize(_ size: Int) {
- self.downloadedSize = size
+ downloadedSize = size
}
func incDownloadedSize(_ size: Int) {
- self.downloadedSize += size
+ downloadedSize += size
}
}
@@ -373,7 +373,7 @@ actor Broadcaster {
}
func subscribe() -> AsyncStream {
- return AsyncStream { continuation in
+ AsyncStream { continuation in
let id = UUID()
self.continuations[id] = continuation
@@ -396,11 +396,11 @@ actor Broadcaster {
}
private func unsubscribe(_ id: UUID) {
- self.continuations.removeValue(forKey: id)
+ continuations.removeValue(forKey: id)
}
func broadcast(state: E) async {
- self.latestState = state
+ latestState = state
await withTaskGroup(of: Void.self) { group in
for continuation in continuations.values {
group.addTask {
@@ -412,25 +412,25 @@ actor Broadcaster {
}
actor SessionActor {
- private var urlSession: URLSession? = nil
+ private var urlSession: URLSession?
func set(_ urlSession: URLSession?) {
self.urlSession = urlSession
}
func get() -> URLSession? {
- return self.urlSession
+ urlSession
}
}
actor TaskActor {
- private var task: Task? = nil
+ private var task: Task?
func set(_ task: Task?) {
self.task = task
}
func get() -> Task? {
- return self.task
+ task
}
}
diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift
index 85fd8fc..00b7b2c 100644
--- a/Sources/Hub/Hub.swift
+++ b/Sources/Hub/Hub.swift
@@ -7,10 +7,10 @@
import Foundation
-public struct Hub: Sendable {}
+public struct Hub: Sendable { }
-extension Hub {
- public enum HubClientError: LocalizedError {
+public extension Hub {
+ enum HubClientError: LocalizedError {
case authorizationRequired
case httpStatusCode(Int)
case parse
@@ -51,13 +51,13 @@ extension Hub {
}
}
- public enum RepoType: String, Codable {
+ enum RepoType: String, Codable {
case models
case datasets
case spaces
}
- public struct Repo: Codable {
+ struct Repo: Codable {
public let id: String
public let type: RepoType
@@ -82,8 +82,8 @@ public final class LanguageModelConfigurationFromHub: Sendable {
revision: String = "main",
hubApi: HubApi = .shared
) {
- self.configPromise = Task.init {
- return try await Self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi)
+ configPromise = Task.init {
+ try await Self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi)
}
}
@@ -91,8 +91,8 @@ public final class LanguageModelConfigurationFromHub: Sendable {
modelFolder: URL,
hubApi: HubApi = .shared
) {
- self.configPromise = Task {
- return try await Self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
+ configPromise = Task {
+ try await Self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
}
}
@@ -204,7 +204,7 @@ public final class LanguageModelConfigurationFromHub: Sendable {
// Try to load .jinja template as plain text
chatTemplate = try? String(contentsOf: chatTemplateJinjaURL, encoding: .utf8)
} else if FileManager.default.fileExists(atPath: chatTemplateJsonURL.path),
- let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL)
+ let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL)
{
// Fall back to .json template
chatTemplate = chatTemplateConfig.chatTemplate.string()
diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift
index 39afd5d..7c8f61e 100644
--- a/Sources/Hub/HubApi.swift
+++ b/Sources/Hub/HubApi.swift
@@ -49,8 +49,8 @@ public struct HubApi: Sendable {
private static let logger = Logger()
}
-extension HubApi {
- fileprivate static func hfTokenFromEnv() -> String? {
+private extension HubApi {
+ static func hfTokenFromEnv() -> String? {
let possibleTokens = [
{ ProcessInfo.processInfo.environment["HF_TOKEN"] },
{ ProcessInfo.processInfo.environment["HUGGING_FACE_HUB_TOKEN"] },
@@ -82,18 +82,18 @@ extension HubApi {
}
/// File retrieval
-extension HubApi {
+public extension HubApi {
/// Model data for parsed filenames
- public struct Sibling: Codable {
+ struct Sibling: Codable {
let rfilename: String
}
- public struct SiblingsResponse: Codable {
+ struct SiblingsResponse: Codable {
let siblings: [Sibling]
}
/// Throws error if the response code is not 20X
- public func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) {
+ func httpGet(for url: URL) async throws -> (Data, HTTPURLResponse) {
var request = URLRequest(url: url)
if let hfToken {
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization")
@@ -124,7 +124,7 @@ extension HubApi {
/// Throws error if page does not exist or is not accessible.
/// Allows relative redirects but ignores absolute ones for LFS files.
- public func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
+ func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
var request = URLRequest(url: url)
request.httpMethod = "HEAD"
if let hfToken {
@@ -139,15 +139,16 @@ extension HubApi {
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError }
switch response.statusCode {
- case 200..<400: break // Allow redirects to pass through to the redirect delegate
- case 400..<500: throw Hub.HubClientError.authorizationRequired
+ case 200..<400: break // Allow redirects to pass through to the redirect delegate
+ case 401, 403: throw Hub.HubClientError.authorizationRequired
+ case 404: throw Hub.HubClientError.fileNotFound(url.lastPathComponent)
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
}
return (data, response)
}
- public func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] {
+ func getFilenames(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [String] {
// Read repo info and only parse "siblings"
let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)/revision/\(revision)")!
let (data, _) = try await httpGet(for: url)
@@ -162,22 +163,22 @@ extension HubApi {
return Array(selected)
}
- public func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
- return try await getFilenames(from: Repo(id: repoId), matching: globs)
+ func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
+ try await getFilenames(from: Repo(id: repoId), matching: globs)
}
- public func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] {
- return try await getFilenames(from: repo, matching: [glob])
+ func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] {
+ try await getFilenames(from: repo, matching: [glob])
}
- public func getFilenames(from repoId: String, matching glob: String) async throws -> [String] {
- return try await getFilenames(from: Repo(id: repoId), matching: [glob])
+ func getFilenames(from repoId: String, matching glob: String) async throws -> [String] {
+ try await getFilenames(from: Repo(id: repoId), matching: [glob])
}
}
/// Additional Errors
-extension HubApi {
- public enum EnvironmentError: LocalizedError {
+public extension HubApi {
+ enum EnvironmentError: LocalizedError {
case invalidMetadataError(String)
case offlineModeError(String)
case fileIntegrityError(String)
@@ -185,31 +186,31 @@ extension HubApi {
public var errorDescription: String? {
switch self {
- case .invalidMetadataError(let message):
- return String(localized: "Invalid metadata: \(message)")
- case .offlineModeError(let message):
- return String(localized: "Offline mode error: \(message)")
- case .fileIntegrityError(let message):
- return String(localized: "File integrity check failed: \(message)")
- case .fileWriteError(let message):
- return String(localized: "Failed to write file: \(message)")
+ case let .invalidMetadataError(message):
+ String(localized: "Invalid metadata: \(message)")
+ case let .offlineModeError(message):
+ String(localized: "Offline mode error: \(message)")
+ case let .fileIntegrityError(message):
+ String(localized: "File integrity check failed: \(message)")
+ case let .fileWriteError(message):
+ String(localized: "Failed to write file: \(message)")
}
}
}
}
/// Configuration loading helpers
-extension HubApi {
+public extension HubApi {
/// Assumes the file has already been downloaded.
/// `filename` is relative to the download base.
- public func configuration(from filename: String, in repo: Repo) throws -> Config {
+ func configuration(from filename: String, in repo: Repo) throws -> Config {
let fileURL = localRepoLocation(repo).appending(path: filename)
return try configuration(fileURL: fileURL)
}
/// Assumes the file is already present at local url.
/// `fileURL` is a complete local file path for the given model
- public func configuration(fileURL: URL) throws -> Config {
+ func configuration(fileURL: URL) throws -> Config {
let data = try Data(contentsOf: fileURL)
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse }
@@ -218,8 +219,8 @@ extension HubApi {
}
/// Whoami
-extension HubApi {
- public func whoami() async throws -> Config {
+public extension HubApi {
+ func whoami() async throws -> Config {
guard hfToken != nil else { throw Hub.HubClientError.authorizationRequired }
let url = URL(string: "\(endpoint)/api/whoami-v2")!
@@ -232,8 +233,8 @@ extension HubApi {
}
/// Snaphsot download
-extension HubApi {
- public func localRepoLocation(_ repo: Repo) -> URL {
+public extension HubApi {
+ func localRepoLocation(_ repo: Repo) -> URL {
downloadBase.appending(component: repo.type.rawValue).appending(component: repo.id)
}
@@ -246,7 +247,7 @@ extension HubApi {
/// - filePath: The path of the file for which metadata is being read.
/// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed.
/// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid.
- public func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? {
+ func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? {
if FileManager.default.fileExists(atPath: metadataPath.path) {
do {
let contents = try String(contentsOf: metadataPath, encoding: .utf8)
@@ -290,13 +291,13 @@ extension HubApi {
return nil
}
- public func isValidHash(hash: String, pattern: String) -> Bool {
+ func isValidHash(hash: String, pattern: String) -> Bool {
let regex = try? NSRegularExpression(pattern: pattern)
let range = NSRange(location: 0, length: hash.utf16.count)
return regex?.firstMatch(in: hash, options: [], range: range) != nil
}
- public func computeFileHash(file url: URL) throws -> String {
+ func computeFileHash(file url: URL) throws -> String {
// Open file for reading
guard let fileHandle = try? FileHandle(forReadingFrom: url) else {
throw Hub.HubClientError.fileNotFound(url.lastPathComponent)
@@ -307,13 +308,13 @@ extension HubApi {
}
var hasher = SHA256()
- let chunkSize = 1024 * 1024 // 1MB chunks
+ let chunkSize = 1024 * 1024 // 1MB chunks
while autoreleasepool(invoking: {
let nextChunk = try? fileHandle.read(upToCount: chunkSize)
guard let nextChunk,
- !nextChunk.isEmpty
+ !nextChunk.isEmpty
else {
return false
}
@@ -321,14 +322,14 @@ extension HubApi {
hasher.update(data: nextChunk)
return true
- }) {}
+ }) { }
let digest = hasher.finalize()
return digest.map { String(format: "%02x", $0) }.joined()
}
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391
- public func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws {
+ func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws {
let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n"
do {
try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true)
@@ -338,7 +339,8 @@ extension HubApi {
}
}
- public struct HubFileDownloader {
+ struct HubFileDownloader {
+ let hub: HubApi
let repo: Repo
let revision: String
let repoDestination: URL
@@ -386,24 +388,24 @@ extension HubApi {
/// (See for example PipelineLoader in swift-coreml-diffusers)
@discardableResult
func download(progressHandler: @escaping (Double) -> Void) async throws -> URL {
- let localMetadata = try HubApi.shared.readDownloadMetadata(metadataPath: metadataDestination)
- let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source)
+ let localMetadata = try hub.readDownloadMetadata(metadataPath: metadataDestination)
+ let remoteMetadata = try await hub.getFileMetadata(url: source)
let localCommitHash = localMetadata?.commitHash ?? ""
let remoteCommitHash = remoteMetadata.commitHash ?? ""
// Local file exists + metadata exists + commit_hash matches => return file
- if HubApi.shared.isValidHash(hash: remoteCommitHash, pattern: HubApi.shared.commitHashPattern), downloaded, localMetadata != nil,
- localCommitHash == remoteCommitHash
+ if hub.isValidHash(hash: remoteCommitHash, pattern: hub.commitHashPattern), downloaded, localMetadata != nil,
+ localCommitHash == remoteCommitHash
{
return destination
}
// From now on, etag, commit_hash, url and size are not empty
guard let remoteCommitHash = remoteMetadata.commitHash,
- let remoteEtag = remoteMetadata.etag,
- let remoteSize = remoteMetadata.size,
- remoteMetadata.location != ""
+ let remoteEtag = remoteMetadata.etag,
+ let remoteSize = remoteMetadata.size,
+ remoteMetadata.location != ""
else {
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
}
@@ -412,7 +414,7 @@ extension HubApi {
if downloaded {
// etag matches => update metadata and return file
if localMetadata?.etag == remoteEtag {
- try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
+ try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return destination
}
@@ -420,10 +422,10 @@ extension HubApi {
// => means it's an LFS file (large)
// => let's compute local hash and compare
// => if match, update metadata and return file
- if HubApi.shared.isValidHash(hash: remoteEtag, pattern: HubApi.shared.sha256Pattern) {
- let fileHash = try HubApi.shared.computeFileHash(file: destination)
+ if hub.isValidHash(hash: remoteEtag, pattern: hub.sha256Pattern) {
+ let fileHash = try hub.computeFileHash(file: destination)
if fileHash == remoteEtag {
- try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
+ try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return destination
}
}
@@ -441,9 +443,9 @@ extension HubApi {
switch state {
case .notStarted:
continue
- case .downloading(let progress):
+ case let .downloading(progress):
progressHandler(progress)
- case .failed(let error):
+ case let .failed(error):
throw error
case .completed:
break listen
@@ -455,22 +457,22 @@ extension HubApi {
}
}
- try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
+ try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return destination
}
}
@discardableResult
- public func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in })
+ func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in })
async throws -> URL
{
let repoDestination = localRepoLocation(repo)
let repoMetadataDestination =
repoDestination
- .appendingPathComponent(".cache")
- .appendingPathComponent("huggingface")
- .appendingPathComponent("download")
+ .appendingPathComponent(".cache")
+ .appendingPathComponent("huggingface")
+ .appendingPathComponent("download")
if await NetworkMonitor.shared.state.shouldUseOfflineMode() || useOfflineMode == true {
if !FileManager.default.fileExists(atPath: repoDestination.path) {
@@ -514,6 +516,7 @@ extension HubApi {
for filename in filenames {
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
let downloader = HubFileDownloader(
+ hub: self,
repo: repo,
revision: revision,
repoDestination: repoDestination,
@@ -540,25 +543,25 @@ extension HubApi {
}
@discardableResult
- public func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- return try await snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
+ func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler)
}
@discardableResult
- public func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- return try await snapshot(from: repo, matching: [glob], progressHandler: progressHandler)
+ func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ try await snapshot(from: repo, revision: revision, matching: [glob], progressHandler: progressHandler)
}
@discardableResult
- public func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- return try await snapshot(from: Repo(id: repoId), matching: [glob], progressHandler: progressHandler)
+ func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler)
}
}
/// Metadata
-extension HubApi {
+public extension HubApi {
/// Data structure containing information about a file versioned on the Hub
- public struct FileMetadata {
+ struct FileMetadata {
/// The commit hash related to the file
public let commitHash: String?
@@ -573,7 +576,7 @@ extension HubApi {
}
/// Metadata about a file in the local directory related to a download process
- public struct LocalDownloadFileMetadata {
+ struct LocalDownloadFileMetadata {
/// Commit hash of the file in the repo
public let commitHash: String
@@ -607,7 +610,7 @@ extension HubApi {
)
}
- public func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
+ func getFileMetadata(from repo: Repo, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
let files = try await getFilenames(from: repo, matching: globs)
let url = URL(string: "\(endpoint)/\(repo.id)/resolve/\(revision)")!
var selectedMetadata: [FileMetadata] = []
@@ -618,15 +621,15 @@ extension HubApi {
return selectedMetadata
}
- public func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
+ func getFileMetadata(from repoId: String, revision: String = "main", matching globs: [String] = []) async throws -> [FileMetadata] {
try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: globs)
}
- public func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
+ func getFileMetadata(from repo: Repo, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
try await getFileMetadata(from: repo, revision: revision, matching: [glob])
}
- public func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
+ func getFileMetadata(from repoId: String, revision: String = "main", matching glob: String) async throws -> [FileMetadata] {
try await getFileMetadata(from: Repo(id: repoId), revision: revision, matching: [glob])
}
}
@@ -639,9 +642,9 @@ extension HubApi {
public var isConstrained: Bool = false
func update(path: NWPath) {
- self.isConnected = path.status == .satisfied
- self.isExpensive = path.isExpensive
- self.isConstrained = path.isConstrained
+ isConnected = path.status == .satisfied
+ isExpensive = path.isExpensive
+ isConstrained = path.isConstrained
}
func shouldUseOfflineMode() -> Bool {
@@ -668,13 +671,13 @@ extension HubApi {
func startMonitoring() {
monitor.pathUpdateHandler = { [weak self] path in
- guard let self = self else { return }
+ guard let self else { return }
Task {
await self.state.update(path: path)
}
}
- monitor.start(queue: self.queue)
+ monitor.start(queue: queue)
}
func stopMonitoring() {
@@ -688,74 +691,74 @@ extension HubApi {
}
/// Stateless wrappers that use `HubApi` instances
-extension Hub {
- public static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] {
- return try await HubApi.shared.getFilenames(from: repo, matching: globs)
+public extension Hub {
+ static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] {
+ try await HubApi.shared.getFilenames(from: repo, matching: globs)
}
- public static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
- return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs)
+ static func getFilenames(from repoId: String, matching globs: [String] = []) async throws -> [String] {
+ try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: globs)
}
- public static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] {
- return try await HubApi.shared.getFilenames(from: repo, matching: glob)
+ static func getFilenames(from repo: Repo, matching glob: String) async throws -> [String] {
+ try await HubApi.shared.getFilenames(from: repo, matching: glob)
}
- public static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] {
- return try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob)
+ static func getFilenames(from repoId: String, matching glob: String) async throws -> [String] {
+ try await HubApi.shared.getFilenames(from: Repo(id: repoId), matching: glob)
}
- public static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- return try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler)
+ static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler)
}
- public static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws
+ static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws
-> URL
{
- return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
+ try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
}
- public static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- return try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler)
+ static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler)
}
- public static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
- return try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler)
+ static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
+ try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler)
}
- public static func whoami(token: String) async throws -> Config {
- return try await HubApi(hfToken: token).whoami()
+ static func whoami(token: String) async throws -> Config {
+ try await HubApi(hfToken: token).whoami()
}
- public static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata {
- return try await HubApi.shared.getFileMetadata(url: fileURL)
+ static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata {
+ try await HubApi.shared.getFileMetadata(url: fileURL)
}
- public static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
- return try await HubApi.shared.getFileMetadata(from: repo, matching: globs)
+ static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
+ try await HubApi.shared.getFileMetadata(from: repo, matching: globs)
}
- public static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
- return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs)
+ static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
+ try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs)
}
- public static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] {
- return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob])
+ static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] {
+ try await HubApi.shared.getFileMetadata(from: repo, matching: [glob])
}
- public static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] {
- return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob])
+ static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] {
+ try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob])
}
}
-extension [String] {
- public func matching(glob: String) -> [String] {
+public extension [String] {
+ func matching(glob: String) -> [String] {
filter { fnmatch(glob, $0, 0) == 0 }
}
}
-extension FileManager {
- public func getFileUrls(at directoryUrl: URL) throws -> [URL] {
+public extension FileManager {
+ func getFileUrls(at directoryUrl: URL) throws -> [URL] {
var fileUrls = [URL]()
// Get all contents including subdirectories
@@ -798,14 +801,13 @@ private final class RedirectDelegate: NSObject, URLSessionTaskDelegate, Sendable
if (300...399).contains(response.statusCode) {
// Get the Location header
if let locationString = response.value(forHTTPHeaderField: "Location"),
- let locationUrl = URL(string: locationString)
+ let locationUrl = URL(string: locationString)
{
-
// Check if it's a relative redirect (no host component)
if locationUrl.host == nil {
// For relative redirects, construct the new URL using the original request's base
if let originalUrl = task.originalRequest?.url,
- var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true)
+ var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true)
{
// Update the path component with the relative path
components.path = locationUrl.path
diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift
index 09e779c..d1533d2 100644
--- a/Tests/HubTests/DownloaderTests.swift
+++ b/Tests/HubTests/DownloaderTests.swift
@@ -19,20 +19,19 @@ enum DownloadError: LocalizedError {
var errorDescription: String? {
switch self {
case .invalidDownloadLocation:
- return String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid")
+ String(localized: "The download location is invalid or inaccessible.", comment: "Error when download destination is invalid")
case .unexpectedError:
- return String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message")
+ String(localized: "An unexpected error occurred during the download process.", comment: "Generic download error message")
}
}
}
private extension Downloader {
func interruptDownload() async {
- await self.session.get()?.invalidateAndCancel()
+ await session.get()?.invalidateAndCancel()
}
}
-
final class DownloaderTests: XCTestCase {
var tempDir: URL!
@@ -56,18 +55,18 @@ final class DownloaderTests: XCTestCase {
let etag = try await Hub.getFileMetadata(fileURL: url).etag!
let destination = tempDir.appendingPathComponent("config.json")
let fileContent = """
- {
- "architectures": [
- "LlamaForCausalLM"
- ],
- "bos_token_id": 1,
- "eos_token_id": 2,
- "model_type": "llama",
- "pad_token_id": 0,
- "vocab_size": 32000
- }
+ {
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "model_type": "llama",
+ "pad_token_id": 0,
+ "vocab_size": 32000
+ }
- """
+ """
let cacheDir = tempDir.appendingPathComponent("cache")
try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true)
@@ -82,9 +81,9 @@ final class DownloaderTests: XCTestCase {
switch state {
case .notStarted:
continue
- case .downloading(let progress):
+ case .downloading:
continue
- case .failed(let error):
+ case let .failed(error):
throw error
case .completed:
break listen
@@ -110,12 +109,12 @@ final class DownloaderTests: XCTestCase {
let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination)
// Download with incorrect expected size
- let sub = await downloader.download(from: url, expectedSize: 999999) // Incorrect size
+ let sub = await downloader.download(from: url, expectedSize: 999999) // Incorrect size
listen: for await state in sub {
switch state {
case .notStarted:
continue
- case .downloading(let progress):
+ case .downloading:
continue
case .failed:
break listen
@@ -149,7 +148,7 @@ final class DownloaderTests: XCTestCase {
FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil)
let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination)
- let sub = await downloader.download(from: url, expectedSize: 73_194_001) // Correct size for verification
+ let sub = await downloader.download(from: url, expectedSize: 73_194_001) // Correct size for verification
// First interruption point at 50%
var threshold = 0.5
@@ -161,19 +160,19 @@ final class DownloaderTests: XCTestCase {
switch state {
case .notStarted:
continue
- case .downloading(let progress):
- if threshold != 1.0 && progress >= threshold {
+ case let .downloading(progress):
+ if threshold != 1.0, progress >= threshold {
// Move to next threshold and interrupt
threshold = threshold == 0.5 ? 0.75 : 1.0
await downloader.interruptDownload()
}
- case .failed(let error):
+ case let .failed(error):
throw error
case .completed:
break listen
}
}
-
+
// Verify the file exists and is complete
if FileManager.default.fileExists(atPath: destination.path) {
let attributes = try FileManager.default.attributesOfItem(atPath: destination.path)
diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift
index 5a53452..61d4968 100644
--- a/Tests/HubTests/HubApiTests.swift
+++ b/Tests/HubTests/HubApiTests.swift
@@ -157,7 +157,7 @@ class HubApiTests: XCTestCase {
let url = URL(
string:
- "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel"
+ "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel"
)
let metadata = try await Hub.getFileMetadata(fileURL: url!)
@@ -266,7 +266,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertEqual(
Set(downloadedFilenames),
Set([
- "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json"
+ "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json",
])
)
}
@@ -415,7 +415,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertEqual(
Set(downloadedMetadataFilenames),
Set([
- ".cache/huggingface/download/tokenizer.json.metadata"
+ ".cache/huggingface/download/tokenizer.json.metadata",
])
)
@@ -544,7 +544,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertEqual(
Set(downloadedMetadataFilenames),
Set([
- ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata"
+ ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata",
])
)
@@ -1009,17 +1009,17 @@ class SnapshotDownloadTests: XCTestCase {
let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path)
let expected = """
- {
- "architectures": [
- "LlamaForCausalLM"
- ],
- "bos_token_id": 1,
- "eos_token_id": 2,
- "model_type": "llama",
- "pad_token_id": 0,
- "vocab_size": 32000
- }
- """
+ {
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "model_type": "llama",
+ "pad_token_id": 0,
+ "vocab_size": 32000
+ }
+ """
XCTAssertTrue(fileContents.contains(expected))
}
@@ -1048,17 +1048,17 @@ class SnapshotDownloadTests: XCTestCase {
let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path)
let expected = """
- X
- "architectures": [
- "LlamaForCausalLM"
- ],
- "bos_token_id": 1,
- "eos_token_id": 2,
- "model_type": "llama",
- "pad_token_id": 0,
- "vocab_size": 32000
- }
- """
+ X
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "model_type": "llama",
+ "pad_token_id": 0,
+ "vocab_size": 32000
+ }
+ """
XCTAssertTrue(fileContents.contains(expected))
}
@@ -1100,4 +1100,45 @@ class SnapshotDownloadTests: XCTestCase {
"Downloaded file should exist at \(filePath.path)"
)
}
+
+ func testDownloadWithRevision() async throws {
+ let hubApi = HubApi(downloadBase: downloadDestination)
+ var lastProgress: Progress? = nil
+
+ let commitHash = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb"
+ let downloadedTo = try await hubApi.snapshot(from: repo, revision: commitHash, matching: "*.json") { progress in
+ print("Total Progress: \(progress.fractionCompleted)")
+ print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)")
+ lastProgress = progress
+ }
+
+ let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo)
+ XCTAssertEqual(lastProgress?.fractionCompleted, 1)
+ XCTAssertEqual(lastProgress?.completedUnitCount, 6)
+ XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
+ XCTAssertEqual(
+ Set(downloadedFilenames),
+ Set([
+ "config.json", "tokenizer.json", "tokenizer_config.json",
+ "llama-2-7b-chat.mlpackage/Manifest.json",
+ "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json",
+ "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json",
+ ])
+ )
+
+ do {
+ let revision = "nonexistent-revision"
+ try await hubApi.snapshot(from: repo, revision: revision, matching: "*.json")
+ XCTFail("Expected an error to be thrown")
+ } catch let error as Hub.HubClientError {
+ switch error {
+ case .fileNotFound:
+ break // Error type is correct
+ default:
+ XCTFail("Wrong error type: \(error)")
+ }
+ } catch {
+ XCTFail("Unexpected error: \(error)")
+ }
+ }
}