Skip to content

Commit 477e10a

Browse files
committed
add test cases and revise temp file download logic
1 parent abf5b16 commit 477e10a

File tree

2 files changed

+278
-2
lines changed

2 files changed

+278
-2
lines changed

Sources/Hub/Downloader.swift

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,134 @@ class Downloader: NSObject, ObservableObject {
7171
}
7272
}
7373
var request = URLRequest(url: url)
74+
75+
// Use headers from argument else create an empty header dictionary
76+
var requestHeaders = headers ?? [:]
77+
78+
// Populate header auth and range fields
7479
if let authToken = authToken {
75-
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
80+
requestHeaders["Authorization"] = "Bearer \(authToken)"
7681
}
82+
if resumeSize > 0 {
83+
requestHeaders["Range"] = "bytes=\(resumeSize)-"
84+
}
85+
86+
87+
request.timeoutInterval = timeout
88+
request.allHTTPHeaderFields = requestHeaders
7789

78-
self.urlSession?.downloadTask(with: request).resume()
90+
Task {
91+
do {
92+
// Create a temp file to write
93+
let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
94+
FileManager.default.createFile(atPath: tempURL.path, contents: nil)
95+
let tempFile = try FileHandle(forWritingTo: tempURL)
96+
97+
defer { tempFile.closeFile() }
98+
try await self.httpGet(request: request, tempFile: tempFile, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize)
99+
100+
// Clean up and move the completed download to its final destination
101+
tempFile.closeFile()
102+
try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination)
103+
104+
self.downloadState.value = .completed(self.destination)
105+
} catch {
106+
self.downloadState.value = .failed(error)
107+
}
108+
}
109+
}
110+
}
111+
112+
/// Downloads a file from given URL using chunked transfer and handles retries.
113+
///
114+
/// Reference: https://github.com/huggingface/huggingface_hub/blob/418a6ffce7881f5c571b2362ed1c23ef8e4d7d20/src/huggingface_hub/file_download.py#L306
115+
///
116+
/// - Parameters:
117+
/// - request: The URLRequest for the file to download
118+
/// - resumeSize: The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position
119+
/// - numRetries: The number of retry attempts remaining for failed downloads
120+
/// - expectedSize: The expected size of the file to download. If set, the download will raise an error if the size of the received content is different from the expected one.
121+
/// - Throws: `DownloadError.unexpectedError` if the response is invalid or file size mismatch occurs
122+
/// `URLError` if the download fails after all retries are exhausted
123+
private func httpGet(
124+
request: URLRequest,
125+
tempFile: FileHandle,
126+
resumeSize: Int,
127+
numRetries: Int,
128+
expectedSize: Int?
129+
) async throws {
130+
guard let session = self.urlSession else {
131+
throw DownloadError.unexpectedError
132+
}
133+
134+
// Create a new request with Range header for resuming
135+
var newRequest = request
136+
if resumeSize > 0 {
137+
newRequest.setValue("bytes=\(resumeSize)-", forHTTPHeaderField: "Range")
138+
}
139+
140+
// Start the download and get the byte stream
141+
let (asyncBytes, response) = try await session.bytes(for: newRequest)
142+
143+
guard let response = response as? HTTPURLResponse else {
144+
throw DownloadError.unexpectedError
145+
}
146+
147+
guard (200..<300).contains(response.statusCode) else {
148+
throw DownloadError.unexpectedError
149+
}
150+
151+
var downloadedSize = resumeSize
152+
153+
// Create a buffer to collect bytes before writing to disk
154+
var buffer = Data(capacity: chunkSize)
155+
156+
var newNumRetries = numRetries
157+
do {
158+
for try await byte in asyncBytes {
159+
buffer.append(byte)
160+
// When buffer is full, write to disk
161+
if buffer.count == chunkSize {
162+
if !buffer.isEmpty { // Filter out keep-alive chunks
163+
try tempFile.write(contentsOf: buffer)
164+
buffer.removeAll(keepingCapacity: true)
165+
downloadedSize += chunkSize
166+
newNumRetries = 5
167+
guard let expectedSize = expectedSize else { continue }
168+
let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0
169+
downloadState.value = .downloading(progress)
170+
}
171+
}
172+
}
173+
174+
if !buffer.isEmpty {
175+
try tempFile.write(contentsOf: buffer)
176+
downloadedSize += buffer.count
177+
buffer.removeAll(keepingCapacity: true)
178+
newNumRetries = 5
179+
}
180+
} catch let error as URLError {
181+
if newNumRetries <= 0 {
182+
throw error
183+
}
184+
try await Task.sleep(nanoseconds: 1_000_000_000)
185+
186+
let config = URLSessionConfiguration.default
187+
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
188+
189+
try await httpGet(
190+
request: request,
191+
tempFile: tempFile,
192+
resumeSize: downloadedSize,
193+
numRetries: newNumRetries - 1,
194+
expectedSize: expectedSize
195+
)
196+
}
197+
198+
// Verify the downloaded file size matches the expected size
199+
let actualSize = try tempFile.seekToEnd()
200+
if let expectedSize = expectedSize, expectedSize != actualSize {
201+
throw DownloadError.unexpectedError
79202
}
80203
}
81204

