Skip to content

Commit 82df136

Browse files
committed
Fix and factor out checkArrayShape
1 parent befbc5f commit 82df136

File tree

4 files changed

+35
-41
lines changed

4 files changed

+35
-41
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import MLX
2+
3+
/// Check if array is in a supported format for conv weights
4+
public func checkArrayShape(_ arr: MLXArray) -> Bool {
5+
let shape = arr.shape
6+
switch shape.count {
7+
case 4:
8+
let outChannels = shape[0]
9+
let kH = shape[1]
10+
let kW = shape[2]
11+
// shape[3] is in_channels, which is ignored
12+
// Check if out_channels is the largest, and kH and kW are the same
13+
return (outChannels >= kH) && (outChannels >= kW) && (kH == kW)
14+
case 3:
15+
let kW = shape[1]
16+
let outChannels = shape[2]
17+
// shape[0] is ignored
18+
// Check if kW is larger than or equal to out_channels
19+
return kW >= outChannels
20+
default:
21+
// Any other number of dimensions is not supported
22+
return false
23+
}
24+
}

Libraries/MLXVLM/Models/Gemma3.swift

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -756,19 +756,6 @@ private class VisionModel: Module {
756756
visionModel(x, outputHiddenStates: outputHiddenStates)
757757
}
758758

759-
/// Check if array is already in MLX format for conv2d weights
760-
private func checkArrayShape(_ arr: MLXArray) -> Bool {
761-
let shape = arr.shape
762-
763-
// Check if the shape has 4 dimensions
764-
guard shape.count == 4 else { return false }
765-
766-
let (outChannels, kH, kW, _) = (shape[0], shape[1], shape[2], shape[3])
767-
768-
// Check if out_channels is the largest, and kH and kW are the same
769-
return (outChannels >= kH) && (outChannels >= kW) && (kH == kW)
770-
}
771-
772759
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
773760
var sanitizedWeights = [String: MLXArray]()
774761

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,20 +1652,6 @@ private func maskedScatter(
16521652
return resultFlat.reshaped(inputShape)
16531653
}
16541654

1655-
private func checkArrayShape(_ arr: MLXArray) -> Bool {
1656-
let shape = arr.shape
1657-
guard shape.count == 4 else {
1658-
return false
1659-
}
1660-
1661-
let (outChannels, kH, kW, _) = (shape[0], shape[1], shape[2], shape[3])
1662-
let result = (outChannels >= kH) && (outChannels >= kW) && (kH == kW)
1663-
print(
1664-
"🔍 checkArrayShape: shape=\(shape), outChannels=\(outChannels), kH=\(kH), kW=\(kW), result=\(result)"
1665-
)
1666-
return result
1667-
}
1668-
16691655
// MARK: - Main Model
16701656

16711657
public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
@@ -3925,27 +3911,31 @@ private class Gemma3nAudioModel: Module {
39253911
return (audioencodings, currentMask)
39263912
}
39273913

3914+
/// Sanitizes weights by transposing convolution layers if they are not
3915+
/// already in the expected MLX format.
39283916
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
39293917
var sanitizedWeights = [String: MLXArray]()
39303918

39313919
for (k, v) in weights {
39323920
if k.contains("conv.weight") {
3933-
// The checkArrayShape function is not robust.
3934-
// The Python implementation doesn't use it. It's safer to just transpose.
3935-
// Assuming NCHW -> NHWC for Conv2d
3936-
if v.ndim == 4 {
3921+
// A Conv2D weight should be 4D.
3922+
// If it is, check if it needs transposing from NCHW to NHWC.
3923+
// If checkArrayShape is true, it's already in the correct format.
3924+
if v.ndim == 4 && !checkArrayShape(v) {
39373925
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
39383926
} else {
39393927
sanitizedWeights[k] = v
39403928
}
39413929
} else if k.contains("conv1d.weight") {
3942-
// Assuming NCL -> NLC for Conv1d
3943-
if v.ndim == 3 {
3930+
// A Conv1D weight should be 3D.
3931+
// If it is, check if it needs transposing from NCL to NLC.
3932+
if v.ndim == 3 && !checkArrayShape(v) {
39443933
sanitizedWeights[k] = v.transposed(0, 2, 1)
39453934
} else {
39463935
sanitizedWeights[k] = v
39473936
}
39483937
} else {
3938+
// For all other weights, keep them as they are.
39493939
sanitizedWeights[k] = v
39503940
}
39513941
}
@@ -4149,7 +4139,6 @@ public struct Gemma3nProcessorConfiguration: Codable, Sendable {
41494139
public let doConvertRgb: Bool?
41504140
public let doPanAndScan: Bool?
41514141

4152-
// Token identifiers - use default values that match Python implementation
41534142
public var imageTokenId: Int { 262145 }
41544143
public var audioTokenId: Int { 262273 }
41554144

Libraries/MLXVLM/Models/Idefics3.swift

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,6 @@ private enum Language {
386386
// MARK: - Vision
387387

388388
private enum Vision {
389-
static func checkArrayShape(_ arr: MLXArray) -> Bool {
390-
if arr.ndim != 4 { return false }
391-
let (o, h, w, _) = (arr.dim(0), arr.dim(1), arr.dim(2), arr.dim(3))
392-
return (o >= h && o >= w && h == w)
393-
}
394-
395389
fileprivate class Attention: Module {
396390
let numHeads: Int
397391
let scale: Float
@@ -602,7 +596,7 @@ private enum Vision {
602596
if k.contains("position_ids") {
603597
continue
604598
} else if k.contains("patch_embedding.weight") {
605-
if Vision.checkArrayShape(v) {
599+
if checkArrayShape(v) {
606600
sanitizedWeights[k] = v
607601
} else {
608602
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)

0 commit comments

Comments
 (0)