diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift index f52c596..bbdb2f4 100644 --- a/Sources/Hub/Downloader.swift +++ b/Sources/Hub/Downloader.swift @@ -24,25 +24,39 @@ class Downloader: NSObject, ObservableObject { enum DownloadError: Error { case invalidDownloadLocation case unexpectedError + case tempFileNotFound } private(set) lazy var downloadState: CurrentValueSubject = CurrentValueSubject(.notStarted) private var stateSubscriber: Cancellable? + + private(set) var tempFilePath: URL + private(set) var expectedSize: Int? + private(set) var downloadedSize: Int = 0 - private var urlSession: URLSession? = nil - + var session: URLSession? = nil + var downloadTask: Task? = nil + init( from url: URL, to destination: URL, + incompleteDestination: URL, using authToken: String? = nil, inBackground: Bool = false, - resumeSize: Int = 0, headers: [String: String]? = nil, expectedSize: Int? = nil, timeout: TimeInterval = 10, numRetries: Int = 5 ) { self.destination = destination + self.expectedSize = expectedSize + + // Create incomplete file path based on destination + tempFilePath = incompleteDestination + + // If resume size wasn't specified, check for an existing incomplete file + let resumeSize = Self.incompleteFileSize(at: incompleteDestination) + super.init() let sessionIdentifier = "swift-transformers.hub.downloader" @@ -53,9 +67,22 @@ class Downloader: NSObject, ObservableObject { config.sessionSendsLaunchEvents = true } - urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil) + session = URLSession(configuration: config, delegate: self, delegateQueue: nil) - setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries) + setUpDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries) + } + + /// Check if an incomplete file exists for the destination and returns its size + /// - Parameter destination: The destination URL for the download + /// - Returns: Size of the incomplete file if it exists, otherwise 0 + static func incompleteFileSize(at incompletePath: URL) -> Int { + if FileManager.default.fileExists(atPath: incompletePath.path) { + if let attributes = try? FileManager.default.attributesOfItem(atPath: incompletePath.path), let fileSize = attributes[.size] as? Int { + return fileSize + } + } + + return 0 } /// Sets up and initiates a file download operation @@ -68,7 +95,7 @@ class Downloader: NSObject, ObservableObject { /// - expectedSize: Expected file size in bytes for validation /// - timeout: Time interval before the request times out /// - numRetries: Number of retry attempts for failed downloads - private func setupDownload( + private func setUpDownload( from url: URL, with authToken: String?, resumeSize: Int, @@ -77,59 +104,67 @@ class Downloader: NSObject, ObservableObject { timeout: TimeInterval, numRetries: Int ) { - downloadState.value = .downloading(0) - urlSession?.getAllTasks { tasks in + session?.getAllTasks { tasks in // If there's an existing pending background task with the same URL, let it proceed. if let existing = tasks.filter({ $0.originalRequest?.url == url }).first { switch existing.state { case .running: - // print("Already downloading \(url)") return case .suspended: - // print("Resuming suspended download task for \(url)") existing.resume() return - case .canceling: - // print("Starting new download task for \(url), previous was canceling") - break - case .completed: - // print("Starting new download task for \(url), previous is complete but the file is no longer present (I think it's cached)") - break + case .canceling, .completed: + existing.cancel() @unknown default: - // print("Unknown state for running task; cancelling and creating a new one") existing.cancel() } } - var request = URLRequest(url: url) - // Use headers from argument else create an empty header dictionary - var requestHeaders = headers ?? [:] - - // Populate header auth and range fields - if let authToken { - requestHeaders["Authorization"] = "Bearer \(authToken)" - } - if resumeSize > 0 { - requestHeaders["Range"] = "bytes=\(resumeSize)-" - } - - request.timeoutInterval = timeout - request.allHTTPHeaderFields = requestHeaders - - Task { + self.downloadTask = Task { do { - // Create a temp file to write - let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString) - FileManager.default.createFile(atPath: tempURL.path, contents: nil) - let tempFile = try FileHandle(forWritingTo: tempURL) + // Set up the request with appropriate headers + var request = URLRequest(url: url) + var requestHeaders = headers ?? [:] + + if let authToken { + requestHeaders["Authorization"] = "Bearer \(authToken)" + } + + self.downloadedSize = resumeSize + + // Set Range header if we're resuming + if resumeSize > 0 { + requestHeaders["Range"] = "bytes=\(resumeSize)-" + + // Calculate and show initial progress + if let expectedSize, expectedSize > 0 { + let initialProgress = Double(resumeSize) / Double(expectedSize) + self.downloadState.value = .downloading(initialProgress) + } else { + self.downloadState.value = .downloading(0) + } + } else { + self.downloadState.value = .downloading(0) + } - defer { tempFile.closeFile() } - try await self.httpGet(request: request, tempFile: tempFile, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize) + request.timeoutInterval = timeout + request.allHTTPHeaderFields = requestHeaders + + // Open the incomplete file for writing + let tempFile = try FileHandle(forWritingTo: self.tempFilePath) + + // If resuming, seek to end of file + if resumeSize > 0 { + try tempFile.seekToEnd() + } + + try await self.httpGet(request: request, tempFile: tempFile, resumeSize: self.downloadedSize, numRetries: numRetries, expectedSize: expectedSize) // Clean up and move the completed download to its final destination tempFile.closeFile() - try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination) + try Task.checkCancellation() + try FileManager.default.moveDownloadedFile(from: self.tempFilePath, to: self.destination) self.downloadState.value = .completed(self.destination) } catch { self.downloadState.value = .failed(error) @@ -156,7 +191,7 @@ class Downloader: NSObject, ObservableObject { numRetries: Int, expectedSize: Int? ) async throws { - guard let session = urlSession else { + guard let session else { throw DownloadError.unexpectedError } @@ -169,16 +204,13 @@ class Downloader: NSObject, ObservableObject { // Start the download and get the byte stream let (asyncBytes, response) = try await session.bytes(for: newRequest) - guard let response = response as? HTTPURLResponse else { + guard let httpResponse = response as? HTTPURLResponse else { throw DownloadError.unexpectedError } - - guard (200..<300).contains(response.statusCode) else { + guard (200..<300).contains(httpResponse.statusCode) else { throw DownloadError.unexpectedError } - var downloadedSize = resumeSize - // Create a buffer to collect bytes before writing to disk var buffer = Data(capacity: chunkSize) @@ -213,12 +245,12 @@ class Downloader: NSObject, ObservableObject { try await Task.sleep(nanoseconds: 1_000_000_000) let config = URLSessionConfiguration.default - self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil) + self.session = URLSession(configuration: config, delegate: self, delegateQueue: nil) try await httpGet( request: request, tempFile: tempFile, - resumeSize: downloadedSize, + resumeSize: self.downloadedSize, numRetries: newNumRetries - 1, expectedSize: expectedSize ) @@ -252,7 +284,9 @@ class Downloader: NSObject, ObservableObject { } func cancel() { - urlSession?.invalidateAndCancel() + session?.invalidateAndCancel() + downloadTask?.cancel() + downloadState.value = .failed(URLError(.cancelled)) } } @@ -284,9 +318,13 @@ extension Downloader: URLSessionDownloadDelegate { extension FileManager { func moveDownloadedFile(from srcURL: URL, to dstURL: URL) throws { - if fileExists(atPath: dstURL.path) { + if fileExists(atPath: dstURL.path()) { try removeItem(at: dstURL) } + + let directoryURL = dstURL.deletingLastPathComponent() + try createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) + try moveItem(at: srcURL, to: dstURL) } } diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index fe8f461..0634b69 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -51,13 +51,13 @@ public extension Hub { } } - enum RepoType: String { + enum RepoType: String, Codable { case models case datasets case spaces } - - struct Repo { + + struct Repo: Codable { public let id: String public let type: RepoType diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 368b805..a9d5020 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -366,14 +366,13 @@ public extension HubApi { FileManager.default.fileExists(atPath: destination.path) } - func prepareDestination() throws { - let directoryURL = destination.deletingLastPathComponent() - try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) - } - - func prepareMetadataDestination() throws { - let directoryURL = metadataDestination.deletingLastPathComponent() + /// We're using incomplete destination to prepare cache destination because incomplete files include lfs + non-lfs files (vs only lfs for metadata files) + func prepareCacheDestination(_ incompleteDestination: URL) throws { + let directoryURL = incompleteDestination.deletingLastPathComponent() try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) + if !FileManager.default.fileExists(atPath: incompleteDestination.path) { + try "".write(to: incompleteDestination, atomically: true, encoding: .utf8) + } } /// Note we go from Combine in Downloader to callback-based progress reporting @@ -423,22 +422,42 @@ public extension HubApi { } // Otherwise, let's download the file! - try prepareDestination() - try prepareMetadataDestination() - - let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession, expectedSize: remoteSize) - let downloadSubscriber = downloader.downloadState.sink { state in - if case let .downloading(progress) = state { - progressHandler(progress) + let incompleteDestination = repoMetadataDestination.appending(path: relativeFilename + ".\(remoteEtag).incomplete") + try prepareCacheDestination(incompleteDestination) + + let downloader = Downloader( + from: source, + to: destination, + incompleteDestination: incompleteDestination, + using: hfToken, + inBackground: backgroundSession, + expectedSize: remoteSize + ) + + return try await withTaskCancellationHandler { + let downloadSubscriber = downloader.downloadState.sink { state in + switch state { + case let .downloading(progress): + progressHandler(progress) + case .completed, .failed, .notStarted: + break + } } + do { + _ = try withExtendedLifetime(downloadSubscriber) { + try downloader.waitUntilDone() + } + + try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) + + return destination + } catch { + // If download fails, leave the incomplete file in place for future resume + throw error + } + } onCancel: { + downloader.cancel() } - _ = try withExtendedLifetime(downloadSubscriber) { - try downloader.waitUntilDone() - } - - try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) - - return destination } } diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift index 2c6ed02..0f8d106 100644 --- a/Tests/HubTests/DownloaderTests.swift +++ b/Tests/HubTests/DownloaderTests.swift @@ -24,6 +24,12 @@ enum DownloadError: LocalizedError { } } +private extension Downloader { + func interruptDownload() { + session?.invalidateAndCancel() + } +} + final class DownloaderTests: XCTestCase { var tempDir: URL! @@ -44,6 +50,7 @@ final class DownloaderTests: XCTestCase { func testSuccessfulDownload() async throws { // Create a test file let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")! + let etag = try await Hub.getFileMetadata(fileURL: url).etag! let destination = tempDir.appendingPathComponent("config.json") let fileContent = """ { @@ -59,9 +66,16 @@ final class DownloaderTests: XCTestCase { """ + let cacheDir = tempDir.appendingPathComponent("cache") + try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true) + + let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") + FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) + let downloader = Downloader( from: url, - to: destination + to: destination, + incompleteDestination: incompleteDestination ) // Store subscriber outside the continuation to maintain its lifecycle @@ -93,12 +107,20 @@ final class DownloaderTests: XCTestCase { /// This test attempts to download with incorrect expected file, verifies the download fails, ensures no partial file is left behind func testDownloadFailsWithIncorrectSize() async throws { let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")! + let etag = try await Hub.getFileMetadata(fileURL: url).etag! let destination = tempDir.appendingPathComponent("config.json") + let cacheDir = tempDir.appendingPathComponent("cache") + try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true) + + let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") + FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) + // Create downloader with incorrect expected size let downloader = Downloader( from: url, to: destination, + incompleteDestination: incompleteDestination, expectedSize: 999999 // Incorrect size ) @@ -115,15 +137,23 @@ final class DownloaderTests: XCTestCase { /// verifies the download can resume and complete successfully, checks the final file exists and has content func testSuccessfulInterruptedDownload() async throws { let url = URL(string: "https://huggingface.co/coreml-projects/sam-2-studio/resolve/main/SAM%202%20Studio%201.1.zip")! + let etag = try await Hub.getFileMetadata(fileURL: url).etag! let destination = tempDir.appendingPathComponent("SAM%202%20Studio%201.1.zip") // Create parent directory if it doesn't exist try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(), withIntermediateDirectories: true) - + + let cacheDir = tempDir.appendingPathComponent("cache") + try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true) + + let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") + FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) + let downloader = Downloader( from: url, to: destination, + incompleteDestination: incompleteDestination, expectedSize: 73194001 // Correct size for verification ) @@ -142,7 +172,7 @@ final class DownloaderTests: XCTestCase { if threshold != 1.0, progress >= threshold { // Move to next threshold and interrupt threshold = threshold == 0.5 ? 0.75 : 1.0 - downloader.cancel() + downloader.interruptDownload() } case .completed: continuation.resume() diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 451816f..6b0d150 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -173,9 +173,9 @@ class SnapshotDownloadTests: XCTestCase { let base = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first! return base.appending(component: "huggingface-tests") }() - + override func setUp() { } - + override func tearDown() { do { try FileManager.default.removeItem(at: downloadDestination) @@ -183,11 +183,11 @@ class SnapshotDownloadTests: XCTestCase { print("Can't remove test download destination \(downloadDestination), error: \(error)") } } - + func getRelativeFiles(url: URL, repo: String) -> [String] { var filenames: [String] = [] let prefix = downloadDestination.appending(path: "models/\(repo)").path.appending("/") - + if let enumerator = FileManager.default.enumerator(at: url, includingPropertiesForKeys: [.isRegularFileKey], options: [.skipsHiddenFiles], errorHandler: nil) { for case let fileURL as URL in enumerator { do { @@ -202,7 +202,7 @@ class SnapshotDownloadTests: XCTestCase { } return filenames } - + func testDownload() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil @@ -226,7 +226,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 6) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + XCTAssertEqual( Set(downloadedFilenames), Set([ @@ -237,7 +237,7 @@ class SnapshotDownloadTests: XCTestCase { ]) ) } - + /// Background sessions get rate limited by the OS, see discussion here: https://github.com/huggingface/swift-transformers/issues/61 /// Test only one file at a time func testDownloadInBackground() async throws { @@ -251,7 +251,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual( Set(downloadedFilenames), @@ -260,7 +260,7 @@ class SnapshotDownloadTests: XCTestCase { ]) ) } - + func testCustomEndpointDownload() async throws { let hubApi = HubApi(downloadBase: downloadDestination, endpoint: "https://hf-mirror.com") var lastProgress: Progress? = nil @@ -272,7 +272,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 6) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual( Set(downloadedFilenames), @@ -326,7 +326,7 @@ class SnapshotDownloadTests: XCTestCase { func testDownloadFileMetadataExists() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -336,7 +336,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 6) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual( Set(downloadedFilenames), @@ -375,7 +375,7 @@ class SnapshotDownloadTests: XCTestCase { attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) let secondDownloadTimestamp = attributes[.modificationDate] as! Date - + // File will not be downloaded again thus last modified date will remain unchanged XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) } @@ -383,7 +383,7 @@ class SnapshotDownloadTests: XCTestCase { func testDownloadFileMetadataSame() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "tokenizer.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -393,7 +393,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual(Set(downloadedFilenames), Set(["tokenizer.json"])) @@ -430,7 +430,7 @@ class SnapshotDownloadTests: XCTestCase { func testDownloadFileMetadataCorrupted() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -440,7 +440,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 6) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual( Set(downloadedFilenames), @@ -483,7 +483,7 @@ class SnapshotDownloadTests: XCTestCase { attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) let secondDownloadTimestamp = attributes[.modificationDate] as! Date - + // File will be downloaded again thus last modified date will change XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) @@ -499,7 +499,7 @@ class SnapshotDownloadTests: XCTestCase { attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) let thirdDownloadTimestamp = attributes[.modificationDate] as! Date - + // File will be downloaded again thus last modified date will change XCTAssertTrue(originalTimestamp != thirdDownloadTimestamp) } @@ -507,7 +507,7 @@ class SnapshotDownloadTests: XCTestCase { func testDownloadLargeFileMetadataCorrupted() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.mlmodel") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -517,7 +517,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual( Set(downloadedFilenames), @@ -552,7 +552,7 @@ class SnapshotDownloadTests: XCTestCase { attributes = try FileManager.default.attributesOfItem(atPath: modelPath.path) let thirdDownloadTimestamp = attributes[.modificationDate] as! Date - + // File will not be downloaded again because this is an LFS file. // While downloading LFS files, we first check if local file ETag is the same as remote ETag. // If that's the case we just update the metadata and keep the local file. @@ -577,7 +577,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual(Set(downloadedFilenames), Set(["llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel"])) @@ -590,7 +590,7 @@ class SnapshotDownloadTests: XCTestCase { let metadataFile = metadataDestination.appendingPathComponent("llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata") let metadataString = try String(contentsOfFile: metadataFile.path) - + let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nfc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107" XCTAssertTrue(metadataString.contains(expected)) } @@ -607,7 +607,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) @@ -637,7 +637,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) @@ -665,7 +665,7 @@ class SnapshotDownloadTests: XCTestCase { func testLFSFileNoMetadata() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -675,7 +675,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) @@ -702,7 +702,7 @@ class SnapshotDownloadTests: XCTestCase { attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) let secondDownloadTimestamp = attributes[.modificationDate] as! Date - + // File will not be downloaded again thus last modified date will remain unchanged XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) @@ -716,7 +716,7 @@ class SnapshotDownloadTests: XCTestCase { func testLFSFileCorruptedMetadata() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -726,7 +726,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) @@ -744,7 +744,7 @@ class SnapshotDownloadTests: XCTestCase { let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") try "a".write(to: metadataFile, atomically: true, encoding: .utf8) - + let _ = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -753,7 +753,7 @@ class SnapshotDownloadTests: XCTestCase { attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) let secondDownloadTimestamp = attributes[.modificationDate] as! Date - + // File will not be downloaded again thus last modified date will remain unchanged XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) @@ -767,7 +767,7 @@ class SnapshotDownloadTests: XCTestCase { func testNonLFSFileRedownload() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -777,7 +777,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual(Set(downloadedFilenames), Set(["config.json"])) @@ -804,7 +804,7 @@ class SnapshotDownloadTests: XCTestCase { attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) let secondDownloadTimestamp = attributes[.modificationDate] as! Date - + // File will be downloaded again thus last modified date will change XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) @@ -818,7 +818,7 @@ class SnapshotDownloadTests: XCTestCase { func testOfflineModeReturnsDestination() async throws { var hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + var downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -831,7 +831,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) - + downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -845,7 +845,7 @@ class SnapshotDownloadTests: XCTestCase { func testOfflineModeThrowsError() async throws { let hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) - + do { try await hubApi.snapshot(from: repo, matching: "*.json") XCTFail("Expected an error to be thrown") @@ -901,7 +901,7 @@ class SnapshotDownloadTests: XCTestCase { func testOfflineModeWithCorruptedLFSMetadata() async throws { var hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "*") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -918,7 +918,7 @@ class SnapshotDownloadTests: XCTestCase { try "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2ab4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4\n0\n".write(to: metadataDestination, atomically: true, encoding: .utf8) hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) - + do { try await hubApi.snapshot(from: lfsRepo, matching: "*") XCTFail("Expected an error to be thrown") @@ -937,7 +937,7 @@ class SnapshotDownloadTests: XCTestCase { func testOfflineModeWithNoFiles() async throws { var hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") @@ -953,7 +953,7 @@ class SnapshotDownloadTests: XCTestCase { try FileManager.default.removeItem(at: fileDestination) hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) - + do { try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") XCTFail("Expected an error to be thrown") @@ -968,4 +968,117 @@ class SnapshotDownloadTests: XCTestCase { XCTFail("Unexpected error: \(error)") } } + + func testResumeDownloadFromEmptyIncomplete() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + var downloadedTo = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent("Library/Caches/huggingface-tests/models/coreml-projects/Llama-2-7b-chat-coreml") + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")! + let etag = try await Hub.getFileMetadata(fileURL: url).etag! + + try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil) + try "".write(to: metadataDestination.appendingPathComponent("config.json.\(etag).incomplete"), atomically: true, encoding: .utf8) + downloadedTo = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) + + let expected = """ + { + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "model_type": "llama", + "pad_token_id": 0, + "vocab_size": 32000 + } + """ + XCTAssertTrue(fileContents.contains(expected)) + } + + func testResumeDownloadFromNonEmptyIncomplete() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + var downloadedTo = FileManager.default.homeDirectoryForCurrentUser + .appendingPathComponent("Library/Caches/huggingface-tests/models/coreml-projects/Llama-2-7b-chat-coreml") + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")! + let etag = try await Hub.getFileMetadata(fileURL: url).etag! + + try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil) + try "X".write(to: metadataDestination.appendingPathComponent("config.json.\(etag).incomplete"), atomically: true, encoding: .utf8) + downloadedTo = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) + + let expected = """ + X + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "model_type": "llama", + "pad_token_id": 0, + "vocab_size": 32000 + } + """ + XCTAssertTrue(fileContents.contains(expected)) + } + + func testRealDownloadInterruptionAndResumption() async throws { + // Use the DepthPro model weights file + let targetFile = "SAM 2 Studio 1.1.zip" + let repo = "coreml-projects/sam-2-studio" + let hubApi = HubApi(downloadBase: downloadDestination) + + // Create expectation for first progress update + let progressExpectation = expectation(description: "First progress update received") + + // Create a task for the download + let downloadTask = Task { + try await hubApi.snapshot(from: repo, matching: targetFile) { progress in + print("Progress reached 1 \(progress.fractionCompleted * 100)%") + if progress.fractionCompleted > 0 { + progressExpectation.fulfill() + } + } + } + + // Wait for the first progress update + await fulfillment(of: [progressExpectation], timeout: 30.0) + + // Cancel the download once we've seen progress + downloadTask.cancel() + try await Task.sleep(nanoseconds: 5_000_000_000) + + // Resume download with a new task + let downloadedTo = try await hubApi.snapshot(from: repo, matching: targetFile) { progress in + print("Progress reached 2 \(progress.fractionCompleted * 100)%") + } + + let filePath = downloadedTo.appendingPathComponent(targetFile) + XCTAssertTrue(FileManager.default.fileExists(atPath: filePath.path), + "Downloaded file should exist at \(filePath.path)") + } }