@@ -68,7 +68,7 @@ public struct AudioConfig: Codable, Sendable, MultimodalConfig {
68
68
sscpConvEps: Float = 1e-3 ,
69
69
rmsNormEps: Float = 1e-6 ,
70
70
gradientClipping: Float = 10000000000.0 ,
71
- vocabOffset: Int = 262272
71
+ vocabOffset: Int = 262272 // 262_144 + 128 (text vocab size + vision vocab size)
72
72
) {
73
73
self . inputFeatSize = inputFeatSize
74
74
self . hiddenSize = hiddenSize
@@ -1580,6 +1580,18 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1580
1580
1581
1581
var inputsEmbeds = languageModel. model. embedTokens ( inputIds)
1582
1582
1583
+ // Ensure no gaps between text, vision, and audio embeddings, in that order
1584
+ // This matches the Python assertion
1585
+ assert (
1586
+ embedAudio. vocabOffset == config. vocabSize - config. audioConfig. vocabSize,
1587
+ " Audio vocab offset mismatch "
1588
+ )
1589
+ assert (
1590
+ embedVision. vocabOffset == config. vocabSize - config. audioConfig. vocabSize
1591
+ - config. visionConfig. vocabSize,
1592
+ " Vision vocab offset mismatch "
1593
+ )
1594
+
1583
1595
// Handle vision tokens
1584
1596
if pixelValues != nil {
1585
1597
let visionMask = logicalAnd (
@@ -1701,12 +1713,12 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1701
1713
if let inputIds = inputIds {
1702
1714
specialModalityMask = expandedDimensions ( inputIds .== tokenId, axis: - 1 )
1703
1715
} else {
1704
- let tokenEmbedding =
1705
- if modality == " audio " {
1706
- embedAudio ( MLXArray ( [ tokenId ] ) )
1707
- } else {
1708
- languageModel. model. embedTokens ( MLXArray ( [ tokenId ] ) )
1709
- }
1716
+ // When inputIds is nil, create mask by comparing embeddings
1717
+ let embedFn : ( MLXArray ) -> MLXArray =
1718
+ modality == " audio "
1719
+ ? { self . embedAudio ( $0 , inputsEmbeds : nil ) }
1720
+ : { self . languageModel. model. embedTokens ( $0 ) }
1721
+ let tokenEmbedding = embedFn ( MLXArray ( [ tokenId ] ) )
1710
1722
specialModalityMask = inputsEmbeds .== tokenEmbedding
1711
1723
}
1712
1724
@@ -1718,8 +1730,8 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1718
1730
guard modalityTokensInText == featureTokens else {
1719
1731
fatalError (
1720
1732
"""
1721
- Number of \( modality) s does not match number of special \( modality) tokens in the input text.
1722
- Got \( modalityTokensInText) \( modality) tokens in the text and \( featureTokens) tokens from \( modality) embeddings.
1733
+ Number of \( modality) s does not match number of special \( modality) tokens in the input text.
1734
+ Got \( modalityTokensInText) \( modality) tokens in the text and \( featureTokens) tokens from \( modality) embeddings.
1723
1735
""" )
1724
1736
}
1725
1737
@@ -3828,3 +3840,210 @@ extension Gemma3n: LoRAModel {
3828
3840
}
3829
3841
}
3830
3842
}
3843
+
3844
+ // MARK: - VLM Factory Configuration and Processor
3845
+
3846
+ public struct Gemma3nConfiguration : Codable , Sendable {
3847
+ public let textConfig : TextConfig
3848
+ public let visionConfig : VisionConfig
3849
+ public let audioConfig : AudioConfig
3850
+ public let modelType : String
3851
+ public let vocabSize : Int
3852
+ public let ignoreIndex : Int
3853
+ public let imageTokenIndex : Int
3854
+ public let audioTokenId : Int
3855
+ public let imageTokenId : Int
3856
+ public let hiddenSize : Int
3857
+ public let padTokenId : Int
3858
+ public let visionSoftTokensPerImage : Int
3859
+ public let audioSoftTokensPerImage : Int
3860
+ public let eosTokenId : [ Int ] ?
3861
+ public let quantization : QuantizationConfig ?
3862
+
3863
+ public var vocabularySize : Int { vocabSize }
3864
+
3865
+ enum CodingKeys : String , CodingKey {
3866
+ case textConfig = " text_config "
3867
+ case visionConfig = " vision_config "
3868
+ case audioConfig = " audio_config "
3869
+ case modelType = " model_type "
3870
+ case vocabSize = " vocab_size "
3871
+ case ignoreIndex = " ignore_index "
3872
+ case imageTokenIndex = " image_token_index "
3873
+ case audioTokenId = " audio_token_id "
3874
+ case imageTokenId = " image_token_id "
3875
+ case hiddenSize = " hidden_size "
3876
+ case padTokenId = " pad_token_id "
3877
+ case visionSoftTokensPerImage = " vision_soft_tokens_per_image "
3878
+ case audioSoftTokensPerImage = " audio_soft_tokens_per_image "
3879
+ case eosTokenId = " eos_token_id "
3880
+ case quantization
3881
+ }
3882
+
3883
+ public init ( from modelConfig: ModelConfig , quantization: QuantizationConfig ? = nil ) {
3884
+ self . textConfig = modelConfig. textConfig
3885
+ self . visionConfig = modelConfig. visionConfig
3886
+ self . audioConfig = modelConfig. audioConfig
3887
+ self . modelType = modelConfig. modelType
3888
+ self . vocabSize = modelConfig. vocabSize
3889
+ self . ignoreIndex = modelConfig. ignoreIndex
3890
+ self . imageTokenIndex = modelConfig. imageTokenIndex
3891
+ self . audioTokenId = modelConfig. audioTokenId
3892
+ self . imageTokenId = modelConfig. imageTokenId
3893
+ self . hiddenSize = modelConfig. hiddenSize
3894
+ self . padTokenId = modelConfig. padTokenId
3895
+ self . visionSoftTokensPerImage = modelConfig. visionSoftTokensPerImage
3896
+ self . audioSoftTokensPerImage = modelConfig. audioSoftTokensPerImage
3897
+ self . eosTokenId = modelConfig. eosTokenId
3898
+ self . quantization = quantization
3899
+ }
3900
+ }
3901
+
3902
+ public class Gemma3nProcessor : UserInputProcessor {
3903
+ private let config : Gemma3nProcessorConfiguration
3904
+ private let tokenizer : any Tokenizer
3905
+
3906
+ public init ( _ config: Gemma3nProcessorConfiguration , tokenizer: any Tokenizer ) {
3907
+ self . config = config
3908
+ self . tokenizer = tokenizer
3909
+ }
3910
+
3911
+ public func preprocess( images: [ CIImage ] , processing: UserInput . Processing ? ) throws -> (
3912
+ MLXArray , THW
3913
+ ) {
3914
+ var userProcessing = processing ?? UserInput . Processing ( )
3915
+ let targetSize = CGSize ( width: config. imageSize, height: config. imageSize)
3916
+ userProcessing. resize = targetSize
3917
+
3918
+ let processedImages = try images. map { image in
3919
+ let processedImage = MediaProcessing . apply ( image, processing: userProcessing)
3920
+ let srgbImage = MediaProcessing . inSRGBToneCurveSpace ( processedImage)
3921
+ let resizedImage = try MediaProcessing . resampleBicubic ( srgbImage, to: targetSize)
3922
+ let normalizedImage = MediaProcessing . normalize (
3923
+ resizedImage, mean: config. imageMeanTuple, std: config. imageStdTuple)
3924
+ return MediaProcessing . asMLXArray ( normalizedImage)
3925
+ }
3926
+
3927
+ let pixelValues = concatenated ( processedImages)
3928
+ return ( pixelValues, THW ( images. count, config. imageSize, config. imageSize) )
3929
+ }
3930
+
3931
+ public func prepare( input: UserInput ) async throws -> LMInput {
3932
+ // Create structured messages for Gemma3n using LIST_WITH_IMAGE_TYPE_TEXT format
3933
+ var messages : [ [ String : Any ] ] = [ ]
3934
+
3935
+ if !input. images. isEmpty {
3936
+ // Add image and text content in the format expected by Gemma3n
3937
+ let content : [ [ String : Any ] ] = [
3938
+ [ " type " : " image " ] ,
3939
+ [ " type " : " text " , " text " : input. prompt. description] ,
3940
+ ]
3941
+ messages. append ( [ " role " : " user " , " content " : content] )
3942
+ } else {
3943
+ // Text-only message
3944
+ messages. append ( [ " role " : " user " , " content " : input. prompt. description] )
3945
+ }
3946
+
3947
+ var promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
3948
+
3949
+ // Process images if any
3950
+ var processedImage : LMInput . ProcessedImage ?
3951
+
3952
+ if !input. images. isEmpty {
3953
+ let imagePixelsAndFrames = try input. images. map {
3954
+ try preprocess ( images: [ $0. asCIImage ( ) ] , processing: input. processing)
3955
+ }
3956
+ let imagePixelsConcatenated = concatenated ( imagePixelsAndFrames. map { $0. 0 } )
3957
+ processedImage = LMInput . ProcessedImage (
3958
+ pixels: imagePixelsConcatenated,
3959
+ frames: imagePixelsAndFrames. map { $0. 1 }
3960
+ )
3961
+
3962
+ // Note: Unlike Gemma3, Gemma3n doesn't expand image tokens in the processor
3963
+ // The model handles token mapping directly in get_input_embeddings
3964
+ }
3965
+
3966
+ let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
3967
+ let mask = ones ( like: promptArray) . asType ( . int8)
3968
+ return LMInput (
3969
+ text: . init( tokens: promptArray, mask: mask) ,
3970
+ image: processedImage
3971
+ )
3972
+ }
3973
+ }
3974
+
3975
+ public struct Gemma3nProcessorConfiguration : Codable , Sendable {
3976
+ public let processorClass : String
3977
+ public let imageProcessorType : String ?
3978
+ public let doNormalize : Bool
3979
+ public let doRescale : Bool
3980
+ public let doResize : Bool
3981
+ public let imageMean : [ CGFloat ]
3982
+ public let imageStd : [ CGFloat ]
3983
+ public let visionSoftTokensPerImage : Int
3984
+ public let resample : Int
3985
+ public let rescaleFactor : Float
3986
+ public let size : ImageSize
3987
+
3988
+ // Optional fields that may be present in some configs
3989
+ public let doConvertRgb : Bool ?
3990
+ public let doPanAndScan : Bool ?
3991
+
3992
+ // Token identifiers - use default values that match Python implementation
3993
+ public var imageTokenId : Int { 262145 } // From Python: image_token_id = 262145
3994
+ public var audioTokenId : Int { 262273 } // From Python: audio_token_id = 262273
3995
+
3996
+ public struct ImageSize : Codable , Sendable {
3997
+ public let height : Int
3998
+ public let width : Int
3999
+ }
4000
+
4001
+ // Computed properties for convenience
4002
+ public var imageSize : Int { size. height }
4003
+
4004
+ public var imageMeanTuple : ( CGFloat , CGFloat , CGFloat ) {
4005
+ ( imageMean [ 0 ] , imageMean [ 1 ] , imageMean [ 2 ] )
4006
+ }
4007
+
4008
+ public var imageStdTuple : ( CGFloat , CGFloat , CGFloat ) {
4009
+ ( imageStd [ 0 ] , imageStd [ 1 ] , imageStd [ 2 ] )
4010
+ }
4011
+
4012
+ enum CodingKeys : String , CodingKey {
4013
+ case processorClass = " processor_class "
4014
+ case imageProcessorType = " image_processor_type "
4015
+ case doNormalize = " do_normalize "
4016
+ case doRescale = " do_rescale "
4017
+ case doResize = " do_resize "
4018
+ case doConvertRgb = " do_convert_rgb "
4019
+ case doPanAndScan = " do_pan_and_scan "
4020
+ case imageMean = " image_mean "
4021
+ case imageStd = " image_std "
4022
+ case visionSoftTokensPerImage = " vision_soft_tokens_per_image "
4023
+ case resample
4024
+ case rescaleFactor = " rescale_factor "
4025
+ case size
4026
+ }
4027
+ }
4028
+
4029
+ extension Gemma3n {
4030
+ public convenience init ( _ config: Gemma3nConfiguration ) {
4031
+ let modelConfig = ModelConfig (
4032
+ textConfig: config. textConfig,
4033
+ visionConfig: config. visionConfig,
4034
+ audioConfig: config. audioConfig,
4035
+ modelType: config. modelType,
4036
+ vocabSize: config. vocabSize,
4037
+ ignoreIndex: config. ignoreIndex,
4038
+ imageTokenIndex: config. imageTokenIndex,
4039
+ audioTokenId: config. audioTokenId,
4040
+ imageTokenId: config. imageTokenId,
4041
+ hiddenSize: config. hiddenSize,
4042
+ padTokenId: config. padTokenId,
4043
+ visionSoftTokensPerImage: config. visionSoftTokensPerImage,
4044
+ audioSoftTokensPerImage: config. audioSoftTokensPerImage,
4045
+ eosTokenId: config. eosTokenId
4046
+ )
4047
+ self . init ( modelConfig)
4048
+ }
4049
+ }
0 commit comments