Skip to content

Commit 4970a00

Browse files
authored
mlx - new image ops implemented (#21202)
* image ops implemented * quick patch
1 parent fe4ba17 commit 4970a00

File tree

2 files changed

+358
-16
lines changed

2 files changed

+358
-16
lines changed

keras/src/backend/mlx/image.py

Lines changed: 355 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src import backend
88
from keras.src.backend.mlx.core import convert_to_tensor
99
from keras.src.backend.mlx.core import to_mlx_dtype
10+
from keras.src.backend.mlx.random import mlx_draw_seed
1011

1112

1213
def rgb_to_grayscale(images, data_format=None):
@@ -657,17 +658,55 @@ def _compute_weight_mat(
657658
)
658659

659660

660-
def elastic_transform(
661-
images,
662-
alpha=20.0,
663-
sigma=5.0,
664-
interpolation="bilinear",
665-
fill_mode="reflect",
666-
fill_value=0.0,
667-
seed=None,
668-
data_format=None,
669-
):
670-
raise NotImplementedError("elastic_transform not yet implemented in mlx.")
661+
def compute_homography_matrix(start_points, end_points):
662+
# as implemented for the jax backend
663+
start_points = convert_to_tensor(start_points, dtype=mx.float32)
664+
end_points = convert_to_tensor(end_points, dtype=mx.float32)
665+
666+
start_x, start_y = start_points[..., 0], start_points[..., 1]
667+
end_x, end_y = end_points[..., 0], end_points[..., 1]
668+
669+
zeros = mx.zeros_like(end_x)
670+
ones = mx.ones_like(end_x)
671+
672+
x_rows = mx.stack(
673+
[
674+
end_x,
675+
end_y,
676+
ones,
677+
zeros,
678+
zeros,
679+
zeros,
680+
-start_x * end_x,
681+
-start_x * end_y,
682+
],
683+
axis=-1,
684+
)
685+
y_rows = mx.stack(
686+
[
687+
zeros,
688+
zeros,
689+
zeros,
690+
end_x,
691+
end_y,
692+
ones,
693+
-start_y * end_x,
694+
-start_y * end_y,
695+
],
696+
axis=-1,
697+
)
698+
699+
coefficient_matrix = mx.concatenate([x_rows, y_rows], axis=1)
700+
701+
target_vector = mx.expand_dims(
702+
mx.concatenate([start_x, start_y], axis=-1), axis=-1
703+
)
704+
705+
# solve the linear system: coefficient_matrix * homography = target_vector
706+
with mx.stream(mx.cpu):
707+
homography_matrix = mx.linalg.solve(coefficient_matrix, target_vector)
708+
709+
return homography_matrix.squeeze(-1)
671710

672711

673712
def perspective_transform(
@@ -678,12 +717,314 @@ def perspective_transform(
678717
fill_value=0,
679718
data_format=None,
680719
):
681-
raise NotImplementedError(
682-
"perspective_transform not yet implemented in mlx."
720+
# perspective_transform based on implementation in jax backend
721+
data_format = backend.standardize_data_format(data_format)
722+
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
723+
raise ValueError(
724+
"Invalid value for argument `interpolation`. Expected one of "
725+
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
726+
f"interpolation={interpolation}"
727+
)
728+
729+
if len(images.shape) not in (3, 4):
730+
raise ValueError(
731+
"Invalid images rank: expected rank 3 (single image) "
732+
"or rank 4 (batch of images). Received input with shape: "
733+
f"images.shape={images.shape}"
734+
)
735+
736+
if start_points.shape[-2:] != (4, 2) or start_points.ndim not in (2, 3):
737+
raise ValueError(
738+
"Invalid start_points shape: expected (4,2) for a single image"
739+
f" or (N,4,2) for a batch. Received shape: {start_points.shape}"
740+
)
741+
if end_points.shape[-2:] != (4, 2) or end_points.ndim not in (2, 3):
742+
raise ValueError(
743+
"Invalid end_points shape: expected (4,2) for a single image"
744+
f" or (N,4,2) for a batch. Received shape: {end_points.shape}"
745+
)
746+
if start_points.shape != end_points.shape:
747+
raise ValueError(
748+
"start_points and end_points must have the same shape."
749+
f" Received start_points.shape={start_points.shape}, "
750+
f"end_points.shape={end_points.shape}"
751+
)
752+
753+
images = convert_to_tensor(images)
754+
start_points = convert_to_tensor(start_points)
755+
end_points = convert_to_tensor(end_points)
756+
757+
need_squeeze = False
758+
if len(images.shape) == 3:
759+
images = mx.expand_dims(images, axis=0)
760+
need_squeeze = True
761+
762+
if len(start_points.shape) == 2:
763+
start_points = mx.expand_dims(start_points, axis=0)
764+
if len(end_points.shape) == 2:
765+
end_points = mx.expand_dims(end_points, axis=0)
766+
767+
if data_format == "channels_first":
768+
images = mx.transpose(images, (0, 2, 3, 1))
769+
770+
batch_size, height, width, channels = images.shape
771+
772+
transforms = compute_homography_matrix(
773+
mx.array(start_points, dtype=mx.float32),
774+
mx.array(end_points, dtype=mx.float32),
775+
)
776+
777+
x, y = mx.meshgrid(mx.arange(width), mx.arange(height), indexing="xy")
778+
grid = mx.stack(
779+
[x.flatten(), y.flatten(), mx.ones_like(x).flatten()], axis=0
683780
)
684781

782+
outputs = []
783+
for b in range(batch_size):
784+
transform = transforms[b]
785+
786+
# apply homography to grid coordinates
787+
denom = transform[6] * grid[0] + transform[7] * grid[1] + 1.0
788+
x_in = (
789+
transform[0] * grid[0] + transform[1] * grid[1] + transform[2]
790+
) / denom
791+
y_in = (
792+
transform[3] * grid[0] + transform[4] * grid[1] + transform[5]
793+
) / denom
794+
795+
coords = mx.stack([y_in, x_in], axis=0)
796+
797+
transformed = mx.zeros((height, width, channels), dtype=images.dtype)
798+
for c in range(channels):
799+
transformed_channel = map_coordinates(
800+
images[b, :, :, c],
801+
coords,
802+
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
803+
fill_mode="constant",
804+
fill_value=fill_value,
805+
).reshape(height, width)
806+
807+
transformed = transformed.at[:, :, c].add(transformed_channel)
808+
809+
outputs.append(transformed)
810+
811+
output = mx.stack(outputs, axis=0)
812+
813+
if data_format == "channels_first":
814+
output = mx.transpose(output, (0, 3, 1, 2))
815+
if need_squeeze:
816+
output = mx.squeeze(output, axis=0)
817+
818+
return output
819+
685820

686821
def gaussian_blur(
687822
images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None
688823
):
689-
raise NotImplementedError("gaussian_blur not yet implemented in mlx.")
824+
# gaussian_blur similar to jax backend
825+
def _create_gaussian_kernel(kernel_size, sigma, dtype, num_channels):
826+
def _get_gaussian_kernel1d(size, sigma):
827+
x = mx.arange(size, dtype=dtype) - (size - 1) / 2
828+
kernel1d = mx.exp(-0.5 * (x / sigma) ** 2)
829+
return kernel1d / mx.sum(kernel1d)
830+
831+
def _get_gaussian_kernel2d(size, sigma):
832+
kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0])
833+
kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1])
834+
return mx.outer(kernel1d_y, kernel1d_x)
835+
836+
kernel2d = _get_gaussian_kernel2d(kernel_size, sigma)
837+
838+
# mlx expects kernel with shape (C_out, spatial..., C_in)
839+
# for depthwise convolution with groups=C, we need (C, H, W, 1)
840+
kernel = kernel2d.reshape(1, kernel_size[0], kernel_size[1], 1)
841+
kernel = mx.tile(kernel, (num_channels, 1, 1, 1))
842+
843+
return kernel
844+
845+
if len(images.shape) not in (3, 4):
846+
raise ValueError(
847+
"Invalid images rank: expected rank 3 (single image) "
848+
"or rank 4 (batch of images). Received input with shape: "
849+
f"images.shape={images.shape}"
850+
)
851+
852+
data_format = backend.standardize_data_format(data_format)
853+
images = convert_to_tensor(images)
854+
sigma = convert_to_tensor(sigma)
855+
dtype = images.dtype
856+
857+
need_squeeze = False
858+
if images.ndim == 3:
859+
images = images[mx.newaxis, ...]
860+
need_squeeze = True
861+
862+
if data_format == "channels_first":
863+
images = mx.transpose(images, (0, 2, 3, 1))
864+
865+
num_channels = images.shape[-1]
866+
867+
# mx.arange can only take integer input values
868+
kernel_size = tuple(int(k) for k in kernel_size)
869+
kernel = _create_gaussian_kernel(kernel_size, sigma, dtype, num_channels)
870+
871+
# get padding for 'same' behavior
872+
pad_h = max(0, (kernel_size[0] - 1) // 2)
873+
pad_w = max(0, (kernel_size[1] - 1) // 2)
874+
padding = ((pad_h, pad_h), (pad_w, pad_w))
875+
876+
blurred_images = mx.conv_general(
877+
images,
878+
kernel,
879+
stride=1,
880+
padding=padding,
881+
kernel_dilation=1,
882+
input_dilation=1,
883+
groups=num_channels,
884+
flip=False,
885+
)
886+
887+
if data_format == "channels_first":
888+
blurred_images = mx.transpose(blurred_images, (0, 3, 1, 2))
889+
890+
if need_squeeze:
891+
blurred_images = mx.squeeze(blurred_images, axis=0)
892+
893+
return blurred_images
894+
895+
896+
def elastic_transform(
897+
images,
898+
alpha=20.0,
899+
sigma=5.0,
900+
interpolation="bilinear",
901+
fill_mode="reflect",
902+
fill_value=0.0,
903+
seed=None,
904+
data_format=None,
905+
):
906+
# elastic_transform based on implementation in jax backend
907+
data_format = backend.standardize_data_format(data_format)
908+
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:
909+
raise ValueError(
910+
"Invalid value for argument `interpolation`. Expected one of "
911+
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
912+
f"interpolation={interpolation}"
913+
)
914+
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:
915+
raise ValueError(
916+
"Invalid value for argument `fill_mode`. Expected one of "
917+
f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}"
918+
)
919+
if len(images.shape) not in (3, 4):
920+
raise ValueError(
921+
"Invalid images rank: expected rank 3 (single image) "
922+
"or rank 4 (batch of images). Received input with shape: "
923+
f"images.shape={images.shape}"
924+
)
925+
926+
images = convert_to_tensor(images)
927+
alpha = convert_to_tensor(alpha)
928+
sigma = convert_to_tensor(sigma)
929+
input_dtype = images.dtype
930+
kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1)
931+
932+
need_squeeze = False
933+
if len(images.shape) == 3:
934+
images = mx.expand_dims(images, axis=0)
935+
need_squeeze = True
936+
937+
if data_format == "channels_last":
938+
batch_size, height, width, channels = images.shape
939+
channel_axis = -1
940+
else:
941+
batch_size, channels, height, width = images.shape
942+
channel_axis = 1
943+
944+
mlx_seed = mlx_draw_seed(seed)
945+
if mlx_seed is not None:
946+
seed_dx, seed_dy = mx.random.split(mlx_seed)
947+
else:
948+
seed_dx, seed_dy = mlx_draw_seed(None), mlx_draw_seed(None)
949+
950+
dx = mx.random.normal(
951+
shape=(batch_size, height, width),
952+
loc=0.0,
953+
scale=sigma,
954+
dtype=input_dtype,
955+
key=seed_dx,
956+
)
957+
958+
dy = mx.random.normal(
959+
shape=(batch_size, height, width),
960+
loc=0.0,
961+
scale=sigma,
962+
dtype=input_dtype,
963+
key=seed_dy,
964+
)
965+
966+
dx = gaussian_blur(
967+
mx.expand_dims(dx, axis=channel_axis),
968+
kernel_size=kernel_size,
969+
sigma=(sigma, sigma),
970+
data_format=data_format,
971+
)
972+
dy = gaussian_blur(
973+
mx.expand_dims(dy, axis=channel_axis),
974+
kernel_size=kernel_size,
975+
sigma=(sigma, sigma),
976+
data_format=data_format,
977+
)
978+
979+
dx = mx.squeeze(dx, axis=channel_axis)
980+
dy = mx.squeeze(dy, axis=channel_axis)
981+
982+
x_vals = mx.arange(width)
983+
y_vals = mx.arange(height)
984+
x, y = mx.meshgrid(x_vals, y_vals, indexing="xy")
985+
x = mx.expand_dims(x, axis=0)
986+
y = mx.expand_dims(y, axis=0)
987+
988+
distorted_x = x + alpha * dx
989+
distorted_y = y + alpha * dy
990+
991+
transformed_images = mx.zeros_like(images)
992+
if data_format == "channels_last":
993+
for i in range(channels):
994+
transformed_channel = []
995+
for b in range(batch_size):
996+
transformed_channel.append(
997+
map_coordinates(
998+
images[b, :, :, i],
999+
[distorted_y[b], distorted_x[b]],
1000+
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
1001+
fill_mode=fill_mode,
1002+
fill_value=fill_value,
1003+
)
1004+
)
1005+
transformed_images = transformed_images.at[:, :, :, i].add(
1006+
mx.stack(transformed_channel)
1007+
)
1008+
else: # channels_first
1009+
for i in range(channels):
1010+
transformed_channel = []
1011+
for b in range(batch_size):
1012+
transformed_channel.append(
1013+
map_coordinates(
1014+
images[b, i, :, :],
1015+
[distorted_y[b], distorted_x[b]],
1016+
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
1017+
fill_mode=fill_mode,
1018+
fill_value=fill_value,
1019+
)
1020+
)
1021+
transformed_images = transformed_images.at[:, i, :, :].add(
1022+
mx.stack(transformed_channel)
1023+
)
1024+
1025+
if need_squeeze:
1026+
transformed_images = mx.squeeze(transformed_images, axis=0)
1027+
1028+
transformed_images = transformed_images.astype(input_dtype)
1029+
1030+
return transformed_images

0 commit comments

Comments
 (0)