Skip to content

Commit 0641a1e

Browse files
authored
mlx - image.resize add crop_to_aspect_ratio (#19699)
1 parent 5368176 commit 0641a1e

File tree

2 files changed

+139
-2
lines changed

2 files changed

+139
-2
lines changed

keras/src/backend/mlx/image.py

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,20 +300,28 @@ def resize(
300300
size,
301301
interpolation="bilinear",
302302
antialias=False,
303+
crop_to_aspect_ratio=False,
304+
pad_to_aspect_ratio=False,
305+
fill_mode="constant",
306+
fill_value=0.0,
303307
data_format="channels_last",
304308
):
305309
if antialias:
306310
raise NotImplementedError(
307311
"Antialiasing not implemented for the MLX backend"
308312
)
309-
313+
if pad_to_aspect_ratio and crop_to_aspect_ratio:
314+
raise ValueError(
315+
"Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` "
316+
"can be `True`."
317+
)
310318
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
311319
raise ValueError(
312320
"Invalid value for argument `interpolation`. Expected of one "
313321
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
314322
f"interpolation={interpolation}"
315323
)
316-
324+
target_height, target_width = size
317325
size = tuple(size)
318326
image = convert_to_tensor(image)
319327

@@ -324,6 +332,127 @@ def resize(
324332
f"image.shape={image.shape}"
325333
)
326334

335+
if crop_to_aspect_ratio:
336+
shape = image.shape
337+
if data_format == "channels_last":
338+
height, width = shape[-3], shape[-2]
339+
else:
340+
height, width = shape[-2], shape[-1]
341+
crop_height = int(float(width * target_height) / target_width)
342+
crop_height = min(height, crop_height)
343+
crop_width = int(float(height * target_width) / target_height)
344+
crop_width = min(width, crop_width)
345+
crop_box_hstart = int(float(height - crop_height) / 2)
346+
crop_box_wstart = int(float(width - crop_width) / 2)
347+
if data_format == "channels_last":
348+
if len(image.shape) == 4:
349+
image = image[
350+
:,
351+
crop_box_hstart : crop_box_hstart + crop_height,
352+
crop_box_wstart : crop_box_wstart + crop_width,
353+
:,
354+
]
355+
else:
356+
image = image[
357+
crop_box_hstart : crop_box_hstart + crop_height,
358+
crop_box_wstart : crop_box_wstart + crop_width,
359+
:,
360+
]
361+
else:
362+
if len(image.shape) == 4:
363+
image = image[
364+
:,
365+
:,
366+
crop_box_hstart : crop_box_hstart + crop_height,
367+
crop_box_wstart : crop_box_wstart + crop_width,
368+
]
369+
else:
370+
image = image[
371+
:,
372+
crop_box_hstart : crop_box_hstart + crop_height,
373+
crop_box_wstart : crop_box_wstart + crop_width,
374+
]
375+
elif pad_to_aspect_ratio:
376+
shape = image.shape
377+
batch_size = image.shape[0]
378+
if data_format == "channels_last":
379+
height, width, channels = shape[-3], shape[-2], shape[-1]
380+
else:
381+
channels, height, width = shape[-3], shape[-2], shape[-1]
382+
pad_height = int(float(width * target_height) / target_width)
383+
pad_height = max(height, pad_height)
384+
pad_width = int(float(height * target_width) / target_height)
385+
pad_width = max(width, pad_width)
386+
img_box_hstart = int(float(pad_height - height) / 2)
387+
img_box_wstart = int(float(pad_width - width) / 2)
388+
if data_format == "channels_last":
389+
if len(image.shape) == 4:
390+
padded_img = (
391+
mx.ones(
392+
(
393+
batch_size,
394+
pad_height + height,
395+
pad_width + width,
396+
channels,
397+
),
398+
dtype=image.dtype,
399+
)
400+
* fill_value
401+
)
402+
padded_img[
403+
:,
404+
img_box_hstart : img_box_hstart + height,
405+
img_box_wstart : img_box_wstart + width,
406+
:,
407+
] = image
408+
else:
409+
padded_img = (
410+
mx.ones(
411+
(pad_height + height, pad_width + width, channels),
412+
dtype=image.dtype,
413+
)
414+
* fill_value
415+
)
416+
padded_img[
417+
img_box_hstart : img_box_hstart + height,
418+
img_box_wstart : img_box_wstart + width,
419+
:,
420+
] = image
421+
else:
422+
if len(image.shape) == 4:
423+
padded_img = (
424+
mx.ones(
425+
(
426+
batch_size,
427+
channels,
428+
pad_height + height,
429+
pad_width + width,
430+
),
431+
dtype=image.dtype,
432+
)
433+
* fill_value
434+
)
435+
padded_img[
436+
:,
437+
:,
438+
img_box_hstart : img_box_hstart + height,
439+
img_box_wstart : img_box_wstart + width,
440+
] = image
441+
else:
442+
padded_img = (
443+
mx.ones(
444+
(channels, pad_height + height, pad_width + width),
445+
dtype=image.dtype,
446+
)
447+
* fill_value
448+
)
449+
padded_img[
450+
:,
451+
img_box_hstart : img_box_hstart + height,
452+
img_box_wstart : img_box_wstart + width,
453+
] = image
454+
image = padded_img
455+
327456
# Change to channels_last
328457
if data_format == "channels_first":
329458
image = (

keras/src/ops/image_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ def test_resize(self, interpolation, antialias, data_format):
273273
f"Received: interpolation={interpolation}, "
274274
f"antialias={antialias}."
275275
)
276+
if backend.backend() == "mlx":
277+
if interpolation in ["lanczos3", "lanczos5", "bicubic"]:
278+
self.skipTest(
279+
f"Resizing with interpolation={interpolation} is "
280+
"not supported by the mlx backend. "
281+
)
282+
elif antialias:
283+
self.skipTest("antialias=True not supported by mlx backend.")
276284
# Unbatched case
277285
if data_format == "channels_first":
278286
x = np.random.random((3, 50, 50)) * 255

0 commit comments

Comments
 (0)