Skip to content

Add Gemma 3n #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions Libraries/MLXVLM/CheckArrayShape.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import MLX

/// Check if array is in a supported format for conv weights
public func checkArrayShape(_ arr: MLXArray) -> Bool {
let shape = arr.shape
switch shape.count {
case 4:
let outChannels = shape[0]
let kH = shape[1]
let kW = shape[2]
// shape[3] is in_channels, which is ignored
// Check if out_channels is the largest, and kH and kW are the same
return (outChannels >= kH) && (outChannels >= kW) && (kH == kW)
case 3:
let kW = shape[1]
let outChannels = shape[2]
// shape[0] is ignored
// Check if kW is larger than or equal to out_channels
return kW >= outChannels
default:
// Any other number of dimensions is not supported
return false
}
}
13 changes: 0 additions & 13 deletions Libraries/MLXVLM/Models/Gemma3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -756,19 +756,6 @@ private class VisionModel: Module {
visionModel(x, outputHiddenStates: outputHiddenStates)
}

/// Check if array is already in MLX format for conv2d weights
private func checkArrayShape(_ arr: MLXArray) -> Bool {
let shape = arr.shape

// Check if the shape has 4 dimensions
guard shape.count == 4 else { return false }

let (outChannels, kH, kW, _) = (shape[0], shape[1], shape[2], shape[3])

// Check if out_channels is the largest, and kH and kW are the same
return (outChannels >= kH) && (outChannels >= kW) && (kH == kW)
}

func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var sanitizedWeights = [String: MLXArray]()

Expand Down
Loading