Skip to content

Commit 24605a8

Browse files
authored
Propagate HubApi configuration (#62)
* propagated HubApi * review changes
1 parent 95194e6 commit 24605a8

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed

Sources/Hub/Hub.swift

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,12 @@ public class LanguageModelConfigurationFromHub {
112112

113113
private var configPromise: Task<Configurations, Error>? = nil
114114

115-
public init(modelName: String) {
115+
public init(
116+
modelName: String,
117+
hubApi: HubApi = .shared
118+
) {
116119
self.configPromise = Task.init {
117-
return try await self.loadConfig(modelName: modelName)
120+
return try await self.loadConfig(modelName: modelName, hubApi: hubApi)
118121
}
119122
}
120123

@@ -161,8 +164,10 @@ public class LanguageModelConfigurationFromHub {
161164
}
162165
}
163166

164-
func loadConfig(modelName: String, hfToken: String? = nil) async throws -> Configurations {
165-
let hubApi = HubApi(hfToken: hfToken)
167+
func loadConfig(
168+
modelName: String,
169+
hubApi: HubApi = .shared
170+
) async throws -> Configurations {
166171
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
167172
let repo = Hub.Repo(id: modelName)
168173
try await hubApi.snapshot(from: repo, matching: filesToDownload)
@@ -172,7 +177,11 @@ public class LanguageModelConfigurationFromHub {
172177
let tokenizerConfig = try? hubApi.configuration(from: "tokenizer_config.json", in: repo)
173178
let tokenizerVocab = try hubApi.configuration(from: "tokenizer.json", in: repo)
174179

175-
let configs = Configurations(modelConfig: modelConfig, tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerVocab)
180+
let configs = Configurations(
181+
modelConfig: modelConfig,
182+
tokenizerConfig: tokenizerConfig,
183+
tokenizerData: tokenizerVocab
184+
)
176185
return configs
177186
}
178187

Sources/Hub/HubApi.swift

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,29 @@
88
import Foundation
99

1010
public struct HubApi {
11-
var downloadBase: URL
12-
var hfToken: String?
13-
var endpoint: String
11+
public let downloadBase: URL
12+
public let hfToken: String?
13+
public let endpoint: String
1414

1515
public typealias RepoType = Hub.RepoType
1616
public typealias Repo = Hub.Repo
1717

18-
public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co") {
19-
if downloadBase == nil {
18+
public init(
19+
downloadBase: URL? = nil,
20+
hfToken: String? = nil,
21+
endpoint: String = "https://huggingface.co"
22+
) {
23+
self.hfToken = hfToken
24+
if let downloadBase {
25+
self.downloadBase = downloadBase
26+
} else {
2027
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
2128
self.downloadBase = documents.appending(component: "huggingface")
22-
} else {
23-
self.downloadBase = downloadBase!
2429
}
25-
self.hfToken = hfToken
2630
self.endpoint = endpoint
2731
}
2832

29-
static let shared = HubApi()
33+
public static let shared = HubApi()
3034
}
3135

3236
/// File retrieval
@@ -179,7 +183,13 @@ public extension HubApi {
179183
let repoDestination = localRepoLocation(repo)
180184
for filename in filenames {
181185
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
182-
let downloader = HubFileDownloader(repo: repo, repoDestination: repoDestination, relativeFilename: filename, hfToken: hfToken, endpoint: endpoint)
186+
let downloader = HubFileDownloader(
187+
repo: repo,
188+
repoDestination: repoDestination,
189+
relativeFilename: filename,
190+
hfToken: hfToken,
191+
endpoint: endpoint
192+
)
183193
try await downloader.download { fractionDownloaded in
184194
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)
185195
progressHandler(progress)

Sources/Tokenizers/Tokenizer.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,11 @@ extension AutoTokenizer {
234234
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
235235
}
236236

237-
public static func from(pretrained model: String) async throws -> Tokenizer {
238-
let config = LanguageModelConfigurationFromHub(modelName: model)
237+
public static func from(
238+
pretrained model: String,
239+
hubApi: HubApi = .shared
240+
) async throws -> Tokenizer {
241+
let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi)
239242
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
240243
let tokenizerData = try await config.tokenizerData
241244

0 commit comments

Comments
 (0)