Skip to content

Commit c2e1e0f

Browse files
Add option for background session, but use default normally (#64)
1 parent 4c98a16 commit c2e1e0f

File tree

3 files changed

+71
-34
lines changed

3 files changed

+71
-34
lines changed

Sources/Hub/Downloader.swift

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,42 @@ import Combine
1111

1212
class Downloader: NSObject, ObservableObject {
1313
private(set) var destination: URL
14-
14+
1515
enum DownloadState {
1616
case notStarted
1717
case downloading(Double)
1818
case completed(URL)
1919
case failed(Error)
2020
}
21-
21+
2222
enum DownloadError: Error {
2323
case invalidDownloadLocation
2424
case unexpectedError
2525
}
26-
26+
2727
private(set) lazy var downloadState: CurrentValueSubject<DownloadState, Never> = CurrentValueSubject(.notStarted)
2828
private var stateSubscriber: Cancellable?
29-
29+
3030
private var urlSession: URLSession? = nil
31-
32-
init(from url: URL, to destination: URL, using authToken: String? = nil) {
31+
32+
init(from url: URL, to destination: URL, using authToken: String? = nil, inBackground: Bool = false) {
3333
self.destination = destination
3434
super.init()
35-
36-
let config = URLSessionConfiguration.background(withIdentifier: url.path)
37-
#if targetEnvironment(simulator)
38-
urlSession = URLSession(configuration: .default, delegate: self, delegateQueue: OperationQueue())
39-
#else
40-
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue())
41-
#endif
35+
let sessionIdentifier = "swift-transformers.hub.downloader"
36+
37+
var config = URLSessionConfiguration.default
38+
if inBackground {
39+
config = URLSessionConfiguration.background(withIdentifier: sessionIdentifier)
40+
config.isDiscretionary = false
41+
config.sessionSendsLaunchEvents = true
42+
}
43+
44+
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
45+
46+
setupDownload(from: url, with: authToken)
47+
}
48+
49+
private func setupDownload(from url: URL, with authToken: String?) {
4250
downloadState.value = .downloading(0)
4351
urlSession?.getAllTasks { tasks in
4452
// If there's an existing pending background task with the same URL, let it proceed.
@@ -70,7 +78,7 @@ class Downloader: NSObject, ObservableObject {
7078
self.urlSession?.downloadTask(with: request).resume()
7179
}
7280
}
73-
81+
7482
@discardableResult
7583
func waitUntilDone() throws -> URL {
7684
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
@@ -83,31 +91,28 @@ class Downloader: NSObject, ObservableObject {
8391
}
8492
}
8593
semaphore.wait()
86-
94+
8795
switch downloadState.value {
8896
case .completed(let url): return url
8997
case .failed(let error): throw error
9098
default: throw DownloadError.unexpectedError
9199
}
92100
}
93-
101+
94102
func cancel() {
95103
urlSession?.invalidateAndCancel()
96104
}
97105
}
98106

99-
extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
107+
extension Downloader: URLSessionDownloadDelegate {
100108
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten _: Int64, totalBytesExpectedToWrite _: Int64) {
101109
downloadState.value = .downloading(downloadTask.progress.fractionCompleted)
102110
}
103111

104112
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
105-
guard FileManager.default.fileExists(atPath: location.path) else {
106-
downloadState.value = .failed(DownloadError.invalidDownloadLocation)
107-
return
108-
}
109113
do {
110-
try FileManager.default.moveItem(at: location, to: destination)
114+
// If the downloaded file already exists on the filesystem, overwrite it
115+
try FileManager.default.moveDownloadedFile(from: location, to: self.destination)
111116
downloadState.value = .completed(destination)
112117
} catch {
113118
downloadState.value = .failed(error)
@@ -124,3 +129,12 @@ extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
124129
}
125130
}
126131
}
132+
133+
extension FileManager {
134+
func moveDownloadedFile(from srcURL: URL, to dstURL: URL) throws {
135+
if fileExists(atPath: dstURL.path) {
136+
try removeItem(at: dstURL)
137+
}
138+
try moveItem(at: srcURL, to: dstURL)
139+
}
140+
}

Sources/Hub/HubApi.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,15 @@
88
import Foundation
99

1010
public struct HubApi {
11-
public let downloadBase: URL
12-
public let hfToken: String?
13-
public let endpoint: String
14-
11+
var downloadBase: URL
12+
var hfToken: String?
13+
var endpoint: String
14+
var useBackgroundSession: Bool
15+
1516
public typealias RepoType = Hub.RepoType
1617
public typealias Repo = Hub.Repo
1718

18-
public init(
19-
downloadBase: URL? = nil,
20-
hfToken: String? = nil,
21-
endpoint: String = "https://huggingface.co"
22-
) {
19+
public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false) {
2320
self.hfToken = hfToken
2421
if let downloadBase {
2522
self.downloadBase = downloadBase
@@ -28,6 +25,7 @@ public struct HubApi {
2825
self.downloadBase = documents.appending(component: "huggingface")
2926
}
3027
self.endpoint = endpoint
28+
self.useBackgroundSession = useBackgroundSession
3129
}
3230

3331
public static let shared = HubApi()
@@ -129,7 +127,8 @@ public extension HubApi {
129127
let relativeFilename: String
130128
let hfToken: String?
131129
let endpoint: String?
132-
130+
let backgroundSession: Bool
131+
133132
var source: URL {
134133
// https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true
135134
var url = URL(string: endpoint ?? "https://huggingface.co")!
@@ -163,7 +162,7 @@ public extension HubApi {
163162
guard !downloaded else { return destination }
164163

165164
try prepareDestination()
166-
let downloader = Downloader(from: source, to: destination, using: hfToken)
165+
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession)
167166
let downloadSubscriber = downloader.downloadState.sink { state in
168167
if case .downloading(let progress) = state {
169168
progressHandler(progress)
@@ -188,7 +187,8 @@ public extension HubApi {
188187
repoDestination: repoDestination,
189188
relativeFilename: filename,
190189
hfToken: hfToken,
191-
endpoint: endpoint
190+
endpoint: endpoint,
191+
backgroundSession: useBackgroundSession
192192
)
193193
try await downloader.download { fractionDownloaded in
194194
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)

Tests/HubTests/HubApiTests.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,29 @@ class SnapshotDownloadTests: XCTestCase {
149149
)
150150
}
151151

152+
/// Background sessions get rate limited by the OS, see discussion here: https://github.com/huggingface/swift-transformers/issues/61
153+
/// Test only one file at a time
154+
func testDownloadInBackground() async throws {
155+
let hubApi = HubApi(downloadBase: downloadDestination, useBackgroundSession: true)
156+
var lastProgress: Progress? = nil
157+
let downloadedTo = try await hubApi.snapshot(from: repo, matching: "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json") { progress in
158+
print("Total Progress: \(progress.fractionCompleted)")
159+
print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)")
160+
lastProgress = progress
161+
}
162+
XCTAssertEqual(lastProgress?.fractionCompleted, 1)
163+
XCTAssertEqual(lastProgress?.completedUnitCount, 1)
164+
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
165+
166+
let downloadedFilenames = getRelativeFiles(url: downloadDestination)
167+
XCTAssertEqual(
168+
Set(downloadedFilenames),
169+
Set([
170+
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json",
171+
])
172+
)
173+
}
174+
152175
func testCustomEndpointDownload() async throws {
153176
let hubApi = HubApi(downloadBase: downloadDestination, endpoint: "https://hf-mirror.com")
154177
var lastProgress: Progress? = nil

0 commit comments

Comments
 (0)