Skip to content

Commit 18b62e5

Browse files
authored
Support loading tokenizer from local folder #76 (#81)
* Support loading tokenizer from local folder #76 * PR feedback & minor optimisations
1 parent 508c540 commit 18b62e5

File tree

4 files changed

+59
-6
lines changed

4 files changed

+59
-6
lines changed

Sources/Hub/Hub.swift

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ public class LanguageModelConfigurationFromHub {
120120
return try await self.loadConfig(modelName: modelName, hubApi: hubApi)
121121
}
122122
}
123+
124+
public init(
125+
modelFolder: URL,
126+
hubApi: HubApi = .shared
127+
) {
128+
self.configPromise = Task {
129+
return try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
130+
}
131+
}
123132

124133
public var modelConfig: Config {
125134
get async throws {
@@ -170,12 +179,19 @@ public class LanguageModelConfigurationFromHub {
170179
) async throws -> Configurations {
171180
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
172181
let repo = Hub.Repo(id: modelName)
173-
try await hubApi.snapshot(from: repo, matching: filesToDownload)
182+
let downloadedModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)
174183

184+
return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi)
185+
}
186+
187+
func loadConfig(
188+
modelFolder: URL,
189+
hubApi: HubApi = .shared
190+
) async throws -> Configurations {
175191
// Note tokenizerConfig may be nil (does not exist in all models)
176-
let modelConfig = try hubApi.configuration(from: "config.json", in: repo)
177-
let tokenizerConfig = try? hubApi.configuration(from: "tokenizer_config.json", in: repo)
178-
let tokenizerVocab = try hubApi.configuration(from: "tokenizer.json", in: repo)
192+
let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json"))
193+
let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json"))
194+
let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json"))
179195

180196
let configs = Configurations(
181197
modelConfig: modelConfig,

Sources/Hub/HubApi.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,14 @@ public extension HubApi {
9393
/// Assumes the file has already been downloaded.
9494
/// `filename` is relative to the download base.
9595
func configuration(from filename: String, in repo: Repo) throws -> Config {
96-
let url = localRepoLocation(repo).appending(path: filename)
97-
let data = try Data(contentsOf: url)
96+
let fileURL = localRepoLocation(repo).appending(path: filename)
97+
return try configuration(fileURL: fileURL)
98+
}
99+
100+
/// Assumes the file is already present at local url.
101+
/// `fileURL` is a complete local file path for the given model
102+
func configuration(fileURL: URL) throws -> Config {
103+
let data = try Data(contentsOf: fileURL)
98104
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
99105
guard let dictionary = parsed as? [String: Any] else { throw Hub.HubClientError.parse }
100106
return Config(dictionary)

Sources/Tokenizers/Tokenizer.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,17 @@ extension AutoTokenizer {
256256

257257
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
258258
}
259+
260+
public static func from(
261+
modelFolder: URL,
262+
hubApi: HubApi = .shared
263+
) async throws -> Tokenizer {
264+
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi)
265+
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
266+
let tokenizerData = try await config.tokenizerData
267+
268+
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
269+
}
259270
}
260271

261272
// MARK: - Tokenizer model classes

Tests/TokenizersTests/FactoryTests.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,24 @@ class FactoryTests: TestWithCustomHubDownloadLocation {
4242
let inputIds = tokenizer("Today she took a train to the West")
4343
XCTAssertEqual(inputIds, [50258, 50363, 27676, 750, 1890, 257, 3847, 281, 264, 4055, 50257])
4444
}
45+
46+
func testFromModelFolder() async throws {
47+
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
48+
let repo = Hub.Repo(id: "coreml-projects/Llama-2-7b-chat-coreml")
49+
let localModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)
50+
51+
let tokenizer = try await AutoTokenizer.from(modelFolder: localModelFolder, hubApi: hubApi)
52+
let inputIds = tokenizer("Today she took a train to the West")
53+
XCTAssertEqual(inputIds, [1, 20628, 1183, 3614, 263, 7945, 304, 278, 3122])
54+
}
55+
56+
func testWhisperFromModelFolder() async throws {
57+
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
58+
let repo = Hub.Repo(id: "openai/whisper-large-v2")
59+
let localModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)
60+
61+
let tokenizer = try await AutoTokenizer.from(modelFolder: localModelFolder, hubApi: hubApi)
62+
let inputIds = tokenizer("Today she took a train to the West")
63+
XCTAssertEqual(inputIds, [50258, 50363, 27676, 750, 1890, 257, 3847, 281, 264, 4055, 50257])
64+
}
4565
}

0 commit comments

Comments
 (0)