3
3
4
4
from keras .src .backend import standardize_data_format
5
5
from keras .src .backend import standardize_dtype
6
+ from keras .src .backend .common .backend_utils import (
7
+ compute_conv_transpose_padding_args_for_mlx ,
8
+ )
9
+ from keras .src .backend .common .backend_utils import (
10
+ compute_transpose_padding_args_for_mlx ,
11
+ )
6
12
from keras .src .backend .config import epsilon
7
13
from keras .src .backend .mlx .core import convert_to_tensor
8
14
from keras .src .backend .mlx .core import to_mlx_dtype
@@ -148,25 +154,15 @@ def conv(
148
154
# mlx expects kernel with (out_channels, spatial..., in_channels)
149
155
kernel = kernel .transpose (- 1 , * range (kernel .ndim - 2 ), - 2 )
150
156
151
- if padding == "valid" :
152
- mlx_padding = 0
153
- elif padding == "same" :
154
- kernel_spatial_shape = kernel .shape [1 :- 1 ]
155
- start_paddings = []
156
- end_paddings = []
157
- for dim_size , k_size , d_rate , s in zip (
158
- inputs .shape [1 :- 1 ], kernel_spatial_shape , dilation_rate , strides
159
- ):
160
- out_size = (dim_size + s - 1 ) // s
161
- effective_k_size = (k_size - 1 ) * d_rate + 1
162
- total_pad = max (0 , (out_size - 1 ) * s + effective_k_size - dim_size )
163
- pad_start = total_pad // 2
164
- pad_end = total_pad - pad_start
165
- start_paddings .append (pad_start )
166
- end_paddings .append (pad_end )
167
- mlx_padding = (start_paddings , end_paddings )
168
- else :
169
- raise ValueError (f"Invalid padding value: { padding } " )
157
+ kernel_spatial_shape = kernel .shape [1 :- 1 ]
158
+ input_spatial_shape = inputs .shape [1 :- 1 ]
159
+ mlx_padding = compute_transpose_padding_args_for_mlx (
160
+ padding ,
161
+ input_spatial_shape ,
162
+ kernel_spatial_shape ,
163
+ dilation_rate ,
164
+ strides ,
165
+ )
170
166
171
167
channels = inputs .shape [- 1 ]
172
168
kernel_in_channels = kernel .shape [- 1 ]
@@ -202,7 +198,53 @@ def depthwise_conv(
202
198
data_format = None ,
203
199
dilation_rate = 1 ,
204
200
):
205
- raise NotImplementedError ("MLX backend doesn't support depthwise conv yet" )
201
+ inputs = convert_to_tensor (inputs )
202
+ kernel = convert_to_tensor (kernel )
203
+ data_format = standardize_data_format (data_format )
204
+ num_spatial_dims = inputs .ndim - 2
205
+
206
+ strides = standardize_tuple (strides , num_spatial_dims , "strides" )
207
+ dilation_rate = standardize_tuple (
208
+ dilation_rate , num_spatial_dims , "dilation_rate"
209
+ )
210
+
211
+ if data_format == "channels_first" :
212
+ # mlx expects channels_last
213
+ inputs = inputs .transpose (0 , * range (2 , inputs .ndim ), 1 )
214
+
215
+ feature_group_count = inputs .shape [- 1 ]
216
+
217
+ # reshape first for depthwise conv, then transpose to expected mlx format
218
+ kernel = kernel .reshape (
219
+ * iter (kernel .shape [:- 2 ]), 1 , feature_group_count * kernel .shape [- 1 ]
220
+ )
221
+ # mlx expects kernel with (out_channels, spatial..., in_channels)
222
+ kernel = kernel .transpose (- 1 , * range (kernel .ndim - 2 ), - 2 )
223
+
224
+ kernel_spatial_shape = kernel .shape [1 :- 1 ]
225
+ input_spatial_shape = inputs .shape [1 :- 1 ]
226
+ mlx_padding = compute_transpose_padding_args_for_mlx (
227
+ padding ,
228
+ input_spatial_shape ,
229
+ kernel_spatial_shape ,
230
+ dilation_rate ,
231
+ strides ,
232
+ )
233
+
234
+ result = mx .conv_general (
235
+ inputs ,
236
+ kernel ,
237
+ stride = strides ,
238
+ padding = mlx_padding ,
239
+ kernel_dilation = dilation_rate ,
240
+ input_dilation = 1 ,
241
+ groups = feature_group_count ,
242
+ flip = False ,
243
+ )
244
+ if data_format == "channels_first" :
245
+ result = result .transpose (0 , - 1 , * range (1 , result .ndim - 1 ))
246
+
247
+ return result
206
248
207
249
208
250
def separable_conv (
@@ -214,7 +256,23 @@ def separable_conv(
214
256
data_format = None ,
215
257
dilation_rate = 1 ,
216
258
):
217
- raise NotImplementedError ("MLX backend doesn't support separable conv yet" )
259
+ data_format = standardize_data_format (data_format )
260
+ depthwise_conv_output = depthwise_conv (
261
+ inputs ,
262
+ depthwise_kernel ,
263
+ strides ,
264
+ padding ,
265
+ data_format ,
266
+ dilation_rate ,
267
+ )
268
+ return conv (
269
+ depthwise_conv_output ,
270
+ pointwise_kernel ,
271
+ strides = 1 ,
272
+ padding = "valid" ,
273
+ data_format = data_format ,
274
+ dilation_rate = dilation_rate ,
275
+ )
218
276
219
277
220
278
def conv_transpose (
@@ -226,7 +284,62 @@ def conv_transpose(
226
284
data_format = None ,
227
285
dilation_rate = 1 ,
228
286
):
229
- raise NotImplementedError ("MLX backend doesn't support conv transpose yet" )
287
+ inputs = convert_to_tensor (inputs )
288
+ kernel = convert_to_tensor (kernel )
289
+ data_format = standardize_data_format (data_format )
290
+ num_spatial_dims = inputs .ndim - 2
291
+
292
+ strides = standardize_tuple (strides , num_spatial_dims , "strides" )
293
+ dilation_rate = standardize_tuple (
294
+ dilation_rate , num_spatial_dims , "dilation_rate"
295
+ )
296
+ if output_padding is not None :
297
+ output_padding = standardize_tuple (
298
+ output_padding , num_spatial_dims , "output_padding"
299
+ )
300
+
301
+ if data_format == "channels_first" :
302
+ # mlx expects channels_last
303
+ inputs = inputs .transpose (0 , * range (2 , inputs .ndim ), 1 )
304
+
305
+ # mlx expects kernel with (out_channels, spatial..., in_channels)
306
+ kernel = kernel .transpose (- 2 , * range (kernel .ndim - 2 ), - 1 )
307
+ kernel_spatial_shape = kernel .shape [1 :- 1 ]
308
+
309
+ mlx_padding = compute_conv_transpose_padding_args_for_mlx (
310
+ padding ,
311
+ num_spatial_dims ,
312
+ kernel_spatial_shape ,
313
+ dilation_rate ,
314
+ strides ,
315
+ output_padding ,
316
+ )
317
+
318
+ channels = inputs .shape [- 1 ]
319
+ kernel_in_channels = kernel .shape [- 1 ]
320
+ if channels % kernel_in_channels > 0 :
321
+ raise ValueError (
322
+ "The number of input channels must be evenly divisible by "
323
+ f"kernel's in_channels. Received input channels { channels } and "
324
+ f"kernel in_channels { kernel_in_channels } . "
325
+ )
326
+ groups = channels // kernel_in_channels
327
+
328
+ result = mx .conv_general (
329
+ inputs ,
330
+ kernel ,
331
+ stride = 1 , # stride is handled by input_dilation
332
+ padding = mlx_padding ,
333
+ kernel_dilation = dilation_rate ,
334
+ input_dilation = strides ,
335
+ groups = groups ,
336
+ flip = True ,
337
+ )
338
+
339
+ if data_format == "channels_first" :
340
+ result = result .transpose (0 , - 1 , * range (1 , result .ndim - 1 ))
341
+
342
+ return result
230
343
231
344
232
345
def one_hot (x , num_classes , axis = - 1 , dtype = "float32" , sparse = False ):
0 commit comments