Skip to content

Commit 02767d5

Browse files
committed
Do all sanitization steps on load
1 parent 8820938 commit 02767d5

File tree

1 file changed

+4
-45
lines changed

1 file changed

+4
-45
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,53 +1896,12 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
18961896
sanitizedWeights[k] = v
18971897
}
18981898
}
1899+
sanitizedWeights = visionTower.sanitize(weights: sanitizedWeights)
1900+
// TODO: The audio and language sanitization is not done in the Python implementation. Is this needed?
1901+
sanitizedWeights = audioTower.sanitize(weights: sanitizedWeights)
1902+
sanitizedWeights = languageModel.sanitize(weights: sanitizedWeights)
18991903
return sanitizedWeights
19001904
}
1901-
1902-
public static func fromPretrained(pathOrHfRepo: String) throws -> Gemma3n {
1903-
let path = URL(fileURLWithPath: pathOrHfRepo)
1904-
1905-
let configPath = path.appendingPathComponent("config.json")
1906-
let configData = try Data(contentsOf: configPath)
1907-
1908-
let decoder = JSONDecoder()
1909-
decoder.keyDecodingStrategy = .convertFromSnakeCase
1910-
let modelConfig = try decoder.decode(ModelConfig.self, from: configData)
1911-
1912-
let model = Gemma3n(modelConfig)
1913-
1914-
// Load all weight files into a single dictionary
1915-
let weightFiles = try FileManager.default.contentsOfDirectory(atPath: path.path)
1916-
.filter { $0.hasSuffix(".safetensors") }
1917-
guard !weightFiles.isEmpty else {
1918-
throw NSError(
1919-
domain: "ModelLoading", code: 1,
1920-
userInfo: [NSLocalizedDescriptionKey: "No safetensors found in \(path.path)"])
1921-
}
1922-
1923-
var weights = [String: MLXArray]()
1924-
for weightFile in weightFiles {
1925-
let fileWeights = try loadArrays(url: path.appendingPathComponent(weightFile))
1926-
weights.merge(fileWeights) { _, new in new }
1927-
}
1928-
1929-
var sanitizedWeights = model.sanitize(weights: weights)
1930-
sanitizedWeights = model.visionTower.sanitize(weights: sanitizedWeights)
1931-
// The audio and language sanitization is not done in the Python implementation
1932-
// sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1933-
// sanitizedWeights = model.languageModel.sanitize(weights: sanitizedWeights)
1934-
1935-
// Handle tied lm_head weights
1936-
if sanitizedWeights["language_model.lm_head.weight"] == nil {
1937-
if let embedWeight = sanitizedWeights["language_model.model.embed_tokens.weight"] {
1938-
sanitizedWeights["language_model.lm_head.weight"] = embedWeight
1939-
}
1940-
}
1941-
1942-
// Load the weights
1943-
try model.update(parameters: ModuleParameters.unflattened(sanitizedWeights), verify: [.all])
1944-
return model
1945-
}
19461905
}
19471906

19481907
// MARK: - Audio Model Components

0 commit comments

Comments
 (0)