Skip to content

Commit e2e2288

Browse files
authored
mlx - more conv (#20807)
* depthwise_conv implementation * implementation * clean and implemented
1 parent b8338f7 commit e2e2288

File tree

2 files changed

+217
-22
lines changed

2 files changed

+217
-22
lines changed

keras/src/backend/common/backend_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,88 @@ def compute_conv_transpose_padding_args_for_torch(
187187
return torch_paddings, torch_output_paddings
188188

189189

190+
def _convert_conv_tranpose_padding_args_from_keras_to_mlx(
191+
kernel_size, stride, dilation_rate, padding, output_padding
192+
):
193+
effective_k_size = (kernel_size - 1) * dilation_rate + 1
194+
if padding == "valid":
195+
output_padding = (
196+
max(effective_k_size, stride) - effective_k_size
197+
if output_padding is None
198+
else output_padding
199+
)
200+
pad_left = effective_k_size - 1
201+
pad_right = effective_k_size - 1 + output_padding
202+
elif padding == "same":
203+
if output_padding is None:
204+
total_pad = stride + effective_k_size - 2
205+
else:
206+
total_pad = (
207+
effective_k_size + effective_k_size % 2 - 2 + output_padding
208+
)
209+
pad_left = min(total_pad // 2 + total_pad % 2, effective_k_size - 1)
210+
pad_right = total_pad - pad_left
211+
else:
212+
raise ValueError(f"Invalid padding value: {padding}")
213+
return pad_left, pad_right
214+
215+
216+
def compute_conv_transpose_padding_args_for_mlx(
217+
padding,
218+
num_spatial_dims,
219+
kernel_spatial_shape,
220+
dilation_rate,
221+
strides,
222+
output_padding,
223+
):
224+
start_paddings = []
225+
end_paddings = []
226+
for i in range(num_spatial_dims):
227+
kernel_size_i = kernel_spatial_shape[i]
228+
stride_i = strides[i]
229+
dilation_rate_i = dilation_rate[i]
230+
output_padding_i = None if output_padding is None else output_padding[i]
231+
pad_left, pad_right = (
232+
_convert_conv_tranpose_padding_args_from_keras_to_mlx(
233+
kernel_size_i,
234+
stride_i,
235+
dilation_rate_i,
236+
padding,
237+
output_padding_i,
238+
)
239+
)
240+
start_paddings.append(pad_left)
241+
end_paddings.append(pad_right)
242+
return (start_paddings, end_paddings)
243+
244+
245+
def compute_transpose_padding_args_for_mlx(
246+
padding,
247+
input_spatial_shape,
248+
kernel_spatial_shape,
249+
dilation_rate,
250+
strides,
251+
):
252+
if padding == "valid":
253+
return 0
254+
elif padding == "same":
255+
start_paddings = []
256+
end_paddings = []
257+
for dim_size, k_size, d_rate, s in zip(
258+
input_spatial_shape, kernel_spatial_shape, dilation_rate, strides
259+
):
260+
out_size = (dim_size + s - 1) // s
261+
effective_k_size = (k_size - 1) * d_rate + 1
262+
total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size)
263+
pad_start = total_pad // 2
264+
pad_end = total_pad - pad_start
265+
start_paddings.append(pad_start)
266+
end_paddings.append(pad_end)
267+
return (start_paddings, end_paddings)
268+
else:
269+
raise ValueError(f"Invalid padding value: {padding}")
270+
271+
190272
def _get_output_shape_given_tf_padding(
191273
input_size, kernel_size, strides, padding, output_padding, dilation_rate
192274
):

keras/src/backend/mlx/nn.py

Lines changed: 135 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
from keras.src.backend import standardize_data_format
55
from keras.src.backend import standardize_dtype
6+
from keras.src.backend.common.backend_utils import (
7+
compute_conv_transpose_padding_args_for_mlx,
8+
)
9+
from keras.src.backend.common.backend_utils import (
10+
compute_transpose_padding_args_for_mlx,
11+
)
612
from keras.src.backend.config import epsilon
713
from keras.src.backend.mlx.core import convert_to_tensor
814
from keras.src.backend.mlx.core import to_mlx_dtype
@@ -148,25 +154,15 @@ def conv(
148154
# mlx expects kernel with (out_channels, spatial..., in_channels)
149155
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2)
150156

151-
if padding == "valid":
152-
mlx_padding = 0
153-
elif padding == "same":
154-
kernel_spatial_shape = kernel.shape[1:-1]
155-
start_paddings = []
156-
end_paddings = []
157-
for dim_size, k_size, d_rate, s in zip(
158-
inputs.shape[1:-1], kernel_spatial_shape, dilation_rate, strides
159-
):
160-
out_size = (dim_size + s - 1) // s
161-
effective_k_size = (k_size - 1) * d_rate + 1
162-
total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size)
163-
pad_start = total_pad // 2
164-
pad_end = total_pad - pad_start
165-
start_paddings.append(pad_start)
166-
end_paddings.append(pad_end)
167-
mlx_padding = (start_paddings, end_paddings)
168-
else:
169-
raise ValueError(f"Invalid padding value: {padding}")
157+
kernel_spatial_shape = kernel.shape[1:-1]
158+
input_spatial_shape = inputs.shape[1:-1]
159+
mlx_padding = compute_transpose_padding_args_for_mlx(
160+
padding,
161+
input_spatial_shape,
162+
kernel_spatial_shape,
163+
dilation_rate,
164+
strides,
165+
)
170166

171167
channels = inputs.shape[-1]
172168
kernel_in_channels = kernel.shape[-1]
@@ -202,7 +198,53 @@ def depthwise_conv(
202198
data_format=None,
203199
dilation_rate=1,
204200
):
205-
raise NotImplementedError("MLX backend doesn't support depthwise conv yet")
201+
inputs = convert_to_tensor(inputs)
202+
kernel = convert_to_tensor(kernel)
203+
data_format = standardize_data_format(data_format)
204+
num_spatial_dims = inputs.ndim - 2
205+
206+
strides = standardize_tuple(strides, num_spatial_dims, "strides")
207+
dilation_rate = standardize_tuple(
208+
dilation_rate, num_spatial_dims, "dilation_rate"
209+
)
210+
211+
if data_format == "channels_first":
212+
# mlx expects channels_last
213+
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)
214+
215+
feature_group_count = inputs.shape[-1]
216+
217+
# reshape first for depthwise conv, then transpose to expected mlx format
218+
kernel = kernel.reshape(
219+
*iter(kernel.shape[:-2]), 1, feature_group_count * kernel.shape[-1]
220+
)
221+
# mlx expects kernel with (out_channels, spatial..., in_channels)
222+
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2)
223+
224+
kernel_spatial_shape = kernel.shape[1:-1]
225+
input_spatial_shape = inputs.shape[1:-1]
226+
mlx_padding = compute_transpose_padding_args_for_mlx(
227+
padding,
228+
input_spatial_shape,
229+
kernel_spatial_shape,
230+
dilation_rate,
231+
strides,
232+
)
233+
234+
result = mx.conv_general(
235+
inputs,
236+
kernel,
237+
stride=strides,
238+
padding=mlx_padding,
239+
kernel_dilation=dilation_rate,
240+
input_dilation=1,
241+
groups=feature_group_count,
242+
flip=False,
243+
)
244+
if data_format == "channels_first":
245+
result = result.transpose(0, -1, *range(1, result.ndim - 1))
246+
247+
return result
206248

