7
7
from keras .src import backend
8
8
from keras .src .backend .mlx .core import convert_to_tensor
9
9
from keras .src .backend .mlx .core import to_mlx_dtype
10
+ from keras .src .backend .mlx .random import mlx_draw_seed
10
11
11
12
12
13
def rgb_to_grayscale (images , data_format = None ):
@@ -657,17 +658,55 @@ def _compute_weight_mat(
657
658
)
658
659
659
660
660
- def elastic_transform (
661
- images ,
662
- alpha = 20.0 ,
663
- sigma = 5.0 ,
664
- interpolation = "bilinear" ,
665
- fill_mode = "reflect" ,
666
- fill_value = 0.0 ,
667
- seed = None ,
668
- data_format = None ,
669
- ):
670
- raise NotImplementedError ("elastic_transform not yet implemented in mlx." )
661
+ def compute_homography_matrix (start_points , end_points ):
662
+ # as implemented for the jax backend
663
+ start_points = convert_to_tensor (start_points , dtype = mx .float32 )
664
+ end_points = convert_to_tensor (end_points , dtype = mx .float32 )
665
+
666
+ start_x , start_y = start_points [..., 0 ], start_points [..., 1 ]
667
+ end_x , end_y = end_points [..., 0 ], end_points [..., 1 ]
668
+
669
+ zeros = mx .zeros_like (end_x )
670
+ ones = mx .ones_like (end_x )
671
+
672
+ x_rows = mx .stack (
673
+ [
674
+ end_x ,
675
+ end_y ,
676
+ ones ,
677
+ zeros ,
678
+ zeros ,
679
+ zeros ,
680
+ - start_x * end_x ,
681
+ - start_x * end_y ,
682
+ ],
683
+ axis = - 1 ,
684
+ )
685
+ y_rows = mx .stack (
686
+ [
687
+ zeros ,
688
+ zeros ,
689
+ zeros ,
690
+ end_x ,
691
+ end_y ,
692
+ ones ,
693
+ - start_y * end_x ,
694
+ - start_y * end_y ,
695
+ ],
696
+ axis = - 1 ,
697
+ )
698
+
699
+ coefficient_matrix = mx .concatenate ([x_rows , y_rows ], axis = 1 )
700
+
701
+ target_vector = mx .expand_dims (
702
+ mx .concatenate ([start_x , start_y ], axis = - 1 ), axis = - 1
703
+ )
704
+
705
+ # solve the linear system: coefficient_matrix * homography = target_vector
706
+ with mx .stream (mx .cpu ):
707
+ homography_matrix = mx .linalg .solve (coefficient_matrix , target_vector )
708
+
709
+ return homography_matrix .squeeze (- 1 )
671
710
672
711
673
712
def perspective_transform (
@@ -678,12 +717,314 @@ def perspective_transform(
678
717
fill_value = 0 ,
679
718
data_format = None ,
680
719
):
681
- raise NotImplementedError (
682
- "perspective_transform not yet implemented in mlx."
720
+ # perspective_transform based on implementation in jax backend
721
+ data_format = backend .standardize_data_format (data_format )
722
+ if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS .keys ():
723
+ raise ValueError (
724
+ "Invalid value for argument `interpolation`. Expected one of "
725
+ f"{ set (AFFINE_TRANSFORM_INTERPOLATIONS .keys ())} . Received: "
726
+ f"interpolation={ interpolation } "
727
+ )
728
+
729
+ if len (images .shape ) not in (3 , 4 ):
730
+ raise ValueError (
731
+ "Invalid images rank: expected rank 3 (single image) "
732
+ "or rank 4 (batch of images). Received input with shape: "
733
+ f"images.shape={ images .shape } "
734
+ )
735
+
736
+ if start_points .shape [- 2 :] != (4 , 2 ) or start_points .ndim not in (2 , 3 ):
737
+ raise ValueError (
738
+ "Invalid start_points shape: expected (4,2) for a single image"
739
+ f" or (N,4,2) for a batch. Received shape: { start_points .shape } "
740
+ )
741
+ if end_points .shape [- 2 :] != (4 , 2 ) or end_points .ndim not in (2 , 3 ):
742
+ raise ValueError (
743
+ "Invalid end_points shape: expected (4,2) for a single image"
744
+ f" or (N,4,2) for a batch. Received shape: { end_points .shape } "
745
+ )
746
+ if start_points .shape != end_points .shape :
747
+ raise ValueError (
748
+ "start_points and end_points must have the same shape."
749
+ f" Received start_points.shape={ start_points .shape } , "
750
+ f"end_points.shape={ end_points .shape } "
751
+ )
752
+
753
+ images = convert_to_tensor (images )
754
+ start_points = convert_to_tensor (start_points )
755
+ end_points = convert_to_tensor (end_points )
756
+
757
+ need_squeeze = False
758
+ if len (images .shape ) == 3 :
759
+ images = mx .expand_dims (images , axis = 0 )
760
+ need_squeeze = True
761
+
762
+ if len (start_points .shape ) == 2 :
763
+ start_points = mx .expand_dims (start_points , axis = 0 )
764
+ if len (end_points .shape ) == 2 :
765
+ end_points = mx .expand_dims (end_points , axis = 0 )
766
+
767
+ if data_format == "channels_first" :
768
+ images = mx .transpose (images , (0 , 2 , 3 , 1 ))
769
+
770
+ batch_size , height , width , channels = images .shape
771
+
772
+ transforms = compute_homography_matrix (
773
+ mx .array (start_points , dtype = mx .float32 ),
774
+ mx .array (end_points , dtype = mx .float32 ),
775
+ )
776
+
777
+ x , y = mx .meshgrid (mx .arange (width ), mx .arange (height ), indexing = "xy" )
778
+ grid = mx .stack (
779
+ [x .flatten (), y .flatten (), mx .ones_like (x ).flatten ()], axis = 0
683
780
)
684
781
782
+ outputs = []
783
+ for b in range (batch_size ):
784
+ transform = transforms [b ]
785
+
786
+ # apply homography to grid coordinates
787
+ denom = transform [6 ] * grid [0 ] + transform [7 ] * grid [1 ] + 1.0
788
+ x_in = (
789
+ transform [0 ] * grid [0 ] + transform [1 ] * grid [1 ] + transform [2 ]
790
+ ) / denom
791
+ y_in = (
792
+ transform [3 ] * grid [0 ] + transform [4 ] * grid [1 ] + transform [5 ]
793
+ ) / denom
794
+
795
+ coords = mx .stack ([y_in , x_in ], axis = 0 )
796
+
797
+ transformed = mx .zeros ((height , width , channels ), dtype = images .dtype )
798
+ for c in range (channels ):
799
+ transformed_channel = map_coordinates (
800
+ images [b , :, :, c ],
801
+ coords ,
802
+ order = AFFINE_TRANSFORM_INTERPOLATIONS [interpolation ],
803
+ fill_mode = "constant" ,
804
+ fill_value = fill_value ,
805
+ ).reshape (height , width )
806
+
807
+ transformed = transformed .at [:, :, c ].add (transformed_channel )
808
+
809
+ outputs .append (transformed )
810
+
811
+ output = mx .stack (outputs , axis = 0 )
812
+
813
+ if data_format == "channels_first" :
814
+ output = mx .transpose (output , (0 , 3 , 1 , 2 ))
815
+ if need_squeeze :
816
+ output = mx .squeeze (output , axis = 0 )
817
+
818
+ return output
819
+
685
820
686
821
def gaussian_blur (
687
822
images , kernel_size = (3 , 3 ), sigma = (1.0 , 1.0 ), data_format = None
688
823
):
689
- raise NotImplementedError ("gaussian_blur not yet implemented in mlx." )
824
+ # gaussian_blur similar to jax backend
825
+ def _create_gaussian_kernel (kernel_size , sigma , dtype , num_channels ):
826
+ def _get_gaussian_kernel1d (size , sigma ):
827
+ x = mx .arange (size , dtype = dtype ) - (size - 1 ) / 2
828
+ kernel1d = mx .exp (- 0.5 * (x / sigma ) ** 2 )
829
+ return kernel1d / mx .sum (kernel1d )
830
+
831
+ def _get_gaussian_kernel2d (size , sigma ):
832
+ kernel1d_x = _get_gaussian_kernel1d (size [0 ], sigma [0 ])
833
+ kernel1d_y = _get_gaussian_kernel1d (size [1 ], sigma [1 ])
834
+ return mx .outer (kernel1d_y , kernel1d_x )
835
+
836
+ kernel2d = _get_gaussian_kernel2d (kernel_size , sigma )
837
+
838
+ # mlx expects kernel with shape (C_out, spatial..., C_in)
839
+ # for depthwise convolution with groups=C, we need (C, H, W, 1)
840
+ kernel = kernel2d .reshape (1 , kernel_size [0 ], kernel_size [1 ], 1 )
841
+ kernel = mx .tile (kernel , (num_channels , 1 , 1 , 1 ))
842
+
843
+ return kernel
844
+
845
+ if len (images .shape ) not in (3 , 4 ):
846
+ raise ValueError (
847
+ "Invalid images rank: expected rank 3 (single image) "
848
+ "or rank 4 (batch of images). Received input with shape: "
849
+ f"images.shape={ images .shape } "
850
+ )
851
+
852
+ data_format = backend .standardize_data_format (data_format )
853
+ images = convert_to_tensor (images )
854
+ sigma = convert_to_tensor (sigma )
855
+ dtype = images .dtype
856
+
857
+ need_squeeze = False
858
+ if images .ndim == 3 :
859
+ images = images [mx .newaxis , ...]
860
+ need_squeeze = True
861
+
862
+ if data_format == "channels_first" :
863
+ images = mx .transpose (images , (0 , 2 , 3 , 1 ))
864
+
865
+ num_channels = images .shape [- 1 ]
866
+
867
+ # mx.arange can only take integer input values
868
+ kernel_size = tuple (int (k ) for k in kernel_size )
869
+ kernel = _create_gaussian_kernel (kernel_size , sigma , dtype , num_channels )
870
+
871
+ # get padding for 'same' behavior
872
+ pad_h = max (0 , (kernel_size [0 ] - 1 ) // 2 )
873
+ pad_w = max (0 , (kernel_size [1 ] - 1 ) // 2 )
874
+ padding = ((pad_h , pad_h ), (pad_w , pad_w ))
875
+
876
+ blurred_images = mx .conv_general (
877
+ images ,
878
+ kernel ,
879
+ stride = 1 ,
880
+ padding = padding ,
881
+ kernel_dilation = 1 ,
882
+ input_dilation = 1 ,
883
+ groups = num_channels ,
884
+ flip = False ,
885
+ )
886
+
887
+ if data_format == "channels_first" :
888
+ blurred_images = mx .transpose (blurred_images , (0 , 3 , 1 , 2 ))
889
+
890
+ if need_squeeze :
891
+ blurred_images = mx .squeeze (blurred_images , axis = 0 )
892
+
893
+ return blurred_images
894
+
895
+
896
+ def elastic_transform (
897
+ images ,
898
+ alpha = 20.0 ,
899
+ sigma = 5.0 ,
900
+ interpolation = "bilinear" ,
901
+ fill_mode = "reflect" ,
902
+ fill_value = 0.0 ,
903
+ seed = None ,
904
+ data_format = None ,
905
+ ):
906
+ # elastic_transform based on implementation in jax backend
907
+ data_format = backend .standardize_data_format (data_format )
908
+ if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS :
909
+ raise ValueError (
910
+ "Invalid value for argument `interpolation`. Expected one of "
911
+ f"{ set (AFFINE_TRANSFORM_INTERPOLATIONS .keys ())} . Received: "
912
+ f"interpolation={ interpolation } "
913
+ )
914
+ if fill_mode not in AFFINE_TRANSFORM_FILL_MODES :
915
+ raise ValueError (
916
+ "Invalid value for argument `fill_mode`. Expected one of "
917
+ f"{ AFFINE_TRANSFORM_FILL_MODES } . Received: fill_mode={ fill_mode } "
918
+ )
919
+ if len (images .shape ) not in (3 , 4 ):
920
+ raise ValueError (
921
+ "Invalid images rank: expected rank 3 (single image) "
922
+ "or rank 4 (batch of images). Received input with shape: "
923
+ f"images.shape={ images .shape } "
924
+ )
925
+
926
+ images = convert_to_tensor (images )
927
+ alpha = convert_to_tensor (alpha )
928
+ sigma = convert_to_tensor (sigma )
929
+ input_dtype = images .dtype
930
+ kernel_size = (int (6 * sigma ) | 1 , int (6 * sigma ) | 1 )
931
+
932
+ need_squeeze = False
933
+ if len (images .shape ) == 3 :
934
+ images = mx .expand_dims (images , axis = 0 )
935
+ need_squeeze = True
936
+
937
+ if data_format == "channels_last" :
938
+ batch_size , height , width , channels = images .shape
939
+ channel_axis = - 1
940
+ else :
941
+ batch_size , channels , height , width = images .shape
942
+ channel_axis = 1
943
+
944
+ mlx_seed = mlx_draw_seed (seed )
945
+ if mlx_seed is not None :
946
+ seed_dx , seed_dy = mx .random .split (mlx_seed )
947
+ else :
948
+ seed_dx , seed_dy = mlx_draw_seed (None ), mlx_draw_seed (None )
949
+
950
+ dx = mx .random .normal (
951
+ shape = (batch_size , height , width ),
952
+ loc = 0.0 ,
953
+ scale = sigma ,
954
+ dtype = input_dtype ,
955
+ key = seed_dx ,
956
+ )
957
+
958
+ dy = mx .random .normal (
959
+ shape = (batch_size , height , width ),
960
+ loc = 0.0 ,
961
+ scale = sigma ,
962
+ dtype = input_dtype ,
963
+ key = seed_dy ,
964
+ )
965
+
966
+ dx = gaussian_blur (
967
+ mx .expand_dims (dx , axis = channel_axis ),
968
+ kernel_size = kernel_size ,
969
+ sigma = (sigma , sigma ),
970
+ data_format = data_format ,
971
+ )
972
+ dy = gaussian_blur (
973
+ mx .expand_dims (dy , axis = channel_axis ),
974
+ kernel_size = kernel_size ,
975
+ sigma = (sigma , sigma ),
976
+ data_format = data_format ,
977
+ )
978
+
979
+ dx = mx .squeeze (dx , axis = channel_axis )
980
+ dy = mx .squeeze (dy , axis = channel_axis )
981
+
982
+ x_vals = mx .arange (width )
983
+ y_vals = mx .arange (height )
984
+ x , y = mx .meshgrid (x_vals , y_vals , indexing = "xy" )
985
+ x = mx .expand_dims (x , axis = 0 )
986
+ y = mx .expand_dims (y , axis = 0 )
987
+
988
+ distorted_x = x + alpha * dx
989
+ distorted_y = y + alpha * dy
990
+
991
+ transformed_images = mx .zeros_like (images )
992
+ if data_format == "channels_last" :
993
+ for i in range (channels ):
994
+ transformed_channel = []
995
+ for b in range (batch_size ):
996
+ transformed_channel .append (
997
+ map_coordinates (
998
+ images [b , :, :, i ],
999
+ [distorted_y [b ], distorted_x [b ]],
1000
+ order = AFFINE_TRANSFORM_INTERPOLATIONS [interpolation ],
1001
+ fill_mode = fill_mode ,
1002
+ fill_value = fill_value ,
1003
+ )
1004
+ )
1005
+ transformed_images = transformed_images .at [:, :, :, i ].add (
1006
+ mx .stack (transformed_channel )
1007
+ )
1008
+ else : # channels_first
1009
+ for i in range (channels ):
1010
+ transformed_channel = []
1011
+ for b in range (batch_size ):
1012
+ transformed_channel .append (
1013
+ map_coordinates (
1014
+ images [b , i , :, :],
1015
+ [distorted_y [b ], distorted_x [b ]],
1016
+ order = AFFINE_TRANSFORM_INTERPOLATIONS [interpolation ],
1017
+ fill_mode = fill_mode ,
1018
+ fill_value = fill_value ,
1019
+ )
1020
+ )
1021
+ transformed_images = transformed_images .at [:, i , :, :].add (
1022
+ mx .stack (transformed_channel )
1023
+ )
1024
+
1025
+ if need_squeeze :
1026
+ transformed_images = mx .squeeze (transformed_images , axis = 0 )
1027
+
1028
+ transformed_images = transformed_images .astype (input_dtype )
1029
+
1030
+ return transformed_images
0 commit comments