Tests/HubTests/DownloaderTests.swift

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//
2+
// DownloaderTests.swift
3+
// swift-transformers
4+
//
5+
// Created by Arda Atahan Ibis on 1/28/25.
6+
//
7+
8+
import XCTest
9+
@testable import Hub
10+
import Combine
11+
enum DownloadError: Error {
12+
case invalidDownloadLocation
13+
case unexpectedError
14+
}
15+
final class DownloaderTests: XCTestCase {
16+
var tempDir: URL!
17+
18+
override func setUp() {
19+
super.setUp()
20+
tempDir = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
21+
try? FileManager.default.createDirectory(at: tempDir, withIntermediateDirectories: true)
22+
}
23+
24+
override func tearDown() {
25+
try? FileManager.default.removeItem(at: tempDir)
26+
super.tearDown()
27+
}
28+
29+
func testSuccessfulDownload() async throws {
30+
// Create a test file
31+
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")!
32+
let destination = tempDir.appendingPathComponent("config.json")
33+
let fileContent = """
34+
{
35+
"architectures": [
36+
"LlamaForCausalLM"
37+
],
38+
"bos_token_id": 1,
39+
"eos_token_id": 2,
40+
"model_type": "llama",
41+
"pad_token_id": 0,
42+
"vocab_size": 32000
43+
}
44+
45+
"""
46+
47+
let downloader = Downloader(
48+
from: url,
49+
to: destination,
50+
metadataDirURL: tempDir
51+
)
52+
53+
// Store subscriber outside the continuation to maintain its lifecycle
54+
var subscriber: AnyCancellable?
55+
56+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
57+
subscriber = downloader.downloadState.sink { state in
58+
switch state {
59+
case .completed:
60+
continuation.resume()
61+
case .failed(let error):
62+
continuation.resume(throwing: error)
63+
case .downloading:
64+
break
65+
case .notStarted:
66+
break
67+
}
68+
}
69+
}
70+
71+
// Cancel subscription after continuation completes
72+
subscriber?.cancel()
73+
74+
// Verify download completed successfully
75+
XCTAssertTrue(FileManager.default.fileExists(atPath: destination.path))
76+
XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), fileContent)
77+
}
78+
79+
func testDownloadFailsWithIncorrectSize() async throws {
80+
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")!
81+
let destination = tempDir.appendingPathComponent("config.json")
82+
83+
// Create downloader with incorrect expected size
84+
let downloader = Downloader(
85+
from: url,
86+
to: destination,
87+
metadataDirURL: tempDir,
88+
expectedSize: 999999 // Incorrect size
89+
)
90+
91+
do {
92+
try downloader.waitUntilDone()
93+
XCTFail("Download should have failed due to size mismatch")
94+
} catch {
95+
96+
}
97+
98+
// Verify no file was created at destination
99+
XCTAssertFalse(FileManager.default.fileExists(atPath: destination.path))
100+
}
101+
102+
func testSuccessfulInterruptedDownload() async throws {
103+
let url = URL(string: "https://huggingface.co/coreml-projects/sam-2-studio/resolve/main/SAM%202%20Studio%201.1.zip")!
104+
let destination = tempDir.appendingPathComponent("SAM%202%20Studio%201.1.zip")
105+
106+
// Create parent directory if it doesn't exist
107+
try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(),
108+
withIntermediateDirectories: true)
109+
110+
let downloader = Downloader(
111+
from: url,
112+
to: destination,
113+
metadataDirURL: tempDir,
114+
expectedSize: 73194001
115+
)
116+
var threshold = 0.5
117+
118+
var subscriber: AnyCancellable?
119+
120+
do {
121+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
122+
subscriber = downloader.downloadState.sink { state in
123+
switch state {
124+
case .downloading(let progress):
125+
if threshold != 1.0 && progress >= threshold {
126+
threshold = threshold == 0.5 ? 0.75 : 1.0
127+
downloader.cancel()
128+
}
129+
case .completed:
130+
continuation.resume()
131+
case .failed(let error):
132+
continuation.resume(throwing: error)
133+
case .notStarted:
134+
break
135+
}
136+
}
137+
}
138+
139+
subscriber?.cancel()
140+
141+
// Verify the file exists and is complete
142+
if FileManager.default.fileExists(atPath: destination.path) {
143+
let attributes = try FileManager.default.attributesOfItem(atPath: destination.path)
144+
let finalSize = attributes[.size] as! Int64
145+
XCTAssertGreaterThan(finalSize, 0, "File should not be empty")
146+
} else {
147+
XCTFail("File was not created at destination")
148+
}
149+
} catch {
150+
throw error
151+
}
152+
}
153+
}

0 commit comments

Comments
 (0)