Skip to content

Commit 7482e98

Browse files
authored
load gemma3 dwq VLM models as LLM models (#343)
* load gemma3 dwq VLM models as LLM models
1 parent f85e122 commit 7482e98

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
4343
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
4444
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
4545
"gemma3_text": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
46+
"gemma3": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
4647
"granite": create(GraniteConfiguration.self, GraniteModel.init),
4748
"mimo": create(MiMoConfiguration.self, MiMoModel.init),
4849
"glm4": create(GLM4Configuration.self, GLM4Model.init),

Libraries/MLXLLM/Models/Gemma3Text.swift

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,21 @@ public struct Gemma3TextConfiguration: Codable {
4949
case slidingWindowPattern = "sliding_window_pattern"
5050
}
5151

52+
enum VLMCodingKeys: String, CodingKey {
53+
case textConfig = "text_config"
54+
}
55+
5256
public init(from decoder: Decoder) throws {
53-
let container = try decoder.container(keyedBy: CodingKeys.self)
57+
let nestedContainer = try decoder.container(keyedBy: VLMCodingKeys.self)
58+
59+
// in the case of VLM models convertered using mlx_lm.convert
60+
// the configuration will still match the VLMs and be under text_config
61+
let container =
62+
if nestedContainer.contains(.textConfig) {
63+
try nestedContainer.nestedContainer(keyedBy: CodingKeys.self, forKey: .textConfig)
64+
} else {
65+
try decoder.container(keyedBy: CodingKeys.self)
66+
}
5467

5568
modelType = try container.decode(String.self, forKey: .modelType)
5669
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
@@ -339,6 +352,14 @@ public class Gemma3TextModel: Module, LLMModel {
339352
-> [String: MLXArray]
340353
{
341354
var processedWeights = weights
355+
356+
// VLM models converted using mlx_vlm.convert will still have
357+
// the weights are under a language_model key
358+
let unflattened = ModuleParameters.unflattened(weights)
359+
if let lm = unflattened["language_model"] {
360+
processedWeights = Dictionary(uniqueKeysWithValues: lm.flattened())
361+
}
362+
342363
if processedWeights["lm_head.weight"] == nil {
343364
if let embedWeight = processedWeights["model.embed_tokens.weight"] {
344365
processedWeights["lm_head.weight"] = embedWeight

0 commit comments

Comments
 (0)