@@ -12,6 +12,8 @@ import Combine
12
12
class Downloader : NSObject , ObservableObject {
13
13
private( set) var destination : URL
14
14
15
+ private let chunkSize = 10 * 1024 * 1024 // 10MB
16
+
15
17
enum DownloadState {
16
18
case notStarted
17
19
case downloading( Double )
@@ -29,7 +31,17 @@ class Downloader: NSObject, ObservableObject {
29
31
30
32
private var urlSession : URLSession ? = nil
31
33
32
- init ( from url: URL , to destination: URL , using authToken: String ? = nil , inBackground: Bool = false ) {
34
+ init (
35
+ from url: URL ,
36
+ to destination: URL ,
37
+ using authToken: String ? = nil ,
38
+ inBackground: Bool = false ,
39
+ resumeSize: Int = 0 ,
40
+ headers: [ String : String ] ? = nil ,
41
+ expectedSize: Int ? = nil ,
42
+ timeout: TimeInterval = 10 ,
43
+ numRetries: Int = 5
44
+ ) {
33
45
self . destination = destination
34
46
super. init ( )
35
47
let sessionIdentifier = " swift-transformers.hub.downloader "
@@ -43,10 +55,28 @@ class Downloader: NSObject, ObservableObject {
43
55
44
56
self . urlSession = URLSession ( configuration: config, delegate: self , delegateQueue: nil )
45
57
46
- setupDownload ( from: url, with: authToken)
58
+ setupDownload ( from: url, with: authToken, resumeSize : resumeSize , headers : headers , expectedSize : expectedSize , timeout : timeout , numRetries : numRetries )
47
59
}
48
60
49
- private func setupDownload( from url: URL , with authToken: String ? ) {
61
+ /// Sets up and initiates a file download operation
62
+ ///
63
+ /// - Parameters:
64
+ /// - url: Source URL to download from
65
+ /// - authToken: Bearer token for authentication with Hugging Face
66
+ /// - resumeSize: Number of bytes already downloaded for resuming interrupted downloads
67
+ /// - headers: Additional HTTP headers to include in the request
68
+ /// - expectedSize: Expected file size in bytes for validation
69
+ /// - timeout: Time interval before the request times out
70
+ /// - numRetries: Number of retry attempts for failed downloads
71
+ private func setupDownload(
72
+ from url: URL ,
73
+ with authToken: String ? ,
74
+ resumeSize: Int ,
75
+ headers: [ String : String ] ? ,
76
+ expectedSize: Int ? ,
77
+ timeout: TimeInterval ,
78
+ numRetries: Int
79
+ ) {
50
80
downloadState. value = . downloading( 0 )
51
81
urlSession? . getAllTasks { tasks in
52
82
// If there's an existing pending background task with the same URL, let it proceed.
@@ -71,14 +101,137 @@ class Downloader: NSObject, ObservableObject {
71
101
}
72
102
}
73
103
var request = URLRequest ( url: url)
104
+
105
+ // Use headers from argument else create an empty header dictionary
106
+ var requestHeaders = headers ?? [ : ]
107
+
108
+ // Populate header auth and range fields
74
109
if let authToken = authToken {
75
- request. setValue ( " Bearer \( authToken) " , forHTTPHeaderField: " Authorization " )
110
+ requestHeaders [ " Authorization " ] = " Bearer \( authToken) "
111
+ }
112
+ if resumeSize > 0 {
113
+ requestHeaders [ " Range " ] = " bytes= \( resumeSize) - "
76
114
}
115
+
116
+
117
+ request. timeoutInterval = timeout
118
+ request. allHTTPHeaderFields = requestHeaders
77
119
78
- self . urlSession? . downloadTask ( with: request) . resume ( )
120
+ Task {
121
+ do {
122
+ // Create a temp file to write
123
+ let tempURL = FileManager . default. temporaryDirectory. appendingPathComponent ( UUID ( ) . uuidString)
124
+ FileManager . default. createFile ( atPath: tempURL. path, contents: nil )
125
+ let tempFile = try FileHandle ( forWritingTo: tempURL)
126
+
127
+ defer { tempFile. closeFile ( ) }
128
+ try await self . httpGet ( request: request, tempFile: tempFile, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize)
129
+
130
+ // Clean up and move the completed download to its final destination
131
+ tempFile. closeFile ( )
132
+ try FileManager . default. moveDownloadedFile ( from: tempURL, to: self . destination)
133
+
134
+ self . downloadState. value = . completed( self . destination)
135
+ } catch {
136
+ self . downloadState. value = . failed( error)
137
+ }
138
+ }
79
139
}
80
140
}
81
141
142
+ /// Downloads a file from given URL using chunked transfer and handles retries.
143
+ ///
144
+ /// Reference: https://github.com/huggingface/huggingface_hub/blob/418a6ffce7881f5c571b2362ed1c23ef8e4d7d20/src/huggingface_hub/file_download.py#L306
145
+ ///
146
+ /// - Parameters:
147
+ /// - request: The URLRequest for the file to download
148
+ /// - resumeSize: The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position
149
+ /// - numRetries: The number of retry attempts remaining for failed downloads
150
+ /// - expectedSize: The expected size of the file to download. If set, the download will raise an error if the size of the received content is different from the expected one.
151
+ /// - Throws: `DownloadError.unexpectedError` if the response is invalid or file size mismatch occurs
152
+ /// `URLError` if the download fails after all retries are exhausted
153
+ private func httpGet(
154
+ request: URLRequest ,
155
+ tempFile: FileHandle ,
156
+ resumeSize: Int ,
157
+ numRetries: Int ,
158
+ expectedSize: Int ?
159
+ ) async throws {
160
+ guard let session = self . urlSession else {
161
+ throw DownloadError . unexpectedError
162
+ }
163
+
164
+ // Create a new request with Range header for resuming
165
+ var newRequest = request
166
+ if resumeSize > 0 {
167
+ newRequest. setValue ( " bytes= \( resumeSize) - " , forHTTPHeaderField: " Range " )
168
+ }
169
+
170
+ // Start the download and get the byte stream
171
+ let ( asyncBytes, response) = try await session. bytes ( for: newRequest)
172
+
173
+ guard let response = response as? HTTPURLResponse else {
174
+ throw DownloadError . unexpectedError
175
+ }
176
+
177
+ guard ( 200 ..< 300 ) . contains ( response. statusCode) else {
178
+ throw DownloadError . unexpectedError
179
+ }
180
+
181
+ var downloadedSize = resumeSize
182
+
183
+ // Create a buffer to collect bytes before writing to disk
184
+ var buffer = Data ( capacity: chunkSize)
185
+
186
+ var newNumRetries = numRetries
187
+ do {
188
+ for try await byte in asyncBytes {
189
+ buffer. append ( byte)
190
+ // When buffer is full, write to disk
191
+ if buffer. count == chunkSize {
192
+ if !buffer. isEmpty { // Filter out keep-alive chunks
193
+ try tempFile. write ( contentsOf: buffer)
194
+ buffer. removeAll ( keepingCapacity: true )
195
+ downloadedSize += chunkSize
196
+ newNumRetries = 5
197
+ guard let expectedSize = expectedSize else { continue }
198
+ let progress = expectedSize != 0 ? Double ( downloadedSize) / Double( expectedSize) : 0
199
+ downloadState. value = . downloading( progress)
200
+ }
201
+ }
202
+ }
203
+
204
+ if !buffer. isEmpty {
205
+ try tempFile. write ( contentsOf: buffer)
206
+ downloadedSize += buffer. count
207
+ buffer. removeAll ( keepingCapacity: true )
208
+ newNumRetries = 5
209
+ }
210
+ } catch let error as URLError {
211
+ if newNumRetries <= 0 {
212
+ throw error
213
+ }
214
+ try await Task . sleep ( nanoseconds: 1_000_000_000 )
215
+
216
+ let config = URLSessionConfiguration . default
217
+ self . urlSession = URLSession ( configuration: config, delegate: self , delegateQueue: nil )
218
+
219
+ try await httpGet (
220
+ request: request,
221
+ tempFile: tempFile,
222
+ resumeSize: downloadedSize,
223
+ numRetries: newNumRetries - 1 ,
224
+ expectedSize: expectedSize
225
+ )
226
+ }
227
+
228
+ // Verify the downloaded file size matches the expected size
229
+ let actualSize = try tempFile. seekToEnd ( )
230
+ if let expectedSize = expectedSize, expectedSize != actualSize {
231
+ throw DownloadError . unexpectedError
232
+ }
233
+ }
234
+
82
235
@discardableResult
83
236
func waitUntilDone( ) throws -> URL {
84
237
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
0 commit comments