Skip to content

Commit 788fffa

Browse files
committed
Fix media downsampling
1 parent e2ef119 commit 788fffa

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

Libraries/MLXVLM/MediaProcessing.swift

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ private let context = CIContext()
1515
/// var image: CIImage
1616
/// image = MediaProcessing.inSRGBToneCurveSpace(image)
1717
///
18-
/// // apply user instructions
18+
/// // Apply user instructions
1919
/// image = MediaProcessing.apply(image, processing: processing)
2020
///
2121
/// image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize)
@@ -59,6 +59,10 @@ public enum MediaProcessing {
5959
}
6060

6161
/// Resample the image using bicubic interpolation.
62+
/// - Parameters:
63+
/// - image: The image to resample
64+
/// - size: The target size
65+
/// - Returns: The resampled image
6266
static public func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
6367
let filter = CIFilter.bicubicScaleTransform()
6468
let extent = image.extent.size
@@ -70,19 +74,13 @@ public enum MediaProcessing {
7074
let desiredAspectRatio = size.width / size.height
7175
filter.aspectRatio = Float(1 / inputAspectRatio * desiredAspectRatio)
7276

73-
// that image is now the aspect ratio of the target and the size
74-
// of the shorter dimension
75-
let scale: CGFloat
76-
if extent.width < extent.height {
77-
scale = size.width / extent.width
78-
} else {
79-
scale = size.height / extent.height
80-
}
77+
// Use the same scaling approach regardless of orientation
78+
let scale = min(size.width / extent.width, size.height / extent.height)
8179
filter.scale = Float(scale)
8280

8381
let rescaled = filter.outputImage!
8482

85-
// the image has a DoD larger than the requested size so crop
83+
// The image has a DoD larger than the requested size, so crop
8684
// it to the desired size
8785
return rescaled.cropped(to: CGRect(origin: .zero, size: size))
8886
}
@@ -94,7 +92,7 @@ public enum MediaProcessing {
9492
let filter = CIFilter.colorMatrix()
9593
filter.inputImage = image
9694

97-
// this should match
95+
// This should match
9896
// https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html
9997
//
10098
// output[channel] = (input[channel] - mean[channel]) / std[channel]
@@ -113,6 +111,10 @@ public enum MediaProcessing {
113111
}
114112

115113
/// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]`
114+
/// - Parameters:
115+
/// - image: The image to convert
116+
/// - colorSpace: Optional color space for rendering
117+
/// - Returns: The MLXArray representation of the image
116118
static public func asMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) -> MLXArray {
117119
let size = image.extent.size
118120
let w = Int(size.width.rounded())
@@ -135,10 +137,10 @@ public enum MediaProcessing {
135137

136138
var array = MLXArray(data, [h, w, 4], type: Float32.self)
137139

138-
// drop 4th channel
140+
// Drop 4th channel
139141
array = array[0..., 0..., ..<3]
140142

141-
// convert to 1, C, H, W
143+
// Convert to 1, C, H, W
142144
array = array.reshaped(1, h, w, 3).transposed(0, 3, 1, 2)
143145

144146
return array

Libraries/MLXVLM/Models/QwenVL.swift

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,25 @@ public struct QwenVL {
123123
{
124124
if height < factor {
125125
throw VLMError.imageProcessingFailure(
126-
"height: \(height) must be larger than factor: \(factor)")
126+
"Height: \(height) must be larger than factor: \(factor)")
127127
}
128128
if width < factor {
129129
throw VLMError.imageProcessingFailure(
130-
"width: \(width) must be larger than factor: \(factor)")
130+
"Width: \(width) must be larger than factor: \(factor)")
131131
}
132132
if max(height, width) / min(height, width) > 200 {
133133
throw VLMError.imageProcessingFailure(
134-
"absolute aspect ratio must be smaller than 200: \(width)x\(height)")
134+
"Absolute aspect ratio must be smaller than 200: \(width) × \(height)")
135135
}
136136

137+
// Maximum allowed dimension for any single side to prevent buffer overflows
138+
// This is important for portrait/landscape images with extreme aspect ratios
139+
let maxDimension = 224
140+
137141
var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor)
138142
var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor)
139143

144+
// Start by scaling based on total pixel count
140145
if hBar * wBar > maxPixels {
141146
let beta = sqrt(Float(height * width) / Float(maxPixels))
142147
hBar = Int(floor(Float(height) / beta / Float(factor))) * factor
@@ -146,6 +151,24 @@ public struct QwenVL {
146151
hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor
147152
wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor
148153
}
154+
155+
// Additionally check if either dimension exceeds the maximum allowed
156+
if hBar > maxDimension {
157+
// Calculate how much we need to scale down height
158+
let scale = Float(maxDimension) / Float(hBar)
159+
// Apply that scale to both dimensions to maintain aspect ratio
160+
hBar = Int(round(Float(hBar) * scale / Float(factor))) * factor
161+
wBar = Int(round(Float(wBar) * scale / Float(factor))) * factor
162+
}
163+
164+
if wBar > maxDimension {
165+
// Calculate how much we need to scale down width
166+
let scale = Float(maxDimension) / Float(wBar)
167+
// Apply that scale to both dimensions to maintain aspect ratio
168+
hBar = Int(round(Float(hBar) * scale / Float(factor))) * factor
169+
wBar = Int(round(Float(wBar) * scale / Float(factor))) * factor
170+
}
171+
149172
return (hBar, wBar)
150173
}
151174

0 commit comments

Comments
 (0)