Skip to content

Commit 96cde3f

Browse files
authored
change method visibility to public. add some center crop / size related methods (#263)
1 parent e1d3222 commit 96cde3f

File tree

1 file changed

+74
-9
lines changed

1 file changed

+74
-9
lines changed

Libraries/MLXVLM/MediaProcessing.swift

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,32 @@ public enum MediaProcessing {
3434
/// sRGB primaries, etc.
3535
///
3636
/// See ``inLinearToneCurveSpace(_:)``
37-
static public func inSRGBToneCurveSpace(_ image: CIImage) -> CIImage {
37+
public static func inSRGBToneCurveSpace(_ image: CIImage) -> CIImage {
3838
let filter = CIFilter.linearToSRGBToneCurve()
3939
filter.inputImage = image
4040
return filter.outputImage!
4141
}
4242

4343
/// Inverse of ``inSRGBToneCurveSpace(_:)`` (for completeness).
44-
static public func inLinearToneCurveSpace(_ image: CIImage) -> CIImage {
44+
public static func inLinearToneCurveSpace(_ image: CIImage) -> CIImage {
4545
let filter = CIFilter.sRGBToneCurveToLinear()
4646
filter.inputImage = image
4747
return filter.outputImage!
4848
}
4949

5050
/// Compute the best fit size of one size in another (respecting aspect ratio).
51-
static public func bestFit(_ size: CGSize, in other: CGSize) -> CGSize {
51+
public static func bestFit(_ size: CGSize, in other: CGSize) -> CGSize {
5252
let scale = bestFitScale(size, in: other)
5353
return CGSize(width: round(size.width * scale), height: round(size.height * scale))
5454
}
5555

5656
/// Compute the best fit scale of one size in another (respecting aspect ratio).
57-
static public func bestFitScale(_ size: CGSize, in other: CGSize) -> CGFloat {
57+
public static func bestFitScale(_ size: CGSize, in other: CGSize) -> CGFloat {
5858
min(other.width / size.width, other.height / size.height)
5959
}
6060

6161
/// Resample the image using bicubic interpolation.
62-
static public func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
62+
public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
6363
let filter = CIFilter.bicubicScaleTransform()
6464
let extent = image.extent.size
6565

@@ -88,7 +88,7 @@ public enum MediaProcessing {
8888
}
8989

9090
/// Normalize the image using the given mean and standard deviation parameters.
91-
static public func normalize(
91+
public static func normalize(
9292
_ image: CIImage, mean: (CGFloat, CGFloat, CGFloat), std: (CGFloat, CGFloat, CGFloat)
9393
) -> CIImage {
9494
let filter = CIFilter.colorMatrix()
@@ -113,7 +113,7 @@ public enum MediaProcessing {
113113
}
114114

115115
/// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]`
116-
static public func asMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) -> MLXArray {
116+
public static func asMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) -> MLXArray {
117117
let size = image.extent.size
118118
let w = Int(size.width.rounded())
119119
let h = Int(size.height.rounded())
@@ -144,8 +144,72 @@ public enum MediaProcessing {
144144
return array
145145
}
146146

147+
/// Return `true` if the size is smaller or equal to the size of the `extent`.
148+
public static func rectSmallerOrEqual(_ extent: CGRect, size: CGSize) -> Bool {
149+
return extent.width <= size.width && extent.height <= size.height
150+
}
151+
152+
/// Given an `extent` and a target `size` produce the `CGRect` that will be a center crop.
153+
public static func centerCrop(_ extent: CGRect, size: CGSize) -> CGRect {
154+
let targetWidth = min(extent.width, size.width)
155+
let targetHeight = min(extent.height, size.height)
156+
157+
return CGRect(
158+
x: (extent.maxX - targetWidth) / 2,
159+
y: (extent.maxY - targetHeight) / 2,
160+
width: targetWidth, height: targetHeight
161+
)
162+
}
163+
164+
/// Given an `image` and a target `size` produce the `CIImage` that will be a center crop.
165+
public static func centerCrop(_ image: CIImage, size: CGSize) -> CIImage {
166+
let extent = image.extent
167+
if rectSmallerOrEqual(extent, size: size) {
168+
return image
169+
}
170+
171+
let crop = centerCrop(extent, size: size)
172+
return
173+
image
174+
.cropped(to: crop)
175+
.transformed(by: CGAffineTransform(translationX: -crop.minX, y: -crop.minY))
176+
}
177+
178+
/// Given a `size` and a target `shortestEdge` compute a new size
179+
/// that respects the aspect ratio of the original `size` and is
180+
/// constrained by the `shortestEdge`.
181+
public static func fitIn(_ size: CGSize, shortestEdge: Int) -> CGSize {
182+
let floatShortestEdge = CGFloat(shortestEdge)
183+
184+
let (short, long) =
185+
size.width <= size.height ? (size.width, size.height) : (size.height, size.width)
186+
let newShort = floatShortestEdge
187+
let newLong = floatShortestEdge * long / short
188+
189+
return size.width <= size.height
190+
? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort)
191+
}
192+
193+
/// Given a `size` and a target `longestEdge` compute a new size
194+
/// that respects the aspect ratio of the original `size` and is
195+
/// constrained by the `longestEdge`.
196+
public static func fitIn(_ size: CGSize, longestEdge: Int) -> CGSize {
197+
let floatLongestEdge = CGFloat(longestEdge)
198+
199+
var (newShort, newLong) =
200+
size.width <= size.height ? (size.width, size.height) : (size.height, size.width)
201+
202+
if newLong > floatLongestEdge {
203+
newLong = floatLongestEdge
204+
newShort = floatLongestEdge * newShort / newLong
205+
}
206+
207+
return size.width <= size.height
208+
? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort)
209+
}
210+
147211
/// Apply `UserInput.Processing`, if needed, to the image.
148-
static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage {
212+
public static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage {
149213
var image = image
150214

151215
if let resize = processing?.resize {
@@ -156,7 +220,8 @@ public enum MediaProcessing {
156220
return image
157221
}
158222

159-
static func asCIImageSequence(_ asset: AVAsset, samplesPerSecond: Int) async throws -> [CIImage]
223+
public static func asCIImageSequence(_ asset: AVAsset, samplesPerSecond: Int) async throws
224+
-> [CIImage]
160225
{
161226
// Use AVAssetImageGenerator to extract frames
162227
let generator = AVAssetImageGenerator(asset: asset)

0 commit comments

Comments
 (0)