Skip to content

Commit 72fc1f7

Browse files
authored
feat: Support custom endpoint (#57)
1 parent ae3ce32 commit 72fc1f7

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

Sources/Hub/HubApi.swift

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,22 @@
88
import Foundation
99

1010
public struct HubApi {
11-
let endpoint = "https://huggingface.co/api"
1211
var downloadBase: URL
13-
var hfToken: String? = nil
12+
var hfToken: String?
13+
var endpoint: String
1414

1515
public typealias RepoType = Hub.RepoType
1616
public typealias Repo = Hub.Repo
1717

18-
public init(downloadBase: URL? = nil, hfToken: String? = nil) {
18+
public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co") {
1919
if downloadBase == nil {
2020
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
2121
self.downloadBase = documents.appending(component: "huggingface")
2222
} else {
2323
self.downloadBase = downloadBase!
2424
}
2525
self.hfToken = hfToken
26+
self.endpoint = endpoint
2627
}
2728

2829
static let shared = HubApi()
@@ -51,15 +52,15 @@ public extension HubApi {
5152
switch response.statusCode {
5253
case 200..<300: break
5354
case 400..<500: throw Hub.HubClientError.authorizationRequired
54-
default : throw Hub.HubClientError.httpStatusCode(response.statusCode)
55+
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
5556
}
5657

5758
return (data, response)
5859
}
5960

6061
func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] {
6162
// Read repo info and only parse "siblings"
62-
let url = URL(string: "\(endpoint)/\(repo.type)/\(repo.id)")!
63+
let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)")!
6364
let (data, _) = try await httpGet(for: url)
6465
let response = try JSONDecoder().decode(SiblingsResponse.self, from: data)
6566
let filenames = response.siblings.map { $0.rfilename }
@@ -103,7 +104,7 @@ public extension HubApi {
103104
func whoami() async throws -> Config {
104105
guard hfToken != nil else { throw Hub.HubClientError.authorizationRequired }
105106

106-
let url = URL(string: "\(endpoint)/whoami-v2")!
107+
let url = URL(string: "\(endpoint)/api/whoami-v2")!
107108
let (data, _) = try await httpGet(for: url)
108109

109110
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
@@ -123,15 +124,16 @@ public extension HubApi {
123124
let repoDestination: URL
124125
let relativeFilename: String
125126
let hfToken: String?
127+
let endpoint: String?
126128

127129
var source: URL {
128130
// https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true
129-
var url = URL(string: "https://huggingface.co")!
131+
var url = URL(string: endpoint ?? "https://huggingface.co")!
130132
if repo.type != .models {
131133
url = url.appending(component: repo.type.rawValue)
132134
}
133135
url = url.appending(path: repo.id)
134-
url = url.appending(path: "resolve/main") // TODO: revisions
136+
url = url.appending(path: "resolve/main") // TODO: revisions
135137
url = url.appending(path: relativeFilename)
136138
return url
137139
}
@@ -177,7 +179,7 @@ public extension HubApi {
177179
let repoDestination = localRepoLocation(repo)
178180
for filename in filenames {
179181
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
180-
let downloader = HubFileDownloader(repo: repo, repoDestination: repoDestination, relativeFilename: filename, hfToken: hfToken)
182+
let downloader = HubFileDownloader(repo: repo, repoDestination: repoDestination, relativeFilename: filename, hfToken: hfToken, endpoint: endpoint)
181183
try await downloader.download { fractionDownloaded in
182184
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)
183185
progressHandler(progress)
@@ -243,9 +245,8 @@ public extension Hub {
243245
}
244246
}
245247

246-
public extension Array<String> {
248+
public extension [String] {
247249
func matching(glob: String) -> [String] {
248-
self.filter { fnmatch(glob, $0, 0) == 0 }
250+
filter { fnmatch(glob, $0, 0) == 0 }
249251
}
250252
}
251-

Tests/HubTests/HubApiTests.swift

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
// Created by Pedro Cuenca on 20231230.
55
//
66

7-
import XCTest
87
@testable import Hub
9-
8+
import XCTest
109

1110
class HubApiTests: XCTestCase {
1211
override func setUp() {
@@ -18,7 +17,7 @@ class HubApiTests: XCTestCase {
1817
}
1918

2019
// MARK: use a specific revision for these tests
21-
20+
2221
func testFilenameRetrieval() async {
2322
do {
2423
let filenames = try await Hub.getFilenames(from: "coreml-projects/Llama-2-7b-chat-coreml")
@@ -27,7 +26,7 @@ class HubApiTests: XCTestCase {
2726
XCTFail("\(error)")
2827
}
2928
}
30-
29+
3130
func testFilenameRetrievalWithGlob() async {
3231
do {
3332
try await {
@@ -75,7 +74,7 @@ class HubApiTests: XCTestCase {
7574
XCTFail("\(error)")
7675
}
7776
}
78-
77+
7978
func testFilenameRetrievalWithMultiplePatterns() async {
8079
do {
8180
let patterns = ["config.json", "tokenizer.json", "tokenizer_*.json"]
@@ -96,9 +95,8 @@ class SnapshotDownloadTests: XCTestCase {
9695
let base = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!
9796
return base.appending(component: "huggingface-tests")
9897
}()
99-
100-
override func setUp() {
101-
}
98+
99+
override func setUp() {}
102100

103101
override func tearDown() {
104102
do {
@@ -126,7 +124,7 @@ class SnapshotDownloadTests: XCTestCase {
126124
}
127125
return filenames
128126
}
129-
127+
130128
func testDownload() async throws {
131129
let hubApi = HubApi(downloadBase: downloadDestination)
132130
var lastProgress: Progress? = nil
@@ -138,7 +136,31 @@ class SnapshotDownloadTests: XCTestCase {
138136
XCTAssertEqual(lastProgress?.fractionCompleted, 1)
139137
XCTAssertEqual(lastProgress?.completedUnitCount, 6)
140138
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
141-
139+
140+
let downloadedFilenames = getRelativeFiles(url: downloadDestination)
141+
XCTAssertEqual(
142+
Set(downloadedFilenames),
143+
Set([
144+
"config.json", "tokenizer.json", "tokenizer_config.json",
145+
"llama-2-7b-chat.mlpackage/Manifest.json",
146+
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json",
147+
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json",
148+
])
149+
)
150+
}
151+
152+
func testCustomEndpointDownload() async throws {
153+
let hubApi = HubApi(downloadBase: downloadDestination, endpoint: "https://hf-mirror.com")
154+
var lastProgress: Progress? = nil
155+
let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in
156+
print("Total Progress: \(progress.fractionCompleted)")
157+
print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)")
158+
lastProgress = progress
159+
}
160+
XCTAssertEqual(lastProgress?.fractionCompleted, 1)
161+
XCTAssertEqual(lastProgress?.completedUnitCount, 6)
162+
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
163+
142164
let downloadedFilenames = getRelativeFiles(url: downloadDestination)
143165
XCTAssertEqual(
144166
Set(downloadedFilenames),

0 commit comments

Comments
 (0)