Skip to content

Commit d31123f

Browse files
Add repo revision option (#329)
* Add repo revision option * Update ModelConfiguration inits * Add tokenizer revision option
1 parent 8e41311 commit d31123f

File tree

4 files changed

+27
-22
lines changed

4 files changed

+27
-22
lines changed

Libraries/MLXLMCommon/Load.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,16 @@ public func downloadModel(
2424
) async throws -> URL {
2525
do {
2626
switch configuration.id {
27-
case .id(let id):
27+
case .id(let id, let revision):
2828
// download the model weights
2929
let repo = Hub.Repo(id: id)
3030
let modelFiles = ["*.safetensors", "*.json"]
3131
return try await hub.snapshot(
32-
from: repo, matching: modelFiles, progressHandler: progressHandler)
33-
32+
from: repo,
33+
revision: revision,
34+
matching: modelFiles,
35+
progressHandler: progressHandler
36+
)
3437
case .directory(let directory):
3538
return directory
3639
}

Libraries/MLXLMCommon/ModelConfiguration.swift

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ import Hub
99
public struct ModelConfiguration: Sendable {
1010

1111
public enum Identifier: Sendable {
12-
case id(String)
12+
case id(String, revision: String = "main")
1313
case directory(URL)
1414
}
1515

1616
public var id: Identifier
1717

1818
public var name: String {
1919
switch id {
20-
case .id(let string):
21-
string
20+
case .id(let id, _):
21+
id
2222
case .directory(let url):
2323
url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
2424
}
@@ -37,20 +37,22 @@ public struct ModelConfiguration: Sendable {
3737
public var extraEOSTokens: Set<String>
3838

3939
public init(
40-
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
40+
id: String, revision: String = "main",
41+
tokenizerId: String? = nil, overrideTokenizer: String? = nil,
4142
defaultPrompt: String = "hello",
4243
extraEOSTokens: Set<String> = [],
4344
preparePrompt: (@Sendable (String) -> String)? = nil
4445
) {
45-
self.id = .id(id)
46+
self.id = .id(id, revision: revision)
4647
self.tokenizerId = tokenizerId
4748
self.overrideTokenizer = overrideTokenizer
4849
self.defaultPrompt = defaultPrompt
4950
self.extraEOSTokens = extraEOSTokens
5051
}
5152

5253
public init(
53-
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
54+
directory: URL,
55+
tokenizerId: String? = nil, overrideTokenizer: String? = nil,
5456
defaultPrompt: String = "hello",
5557
extraEOSTokens: Set<String> = []
5658
) {
@@ -63,7 +65,7 @@ public struct ModelConfiguration: Sendable {
6365

6466
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
6567
switch id {
66-
case .id(let id):
68+
case .id(let id, _):
6769
// download the model weights and config
6870
let repo = Hub.Repo(id: id)
6971
return hub.localRepoLocation(repo)
@@ -84,8 +86,8 @@ extension ModelConfiguration.Identifier: Equatable {
8486
-> Bool
8587
{
8688
switch (lhs, rhs) {
87-
case (.id(let lhsID), .id(let rhsID)):
88-
lhsID == rhsID
89+
case (.id(let lhsID, let lhsRevision), .id(let rhsID, let rhsRevision)):
90+
lhsID == rhsID && lhsRevision == rhsRevision
8991
case (.directory(let lhsURL), .directory(let rhsURL)):
9092
lhsURL == rhsURL
9193
default:

Libraries/MLXLMCommon/Tokenizer.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ public func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi)
2424
let config: LanguageModelConfigurationFromHub
2525

2626
switch configuration.id {
27-
case .id(let id):
27+
case .id(let id, let revision):
2828
do {
2929
// the load can fail (async when we try to use it)
3030
let loaded = LanguageModelConfigurationFromHub(
31-
modelName: configuration.tokenizerId ?? id, hubApi: hub)
31+
modelName: configuration.tokenizerId ?? id, revision: revision, hubApi: hub)
3232
_ = try await loaded.tokenizerConfig
3333
config = loaded
3434
} catch {

Package.resolved

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)