@@ -1896,53 +1896,12 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1896
1896
sanitizedWeights [ k] = v
1897
1897
}
1898
1898
}
1899
+ sanitizedWeights = visionTower. sanitize ( weights: sanitizedWeights)
1900
+ // TODO: The audio and language sanitization is not done in the Python implementation. Is this needed?
1901
+ sanitizedWeights = audioTower. sanitize ( weights: sanitizedWeights)
1902
+ sanitizedWeights = languageModel. sanitize ( weights: sanitizedWeights)
1899
1903
return sanitizedWeights
1900
1904
}
1901
-
1902
- public static func fromPretrained( pathOrHfRepo: String ) throws -> Gemma3n {
1903
- let path = URL ( fileURLWithPath: pathOrHfRepo)
1904
-
1905
- let configPath = path. appendingPathComponent ( " config.json " )
1906
- let configData = try Data ( contentsOf: configPath)
1907
-
1908
- let decoder = JSONDecoder ( )
1909
- decoder. keyDecodingStrategy = . convertFromSnakeCase
1910
- let modelConfig = try decoder. decode ( ModelConfig . self, from: configData)
1911
-
1912
- let model = Gemma3n ( modelConfig)
1913
-
1914
- // Load all weight files into a single dictionary
1915
- let weightFiles = try FileManager . default. contentsOfDirectory ( atPath: path. path)
1916
- . filter { $0. hasSuffix ( " .safetensors " ) }
1917
- guard !weightFiles. isEmpty else {
1918
- throw NSError (
1919
- domain: " ModelLoading " , code: 1 ,
1920
- userInfo: [ NSLocalizedDescriptionKey: " No safetensors found in \( path. path) " ] )
1921
- }
1922
-
1923
- var weights = [ String: MLXArray] ( )
1924
- for weightFile in weightFiles {
1925
- let fileWeights = try loadArrays ( url: path. appendingPathComponent ( weightFile) )
1926
- weights. merge ( fileWeights) { _, new in new }
1927
- }
1928
-
1929
- var sanitizedWeights = model. sanitize ( weights: weights)
1930
- sanitizedWeights = model. visionTower. sanitize ( weights: sanitizedWeights)
1931
- // The audio and language sanitization is not done in the Python implementation
1932
- // sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1933
- // sanitizedWeights = model.languageModel.sanitize(weights: sanitizedWeights)
1934
-
1935
- // Handle tied lm_head weights
1936
- if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
1937
- if let embedWeight = sanitizedWeights [ " language_model.model.embed_tokens.weight " ] {
1938
- sanitizedWeights [ " language_model.lm_head.weight " ] = embedWeight
1939
- }
1940
- }
1941
-
1942
- // Load the weights
1943
- try model. update ( parameters: ModuleParameters . unflattened ( sanitizedWeights) , verify: [ . all] )
1944
- return model
1945
- }
1946
1905
}
1947
1906
1948
1907
// MARK: - Audio Model Components
0 commit comments