diff --git a/Libraries/MLXVLM/Models/Idefics3.swift b/Libraries/MLXVLM/Models/Idefics3.swift index 5a73b539..634c05e2 100644 --- a/Libraries/MLXVLM/Models/Idefics3.swift +++ b/Libraries/MLXVLM/Models/Idefics3.swift @@ -664,46 +664,35 @@ public class Idefics3: Module, VLMModel, KVCacheDimensionProvider { } private func prepareInputsForMultimodal( - imageFeatures: MLXArray, inputs_embeds: MLXArray, inputIds: MLXArray + imageFeatures: MLXArray, + inputs_embeds: MLXArray, + inputIds: MLXArray ) -> MLXArray { - // inputIds shape: (1, seq_len) - // asArray(Int.self) -> [[Int]], take [0] to get [Int] - let ids: [[Int]] = [inputIds.asArray(Int.self)] + let imageTokenIndex = config.imageTokenIndex + // Get input IDs as array and find image positions + let ids: [[Int]] = [inputIds.asArray(Int.self)] let inputIdArray: [Int] = ids[0] - - let imageTokenIndex = config.imageTokenIndex let imagePositions = inputIdArray.enumerated().compactMap { $1 == imageTokenIndex ? $0 : nil } - var segments = [MLXArray]() - var start_idx = 0 + // Get image feature dimensions and reshape + let (numImages, _, visionHiddenSize) = ( + imageFeatures.dim(0), + imageFeatures.dim(1), + imageFeatures.dim(2) + ) - for pos in imagePositions { - if pos > start_idx { - let textSegment = inputs_embeds[0..., start_idx ..< pos, 0...] - if textSegment.dim(1) > 0 { - segments.append(textSegment) - } - } - start_idx = pos + 1 - segments.append(imageFeatures) - } + let reshapedImageFeatures = imageFeatures.reshaped(-1, visionHiddenSize) - if start_idx < inputs_embeds.dim(1) { - let remain = inputs_embeds[0..., start_idx..., 0...] - if remain.dim(1) > 0 { - segments.append(remain) - } - } + // Convert to same dtype as inputs_embeds (handling quantized models) + let typedImageFeatures = reshapedImageFeatures.asType(inputs_embeds.dtype) - var finalEmbeds = segments[0] - for seg in segments.dropFirst() { - finalEmbeds = concatenated([finalEmbeds, seg], axis: 1) - } + // Place image features at image token positions + inputs_embeds[0..., imagePositions as! MLXArrayIndex, 0...] = typedImageFeatures - return finalEmbeds + return inputs_embeds } public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws @@ -794,7 +783,7 @@ public class Idefics3Processor: UserInputProcessor { // From the Python code and default config, we know image_token_id is usually 49153. // Hardcode this since we can't pass it in or rely on it from the processor config. - private let imageTokenId = 49153 + private let imageTokenId = 49190 public init( _ config: Idefics3ProcessorConfiguration,