Skip to content

Commit d9f78a4

Browse files
committed
Add configuration and processor
1 parent 91b0a24 commit d9f78a4

File tree

2 files changed

+245
-9
lines changed

2 files changed

+245
-9
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 228 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public struct AudioConfig: Codable, Sendable, MultimodalConfig {
6868
sscpConvEps: Float = 1e-3,
6969
rmsNormEps: Float = 1e-6,
7070
gradientClipping: Float = 10000000000.0,
71-
vocabOffset: Int = 262272
71+
vocabOffset: Int = 262272 // 262_144 + 128 (text vocab size + vision vocab size)
7272
) {
7373
self.inputFeatSize = inputFeatSize
7474
self.hiddenSize = hiddenSize
@@ -1580,6 +1580,18 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
15801580

15811581
var inputsEmbeds = languageModel.model.embedTokens(inputIds)
15821582

1583+
// Ensure no gaps between text, vision, and audio embeddings, in that order
1584+
// This matches the Python assertion
1585+
assert(
1586+
embedAudio.vocabOffset == config.vocabSize - config.audioConfig.vocabSize,
1587+
"Audio vocab offset mismatch"
1588+
)
1589+
assert(
1590+
embedVision.vocabOffset == config.vocabSize - config.audioConfig.vocabSize
1591+
- config.visionConfig.vocabSize,
1592+
"Vision vocab offset mismatch"
1593+
)
1594+
15831595
// Handle vision tokens
15841596
if pixelValues != nil {
15851597
let visionMask = logicalAnd(
@@ -1701,12 +1713,12 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
17011713
if let inputIds = inputIds {
17021714
specialModalityMask = expandedDimensions(inputIds .== tokenId, axis: -1)
17031715
} else {
1704-
let tokenEmbedding =
1705-
if modality == "audio" {
1706-
embedAudio(MLXArray([tokenId]))
1707-
} else {
1708-
languageModel.model.embedTokens(MLXArray([tokenId]))
1709-
}
1716+
// When inputIds is nil, create mask by comparing embeddings
1717+
let embedFn: (MLXArray) -> MLXArray =
1718+
modality == "audio"
1719+
? { self.embedAudio($0, inputsEmbeds: nil) }
1720+
: { self.languageModel.model.embedTokens($0) }
1721+
let tokenEmbedding = embedFn(MLXArray([tokenId]))
17101722
specialModalityMask = inputsEmbeds .== tokenEmbedding
17111723
}
17121724

@@ -1718,8 +1730,8 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
17181730
guard modalityTokensInText == featureTokens else {
17191731
fatalError(
17201732
"""
1721-
Number of \(modality)s does not match number of special \(modality) tokens in the input text.
1722-
Got \(modalityTokensInText) \(modality) tokens in the text and \(featureTokens) tokens from \(modality) embeddings.
1733+
Number of \(modality)s does not match number of special \(modality) tokens in the input text.
1734+
Got \(modalityTokensInText) \(modality) tokens in the text and \(featureTokens) tokens from \(modality) embeddings.
17231735
""")
17241736
}
17251737

@@ -3828,3 +3840,210 @@ extension Gemma3n: LoRAModel {
38283840
}
38293841
}
38303842
}
3843+
3844+
// MARK: - VLM Factory Configuration and Processor
3845+
3846+
public struct Gemma3nConfiguration: Codable, Sendable {
3847+
public let textConfig: TextConfig
3848+
public let visionConfig: VisionConfig
3849+
public let audioConfig: AudioConfig
3850+
public let modelType: String
3851+
public let vocabSize: Int
3852+
public let ignoreIndex: Int
3853+
public let imageTokenIndex: Int
3854+
public let audioTokenId: Int
3855+
public let imageTokenId: Int
3856+
public let hiddenSize: Int
3857+
public let padTokenId: Int
3858+
public let visionSoftTokensPerImage: Int
3859+
public let audioSoftTokensPerImage: Int
3860+
public let eosTokenId: [Int]?
3861+
public let quantization: QuantizationConfig?
3862+
3863+
public var vocabularySize: Int { vocabSize }
3864+
3865+
enum CodingKeys: String, CodingKey {
3866+
case textConfig = "text_config"
3867+
case visionConfig = "vision_config"
3868+
case audioConfig = "audio_config"
3869+
case modelType = "model_type"
3870+
case vocabSize = "vocab_size"
3871+
case ignoreIndex = "ignore_index"
3872+
case imageTokenIndex = "image_token_index"
3873+
case audioTokenId = "audio_token_id"
3874+
case imageTokenId = "image_token_id"
3875+
case hiddenSize = "hidden_size"
3876+
case padTokenId = "pad_token_id"
3877+
case visionSoftTokensPerImage = "vision_soft_tokens_per_image"
3878+
case audioSoftTokensPerImage = "audio_soft_tokens_per_image"
3879+
case eosTokenId = "eos_token_id"
3880+
case quantization
3881+
}
3882+
3883+
public init(from modelConfig: ModelConfig, quantization: QuantizationConfig? = nil) {
3884+
self.textConfig = modelConfig.textConfig
3885+
self.visionConfig = modelConfig.visionConfig
3886+
self.audioConfig = modelConfig.audioConfig
3887+
self.modelType = modelConfig.modelType
3888+
self.vocabSize = modelConfig.vocabSize
3889+
self.ignoreIndex = modelConfig.ignoreIndex
3890+
self.imageTokenIndex = modelConfig.imageTokenIndex
3891+
self.audioTokenId = modelConfig.audioTokenId
3892+
self.imageTokenId = modelConfig.imageTokenId
3893+
self.hiddenSize = modelConfig.hiddenSize
3894+
self.padTokenId = modelConfig.padTokenId
3895+
self.visionSoftTokensPerImage = modelConfig.visionSoftTokensPerImage
3896+
self.audioSoftTokensPerImage = modelConfig.audioSoftTokensPerImage
3897+
self.eosTokenId = modelConfig.eosTokenId
3898+
self.quantization = quantization
3899+
}
3900+
}
3901+
3902+
public class Gemma3nProcessor: UserInputProcessor {
3903+
private let config: Gemma3nProcessorConfiguration
3904+
private let tokenizer: any Tokenizer
3905+
3906+
public init(_ config: Gemma3nProcessorConfiguration, tokenizer: any Tokenizer) {
3907+
self.config = config
3908+
self.tokenizer = tokenizer
3909+
}
3910+
3911+
public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> (
3912+
MLXArray, THW
3913+
) {
3914+
var userProcessing = processing ?? UserInput.Processing()
3915+
let targetSize = CGSize(width: config.imageSize, height: config.imageSize)
3916+
userProcessing.resize = targetSize
3917+
3918+
let processedImages = try images.map { image in
3919+
let processedImage = MediaProcessing.apply(image, processing: userProcessing)
3920+
let srgbImage = MediaProcessing.inSRGBToneCurveSpace(processedImage)
3921+
let resizedImage = try MediaProcessing.resampleBicubic(srgbImage, to: targetSize)
3922+
let normalizedImage = MediaProcessing.normalize(
3923+
resizedImage, mean: config.imageMeanTuple, std: config.imageStdTuple)
3924+
return MediaProcessing.asMLXArray(normalizedImage)
3925+
}
3926+
3927+
let pixelValues = concatenated(processedImages)
3928+
return (pixelValues, THW(images.count, config.imageSize, config.imageSize))
3929+
}
3930+
3931+
public func prepare(input: UserInput) async throws -> LMInput {
3932+
// Create structured messages for Gemma3n using LIST_WITH_IMAGE_TYPE_TEXT format
3933+
var messages: [[String: Any]] = []
3934+
3935+
if !input.images.isEmpty {
3936+
// Add image and text content in the format expected by Gemma3n
3937+
let content: [[String: Any]] = [
3938+
["type": "image"],
3939+
["type": "text", "text": input.prompt.description],
3940+
]
3941+
messages.append(["role": "user", "content": content])
3942+
} else {
3943+
// Text-only message
3944+
messages.append(["role": "user", "content": input.prompt.description])
3945+
}
3946+
3947+
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
3948+
3949+
// Process images if any
3950+
var processedImage: LMInput.ProcessedImage?
3951+
3952+
if !input.images.isEmpty {
3953+
let imagePixelsAndFrames = try input.images.map {
3954+
try preprocess(images: [$0.asCIImage()], processing: input.processing)
3955+
}
3956+
let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 })
3957+
processedImage = LMInput.ProcessedImage(
3958+
pixels: imagePixelsConcatenated,
3959+
frames: imagePixelsAndFrames.map { $0.1 }
3960+
)
3961+
3962+
// Note: Unlike Gemma3, Gemma3n doesn't expand image tokens in the processor
3963+
// The model handles token mapping directly in get_input_embeddings
3964+
}
3965+
3966+
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
3967+
let mask = ones(like: promptArray).asType(.int8)
3968+
return LMInput(
3969+
text: .init(tokens: promptArray, mask: mask),
3970+
image: processedImage
3971+
)
3972+
}
3973+
}
3974+
3975+
public struct Gemma3nProcessorConfiguration: Codable, Sendable {
3976+
public let processorClass: String
3977+
public let imageProcessorType: String?
3978+
public let doNormalize: Bool
3979+
public let doRescale: Bool
3980+
public let doResize: Bool
3981+
public let imageMean: [CGFloat]
3982+
public let imageStd: [CGFloat]
3983+
public let visionSoftTokensPerImage: Int
3984+
public let resample: Int
3985+
public let rescaleFactor: Float
3986+
public let size: ImageSize
3987+
3988+
// Optional fields that may be present in some configs
3989+
public let doConvertRgb: Bool?
3990+
public let doPanAndScan: Bool?
3991+
3992+
// Token identifiers - use default values that match Python implementation
3993+
public var imageTokenId: Int { 262145 } // From Python: image_token_id = 262145
3994+
public var audioTokenId: Int { 262273 } // From Python: audio_token_id = 262273
3995+
3996+
public struct ImageSize: Codable, Sendable {
3997+
public let height: Int
3998+
public let width: Int
3999+
}
4000+
4001+
// Computed properties for convenience
4002+
public var imageSize: Int { size.height }
4003+
4004+
public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) {
4005+
(imageMean[0], imageMean[1], imageMean[2])
4006+
}
4007+
4008+
public var imageStdTuple: (CGFloat, CGFloat, CGFloat) {
4009+
(imageStd[0], imageStd[1], imageStd[2])
4010+
}
4011+
4012+
enum CodingKeys: String, CodingKey {
4013+
case processorClass = "processor_class"
4014+
case imageProcessorType = "image_processor_type"
4015+
case doNormalize = "do_normalize"
4016+
case doRescale = "do_rescale"
4017+
case doResize = "do_resize"
4018+
case doConvertRgb = "do_convert_rgb"
4019+
case doPanAndScan = "do_pan_and_scan"
4020+
case imageMean = "image_mean"
4021+
case imageStd = "image_std"
4022+
case visionSoftTokensPerImage = "vision_soft_tokens_per_image"
4023+
case resample
4024+
case rescaleFactor = "rescale_factor"
4025+
case size
4026+
}
4027+
}
4028+
4029+
extension Gemma3n {
4030+
public convenience init(_ config: Gemma3nConfiguration) {
4031+
let modelConfig = ModelConfig(
4032+
textConfig: config.textConfig,
4033+
visionConfig: config.visionConfig,
4034+
audioConfig: config.audioConfig,
4035+
modelType: config.modelType,
4036+
vocabSize: config.vocabSize,
4037+
ignoreIndex: config.ignoreIndex,
4038+
imageTokenIndex: config.imageTokenIndex,
4039+
audioTokenId: config.audioTokenId,
4040+
imageTokenId: config.imageTokenId,
4041+
hiddenSize: config.hiddenSize,
4042+
padTokenId: config.padTokenId,
4043+
visionSoftTokensPerImage: config.visionSoftTokensPerImage,
4044+
audioSoftTokensPerImage: config.audioSoftTokensPerImage,
4045+
eosTokenId: config.eosTokenId
4046+
)
4047+
self.init(modelConfig)
4048+
}
4049+
}

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ public class VLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
8686
"qwen2_5_vl": create(Qwen25VLConfiguration.self, Qwen25VL.init),
8787
"idefics3": create(Idefics3Configuration.self, Idefics3.init),
8888
"gemma3": create(Gemma3Configuration.self, Gemma3.init),
89+
"gemma3n": create(Gemma3nConfiguration.self, Gemma3n.init),
8990
"smolvlm": create(SmolVLM2Configuration.self, SmolVLM2.init),
9091
]
9192
}
@@ -111,6 +112,8 @@ public class VLMProcessorTypeRegistry: ProcessorTypeRegistry, @unchecked Sendabl
111112
Idefics3ProcessorConfiguration.self, Idefics3Processor.init),
112113
"Gemma3Processor": create(
113114
Gemma3ProcessorConfiguration.self, Gemma3Processor.init),
115+
"Gemma3nProcessor": create(
116+
Gemma3nProcessorConfiguration.self, Gemma3nProcessor.init),
114117
"SmolVLMProcessor": create(
115118
SmolVLMProcessorConfiguration.self, SmolVLMProcessor.init),
116119
]
@@ -166,6 +169,18 @@ public class VLMRegistry: AbstractModelRegistry, @unchecked Sendable {
166169
extraEOSTokens: ["<end_of_turn>"]
167170
)
168171

172+
static public let gemma3n_E2B_instruct = ModelConfiguration(
173+
id: "mlx-community/gemma-3n-E2B-it-bf16",
174+
defaultPrompt: "Describe this image.",
175+
extraEOSTokens: ["<end_of_turn>"]
176+
)
177+
178+
static public let gemma3n_E4B_instruct = ModelConfiguration(
179+
id: "mlx-community/gemma-3n-E4B-it-bf16",
180+
defaultPrompt: "Describe this image.",
181+
extraEOSTokens: ["<end_of_turn>"]
182+
)
183+
169184
static public let smolvlm = ModelConfiguration(
170185
id: "HuggingFaceTB/SmolVLM2-500M-Video-Instruct-mlx",
171186
defaultPrompt:
@@ -181,6 +196,8 @@ public class VLMRegistry: AbstractModelRegistry, @unchecked Sendable {
181196
gemma3_4B_qat_4bit,
182197
gemma3_12B_qat_4bit,
183198
gemma3_27B_qat_4bit,
199+
gemma3n_E2B_instruct,
200+
gemma3n_E4B_instruct,
184201
smolvlm,
185202
]
186203
}

0 commit comments

Comments
 (0)