Skip to content

Commit 4f97f98

Browse files
authored
Merge pull request #175 from argmaxinc/download
Add resumable download support with tests
2 parents abf5b16 + a808140 commit 4f97f98

File tree

4 files changed

+324
-8
lines changed

4 files changed

+324
-8
lines changed

Sources/Hub/Downloader.swift

Lines changed: 158 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import Combine
1212
class Downloader: NSObject, ObservableObject {
1313
private(set) var destination: URL
1414

15+
private let chunkSize = 10 * 1024 * 1024 // 10MB
16+
1517
enum DownloadState {
1618
case notStarted
1719
case downloading(Double)
@@ -29,7 +31,17 @@ class Downloader: NSObject, ObservableObject {
2931

3032
private var urlSession: URLSession? = nil
3133

32-
init(from url: URL, to destination: URL, using authToken: String? = nil, inBackground: Bool = false) {
34+
init(
35+
from url: URL,
36+
to destination: URL,
37+
using authToken: String? = nil,
38+
inBackground: Bool = false,
39+
resumeSize: Int = 0,
40+
headers: [String: String]? = nil,
41+
expectedSize: Int? = nil,
42+
timeout: TimeInterval = 10,
43+
numRetries: Int = 5
44+
) {
3345
self.destination = destination
3446
super.init()
3547
let sessionIdentifier = "swift-transformers.hub.downloader"
@@ -43,10 +55,28 @@ class Downloader: NSObject, ObservableObject {
4355

4456
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
4557

46-
setupDownload(from: url, with: authToken)
58+
setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
4759
}
4860

49-
private func setupDownload(from url: URL, with authToken: String?) {
61+
/// Sets up and initiates a file download operation
62+
///
63+
/// - Parameters:
64+
/// - url: Source URL to download from
65+
/// - authToken: Bearer token for authentication with Hugging Face
66+
/// - resumeSize: Number of bytes already downloaded for resuming interrupted downloads
67+
/// - headers: Additional HTTP headers to include in the request
68+
/// - expectedSize: Expected file size in bytes for validation
69+
/// - timeout: Time interval before the request times out
70+
/// - numRetries: Number of retry attempts for failed downloads
71+
private func setupDownload(
72+
from url: URL,
73+
with authToken: String?,
74+
resumeSize: Int,
75+
headers: [String: String]?,
76+
expectedSize: Int?,
77+
timeout: TimeInterval,
78+
numRetries: Int
79+
) {
5080
downloadState.value = .downloading(0)
5181
urlSession?.getAllTasks { tasks in
5282
// If there's an existing pending background task with the same URL, let it proceed.
@@ -71,14 +101,137 @@ class Downloader: NSObject, ObservableObject {
71101
}
72102
}
73103
var request = URLRequest(url: url)
104+
105+
// Use headers from argument else create an empty header dictionary
106+
var requestHeaders = headers ?? [:]
107+
108+
// Populate header auth and range fields
74109
if let authToken = authToken {
75-
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
110+
requestHeaders["Authorization"] = "Bearer \(authToken)"
111+
}
112+
if resumeSize > 0 {
113+
requestHeaders["Range"] = "bytes=\(resumeSize)-"
76114
}
115+
116+
117+
request.timeoutInterval = timeout
118+
request.allHTTPHeaderFields = requestHeaders
77119

78-
self.urlSession?.downloadTask(with: request).resume()
120+
Task {
121+
do {
122+
// Create a temp file to write
123+
let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
124+
FileManager.default.createFile(atPath: tempURL.path, contents: nil)
125+
let tempFile = try FileHandle(forWritingTo: tempURL)
126+
127+
defer { tempFile.closeFile() }
128+
try await self.httpGet(request: request, tempFile: tempFile, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize)
129+
130+
// Clean up and move the completed download to its final destination
131+
tempFile.closeFile()
132+
try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination)
133+
134+
self.downloadState.value = .completed(self.destination)
135+
} catch {
136+
self.downloadState.value = .failed(error)
137+
}
138+
}
79139
}
80140
}
81141

142+
/// Downloads a file from given URL using chunked transfer and handles retries.
143+
///
144+
/// Reference: https://github.com/huggingface/huggingface_hub/blob/418a6ffce7881f5c571b2362ed1c23ef8e4d7d20/src/huggingface_hub/file_download.py#L306
145+
///
146+
/// - Parameters:
147+
/// - request: The URLRequest for the file to download
148+
/// - resumeSize: The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position
149+
/// - numRetries: The number of retry attempts remaining for failed downloads
150+
/// - expectedSize: The expected size of the file to download. If set, the download will raise an error if the size of the received content is different from the expected one.
151+
/// - Throws: `DownloadError.unexpectedError` if the response is invalid or file size mismatch occurs
152+
/// `URLError` if the download fails after all retries are exhausted
153+
private func httpGet(
154+
request: URLRequest,
155+
tempFile: FileHandle,
156+
resumeSize: Int,
157+
numRetries: Int,
158+
expectedSize: Int?
159+
) async throws {
160+
guard let session = self.urlSession else {
161+
throw DownloadError.unexpectedError
162+
}
163+
164+
// Create a new request with Range header for resuming
165+
var newRequest = request
166+
if resumeSize > 0 {
167+
newRequest.setValue("bytes=\(resumeSize)-", forHTTPHeaderField: "Range")
168+
}
169+
170+
// Start the download and get the byte stream
171+
let (asyncBytes, response) = try await session.bytes(for: newRequest)
172+
173+
guard let response = response as? HTTPURLResponse else {
174+
throw DownloadError.unexpectedError
175+
}
176+
177+
guard (200..<300).contains(response.statusCode) else {
178+
throw DownloadError.unexpectedError
179+
}
180+
181+
var downloadedSize = resumeSize
182+
183+
// Create a buffer to collect bytes before writing to disk
184+
var buffer = Data(capacity: chunkSize)
185+
186+
var newNumRetries = numRetries
187+
do {
188+
for try await byte in asyncBytes {
189+
buffer.append(byte)
190+
// When buffer is full, write to disk
191+
if buffer.count == chunkSize {
192+
if !buffer.isEmpty { // Filter out keep-alive chunks
193+
try tempFile.write(contentsOf: buffer)
194+
buffer.removeAll(keepingCapacity: true)
195+
downloadedSize += chunkSize
196+
newNumRetries = 5
197+
guard let expectedSize = expectedSize else { continue }
198+
let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0
199+
downloadState.value = .downloading(progress)
200+
}
201+
}
202+
}
203+
204+
if !buffer.isEmpty {
205+
try tempFile.write(contentsOf: buffer)
206+
downloadedSize += buffer.count
207+
buffer.removeAll(keepingCapacity: true)
208+
newNumRetries = 5
209+
}
210+
} catch let error as URLError {
211+
if newNumRetries <= 0 {
212+
throw error
213+
}
214+
try await Task.sleep(nanoseconds: 1_000_000_000)
215+
216+
let config = URLSessionConfiguration.default
217+
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
218+
219+
try await httpGet(
220+
request: request,
221+
tempFile: tempFile,
222+
resumeSize: downloadedSize,
223+
numRetries: newNumRetries - 1,
224+
expectedSize: expectedSize
225+
)
226+
}
227+
228+
// Verify the downloaded file size matches the expected size
229+
let actualSize = try tempFile.seekToEnd()
230+
if let expectedSize = expectedSize, expectedSize != actualSize {
231+
throw DownloadError.unexpectedError
232+
}
233+
}
234+
82235
@discardableResult
83236
func waitUntilDone() throws -> URL {
84237
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)

Sources/Hub/HubApi.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ public extension HubApi {
367367
// From now on, etag, commit_hash, url and size are not empty
368368
guard let remoteCommitHash = remoteMetadata.commitHash,
369369
let remoteEtag = remoteMetadata.etag,
370+
let remoteSize = remoteMetadata.size,
370371
remoteMetadata.location != "" else {
371372
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
372373
}
@@ -396,7 +397,7 @@ public extension HubApi {
396397
try prepareDestination()
397398
try prepareMetadataDestination()
398399

399-
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession)
400+
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession, expectedSize: remoteSize)
400401
let downloadSubscriber = downloader.downloadState.sink { state in
401402
if case .downloading(let progress) = state {
402403
progressHandler(progress)

Tests/HubTests/DownloaderTests.swift

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
//
2+
// DownloaderTests.swift
3+
// swift-transformers
4+
//
5+
// Created by Arda Atahan Ibis on 1/28/25.
6+
//
7+
8+
import XCTest
9+
import Combine
10+
@testable import Hub
11+
12+
/// Errors that can occur during the download process
13+
enum DownloadError: Error {
14+
case invalidDownloadLocation
15+
case unexpectedError
16+
}
17+
18+
final class DownloaderTests: XCTestCase {
19+
var tempDir: URL!
20+
21+
override func setUp() {
22+
super.setUp()
23+
tempDir = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
24+
try? FileManager.default.createDirectory(at: tempDir, withIntermediateDirectories: true)
25+
}
26+
27+
override func tearDown() {
28+
try? FileManager.default.removeItem(at: tempDir)
29+
super.tearDown()
30+
}
31+
32+
/// This test downloads a known config file, verifies the download completes, checks the content matches expected value
33+
func testSuccessfulDownload() async throws {
34+
// Create a test file
35+
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")!
36+
let destination = tempDir.appendingPathComponent("config.json")
37+
let fileContent = """
38+
{
39+
"architectures": [
40+
"LlamaForCausalLM"
41+
],
42+
"bos_token_id": 1,
43+
"eos_token_id": 2,
44+
"model_type": "llama",
45+
"pad_token_id": 0,
46+
"vocab_size": 32000
47+
}
48+
49+
"""
50+
51+
let downloader = Downloader(
52+
from: url,
53+
to: destination
54+
)
55+
56+
// Store subscriber outside the continuation to maintain its lifecycle
57+
var subscriber: AnyCancellable?
58+
59+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
60+
subscriber = downloader.downloadState.sink { state in
61+
switch state {
62+
case .completed:
63+
continuation.resume()
64+
case .failed(let error):
65+
continuation.resume(throwing: error)
66+
case .downloading:
67+
break
68+
case .notStarted:
69+
break
70+
}
71+
}
72+
}
73+
74+
// Cancel subscription after continuation completes
75+
subscriber?.cancel()
76+
77+
// Verify download completed successfully
78+
XCTAssertTrue(FileManager.default.fileExists(atPath: destination.path))
79+
XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), fileContent)
80+
}
81+
82+
/// This test attempts to download with incorrect expected file, verifies the download fails, ensures no partial file is left behind
83+
func testDownloadFailsWithIncorrectSize() async throws {
84+
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")!
85+
let destination = tempDir.appendingPathComponent("config.json")
86+
87+
// Create downloader with incorrect expected size
88+
let downloader = Downloader(
89+
from: url,
90+
to: destination,
91+
expectedSize: 999999 // Incorrect size
92+
)
93+
94+
do {
95+
try downloader.waitUntilDone()
96+
XCTFail("Download should have failed due to size mismatch")
97+
} catch {
98+
99+
}
100+
101+
// Verify no file was created at destination
102+
XCTAssertFalse(FileManager.default.fileExists(atPath: destination.path))
103+
}
104+
105+
/// This test downloads an LFS file, interrupts the download at 50% and 75% progress,
106+
/// verifies the download can resume and complete successfully, checks the final file exists and has content
107+
func testSuccessfulInterruptedDownload() async throws {
108+
let url = URL(string: "https://huggingface.co/coreml-projects/sam-2-studio/resolve/main/SAM%202%20Studio%201.1.zip")!
109+
let destination = tempDir.appendingPathComponent("SAM%202%20Studio%201.1.zip")
110+
111+
// Create parent directory if it doesn't exist
112+
try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(),
113+
withIntermediateDirectories: true)
114+
115+
let downloader = Downloader(
116+
from: url,
117+
to: destination,
118+
expectedSize: 73194001 // Correct size for verification
119+
)
120+
121+
// First interruption point at 50%
122+
var threshold = 0.5
123+
124+
var subscriber: AnyCancellable?
125+
126+
do {
127+
// Monitor download progress and interrupt at thresholds to test if
128+
// download continues from where it left off
129+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
130+
subscriber = downloader.downloadState.sink { state in
131+
switch state {
132+
case .downloading(let progress):
133+
if threshold != 1.0 && progress >= threshold {
134+
// Move to next threshold and interrupt
135+
threshold = threshold == 0.5 ? 0.75 : 1.0
136+
downloader.cancel()
137+
}
138+
case .completed:
139+
continuation.resume()
140+
case .failed(let error):
141+
continuation.resume(throwing: error)
142+
case .notStarted:
143+
break
144+
}
145+
}
146+
}
147+
148+
subscriber?.cancel()
149+
150+
// Verify the file exists and is complete
151+
if FileManager.default.fileExists(atPath: destination.path) {
152+
let attributes = try FileManager.default.attributesOfItem(atPath: destination.path)
153+
let finalSize = attributes[.size] as! Int64
154+
XCTAssertGreaterThan(finalSize, 0, "File should not be empty")
155+
} else {
156+
XCTFail("File was not created at destination")
157+
}
158+
} catch {
159+
throw error
160+
}
161+
}
162+
}

Tests/TokenizersTests/ChatTemplateTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ class ChatTemplateTests: XCTestCase {
2727
func testDeepSeekQwenChatTemplate() async throws {
2828
let tokenizer = try await AutoTokenizer.from(pretrained: "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
2929
let encoded = try tokenizer.applyChatTemplate(messages: messages)
30-
let encodedTarget = [151646, 151644, 74785, 279, 23670, 15473, 4128, 13, 151645]
30+
let encodedTarget = [151646, 151644, 74785, 279, 23670, 15473, 4128, 13, 151645, 151648, 198]
3131
XCTAssertEqual(encoded, encodedTarget)
3232

3333
let decoded = tokenizer.decode(tokens: encoded)
34-
let decodedTarget = "<|begin▁of▁sentence|><|User|>Describe the Swift programming language.<|Assistant|>"
34+
let decodedTarget = "<|begin▁of▁sentence|><|User|>Describe the Swift programming language.<|Assistant|><think>\n"
3535
XCTAssertEqual(decoded, decodedTarget)
3636
}
3737

0 commit comments

Comments
 (0)