207249

208250
def separable_conv(
@@ -214,7 +256,23 @@ def separable_conv(
214256
data_format=None,
215257
dilation_rate=1,
216258
):
217-
raise NotImplementedError("MLX backend doesn't support separable conv yet")
259+
data_format = standardize_data_format(data_format)
260+
depthwise_conv_output = depthwise_conv(
261+
inputs,
262+
depthwise_kernel,
263+
strides,
264+
padding,
265+
data_format,
266+
dilation_rate,
267+
)
268+
return conv(
269+
depthwise_conv_output,
270+
pointwise_kernel,
271+
strides=1,
272+
padding="valid",
273+
data_format=data_format,
274+
dilation_rate=dilation_rate,
275+
)
218276

219277

220278
def conv_transpose(
@@ -226,7 +284,62 @@ def conv_transpose(
226284
data_format=None,
227285
dilation_rate=1,
228286
):
229-
raise NotImplementedError("MLX backend doesn't support conv transpose yet")
287+
inputs = convert_to_tensor(inputs)
288+
kernel = convert_to_tensor(kernel)
289+
data_format = standardize_data_format(data_format)
290+
num_spatial_dims = inputs.ndim - 2
291+
292+
strides = standardize_tuple(strides, num_spatial_dims, "strides")
293+
dilation_rate = standardize_tuple(
294+
dilation_rate, num_spatial_dims, "dilation_rate"
295+
)
296+
if output_padding is not None:
297+
output_padding = standardize_tuple(
298+
output_padding, num_spatial_dims, "output_padding"
299+
)
300+
301+
if data_format == "channels_first":
302+
# mlx expects channels_last
303+
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)
304+
305+
# mlx expects kernel with (out_channels, spatial..., in_channels)
306+
kernel = kernel.transpose(-2, *range(kernel.ndim - 2), -1)
307+
kernel_spatial_shape = kernel.shape[1:-1]
308+
309+
mlx_padding = compute_conv_transpose_padding_args_for_mlx(
310+
padding,
311+
num_spatial_dims,
312+
kernel_spatial_shape,
313+
dilation_rate,
314+
strides,
315+
output_padding,
316+
)
317+
318+
channels = inputs.shape[-1]
319+
kernel_in_channels = kernel.shape[-1]
320+
if channels % kernel_in_channels > 0:
321+
raise ValueError(
322+
"The number of input channels must be evenly divisible by "
323+
f"kernel's in_channels. Received input channels {channels} and "
324+
f"kernel in_channels {kernel_in_channels}. "
325+
)
326+
groups = channels // kernel_in_channels
327+
328+
result = mx.conv_general(
329+
inputs,
330+
kernel,
331+
stride=1, # stride is handled by input_dilation
332+
padding=mlx_padding,
333+
kernel_dilation=dilation_rate,
334+
input_dilation=strides,
335+
groups=groups,
336+
flip=True,
337+
)
338+
339+
if data_format == "channels_first":
340+
result = result.transpose(0, -1, *range(1, result.ndim - 1))
341+
342+
return result
230343

231344

232345
def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):

0 commit comments

Comments
 (0)