@@ -49,8 +49,21 @@ public struct Gemma3TextConfiguration: Codable {
49
49
case slidingWindowPattern = " sliding_window_pattern "
50
50
}
51
51
52
+ enum VLMCodingKeys : String , CodingKey {
53
+ case textConfig = " text_config "
54
+ }
55
+
52
56
public init ( from decoder: Decoder ) throws {
53
- let container = try decoder. container ( keyedBy: CodingKeys . self)
57
+ let nestedContainer = try decoder. container ( keyedBy: VLMCodingKeys . self)
58
+
59
+ // in the case of VLM models convertered using mlx_lm.convert
60
+ // the configuration will still match the VLMs and be under text_config
61
+ let container =
62
+ if nestedContainer. contains ( . textConfig) {
63
+ try nestedContainer. nestedContainer ( keyedBy: CodingKeys . self, forKey: . textConfig)
64
+ } else {
65
+ try decoder. container ( keyedBy: CodingKeys . self)
66
+ }
54
67
55
68
modelType = try container. decode ( String . self, forKey: . modelType)
56
69
hiddenSize = try container. decode ( Int . self, forKey: . hiddenSize)
@@ -339,6 +352,14 @@ public class Gemma3TextModel: Module, LLMModel {
339
352
-> [ String : MLXArray ]
340
353
{
341
354
var processedWeights = weights
355
+
356
+ // VLM models converted using mlx_vlm.convert will still have
357
+ // the weights are under a language_model key
358
+ let unflattened = ModuleParameters . unflattened ( weights)
359
+ if let lm = unflattened [ " language_model " ] {
360
+ processedWeights = Dictionary ( uniqueKeysWithValues: lm. flattened ( ) )
361
+ }
362
+
342
363
if processedWeights [ " lm_head.weight " ] == nil {
343
364
if let embedWeight = processedWeights [ " model.embed_tokens.weight " ] {
344
365
processedWeights [ " lm_head.weight " ] = embedWeight
0 commit comments