@@ -300,20 +300,28 @@ def resize(
300
300
size ,
301
301
interpolation = "bilinear" ,
302
302
antialias = False ,
303
+ crop_to_aspect_ratio = False ,
304
+ pad_to_aspect_ratio = False ,
305
+ fill_mode = "constant" ,
306
+ fill_value = 0.0 ,
303
307
data_format = "channels_last" ,
304
308
):
305
309
if antialias :
306
310
raise NotImplementedError (
307
311
"Antialiasing not implemented for the MLX backend"
308
312
)
309
-
313
+ if pad_to_aspect_ratio and crop_to_aspect_ratio :
314
+ raise ValueError (
315
+ "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` "
316
+ "can be `True`."
317
+ )
310
318
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS .keys ():
311
319
raise ValueError (
312
320
"Invalid value for argument `interpolation`. Expected of one "
313
321
f"{ set (AFFINE_TRANSFORM_INTERPOLATIONS .keys ())} . Received: "
314
322
f"interpolation={ interpolation } "
315
323
)
316
-
324
+ target_height , target_width = size
317
325
size = tuple (size )
318
326
image = convert_to_tensor (image )
319
327
@@ -324,6 +332,127 @@ def resize(
324
332
f"image.shape={ image .shape } "
325
333
)
326
334
335
+ if crop_to_aspect_ratio :
336
+ shape = image .shape
337
+ if data_format == "channels_last" :
338
+ height , width = shape [- 3 ], shape [- 2 ]
339
+ else :
340
+ height , width = shape [- 2 ], shape [- 1 ]
341
+ crop_height = int (float (width * target_height ) / target_width )
342
+ crop_height = min (height , crop_height )
343
+ crop_width = int (float (height * target_width ) / target_height )
344
+ crop_width = min (width , crop_width )
345
+ crop_box_hstart = int (float (height - crop_height ) / 2 )
346
+ crop_box_wstart = int (float (width - crop_width ) / 2 )
347
+ if data_format == "channels_last" :
348
+ if len (image .shape ) == 4 :
349
+ image = image [
350
+ :,
351
+ crop_box_hstart : crop_box_hstart + crop_height ,
352
+ crop_box_wstart : crop_box_wstart + crop_width ,
353
+ :,
354
+ ]
355
+ else :
356
+ image = image [
357
+ crop_box_hstart : crop_box_hstart + crop_height ,
358
+ crop_box_wstart : crop_box_wstart + crop_width ,
359
+ :,
360
+ ]
361
+ else :
362
+ if len (image .shape ) == 4 :
363
+ image = image [
364
+ :,
365
+ :,
366
+ crop_box_hstart : crop_box_hstart + crop_height ,
367
+ crop_box_wstart : crop_box_wstart + crop_width ,
368
+ ]
369
+ else :
370
+ image = image [
371
+ :,
372
+ crop_box_hstart : crop_box_hstart + crop_height ,
373
+ crop_box_wstart : crop_box_wstart + crop_width ,
374
+ ]
375
+ elif pad_to_aspect_ratio :
376
+ shape = image .shape
377
+ batch_size = image .shape [0 ]
378
+ if data_format == "channels_last" :
379
+ height , width , channels = shape [- 3 ], shape [- 2 ], shape [- 1 ]
380
+ else :
381
+ channels , height , width = shape [- 3 ], shape [- 2 ], shape [- 1 ]
382
+ pad_height = int (float (width * target_height ) / target_width )
383
+ pad_height = max (height , pad_height )
384
+ pad_width = int (float (height * target_width ) / target_height )
385
+ pad_width = max (width , pad_width )
386
+ img_box_hstart = int (float (pad_height - height ) / 2 )
387
+ img_box_wstart = int (float (pad_width - width ) / 2 )
388
+ if data_format == "channels_last" :
389
+ if len (image .shape ) == 4 :
390
+ padded_img = (
391
+ mx .ones (
392
+ (
393
+ batch_size ,
394
+ pad_height + height ,
395
+ pad_width + width ,
396
+ channels ,
397
+ ),
398
+ dtype = image .dtype ,
399
+ )
400
+ * fill_value
401
+ )
402
+ padded_img [
403
+ :,
404
+ img_box_hstart : img_box_hstart + height ,
405
+ img_box_wstart : img_box_wstart + width ,
406
+ :,
407
+ ] = image
408
+ else :
409
+ padded_img = (
410
+ mx .ones (
411
+ (pad_height + height , pad_width + width , channels ),
412
+ dtype = image .dtype ,
413
+ )
414
+ * fill_value
415
+ )
416
+ padded_img [
417
+ img_box_hstart : img_box_hstart + height ,
418
+ img_box_wstart : img_box_wstart + width ,
419
+ :,
420
+ ] = image
421
+ else :
422
+ if len (image .shape ) == 4 :
423
+ padded_img = (
424
+ mx .ones (
425
+ (
426
+ batch_size ,
427
+ channels ,
428
+ pad_height + height ,
429
+ pad_width + width ,
430
+ ),
431
+ dtype = image .dtype ,
432
+ )
433
+ * fill_value
434
+ )
435
+ padded_img [
436
+ :,
437
+ :,
438
+ img_box_hstart : img_box_hstart + height ,
439
+ img_box_wstart : img_box_wstart + width ,
440
+ ] = image
441
+ else :
442
+ padded_img = (
443
+ mx .ones (
444
+ (channels , pad_height + height , pad_width + width ),
445
+ dtype = image .dtype ,
446
+ )
447
+ * fill_value
448
+ )
449
+ padded_img [
450
+ :,
451
+ img_box_hstart : img_box_hstart + height ,
452
+ img_box_wstart : img_box_wstart + width ,
453
+ ] = image
454
+ image = padded_img
455
+
327
456
# Change to channels_last
328
457
if data_format == "channels_first" :
329
458
image = (
0 commit comments