Skip to content

Commit b8338f7

Browse files
authored
mlx.nn.conv implementation (#20792)
1 parent 0ad4c78 commit b8338f7

File tree

1 file changed

+63
-1
lines changed
  • keras/src/backend/mlx

1 file changed

+63
-1
lines changed

keras/src/backend/mlx/nn.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import mlx.core as mx
22
import mlx.nn as nn
33

4+
from keras.src.backend import standardize_data_format
45
from keras.src.backend import standardize_dtype
56
from keras.src.backend.config import epsilon
67
from keras.src.backend.mlx.core import convert_to_tensor
78
from keras.src.backend.mlx.core import to_mlx_dtype
89
from keras.src.backend.mlx.numpy import clip
10+
from keras.src.utils.argument_validation import standardize_tuple
911

1012

1113
def relu(x):
@@ -129,7 +131,67 @@ def conv(
129131
data_format=None,
130132
dilation_rate=1,
131133
):
132-
raise NotImplementedError("MLX backend doesn't support conv yet")
134+
inputs = convert_to_tensor(inputs)
135+
kernel = convert_to_tensor(kernel)
136+
data_format = standardize_data_format(data_format)
137+
num_spatial_dims = inputs.ndim - 2
138+
139+
strides = standardize_tuple(strides, num_spatial_dims, "strides")
140+
dilation_rate = standardize_tuple(
141+
dilation_rate, num_spatial_dims, "dilation_rate"
142+
)
143+
144+
if data_format == "channels_first":
145+
# mlx expects channels_last
146+
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)
147+
148+
# mlx expects kernel with (out_channels, spatial..., in_channels)
149+
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2)
150+
151+
if padding == "valid":
152+
mlx_padding = 0
153+
elif padding == "same":
154+
kernel_spatial_shape = kernel.shape[1:-1]
155+
start_paddings = []
156+
end_paddings = []
157+
for dim_size, k_size, d_rate, s in zip(
158+
inputs.shape[1:-1], kernel_spatial_shape, dilation_rate, strides
159+
):
160+
out_size = (dim_size + s - 1) // s
161+
effective_k_size = (k_size - 1) * d_rate + 1
162+
total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size)
163+
pad_start = total_pad // 2
164+
pad_end = total_pad - pad_start
165+
start_paddings.append(pad_start)
166+
end_paddings.append(pad_end)
167+
mlx_padding = (start_paddings, end_paddings)
168+
else:
169+
raise ValueError(f"Invalid padding value: {padding}")
170+
171+
channels = inputs.shape[-1]
172+
kernel_in_channels = kernel.shape[-1]
173+
if channels % kernel_in_channels > 0:
174+
raise ValueError(
175+
"The number of input channels must be evenly divisible by "
176+
f"kernel's in_channels. Received input channels {channels} and "
177+
f"kernel in_channels {kernel_in_channels}. "
178+
)
179+
groups = channels // kernel_in_channels
180+
181+
result = mx.conv_general(
182+
inputs,
183+
kernel,
184+
stride=strides,
185+
padding=mlx_padding,
186+
kernel_dilation=dilation_rate,
187+
input_dilation=1,
188+
groups=groups,
189+
flip=False,
190+
)
191+
if data_format == "channels_first":
192+
result = result.transpose(0, -1, *range(1, result.ndim - 1))
193+
194+
return result
133195

134196

135197
def depthwise_conv(

0 commit comments

Comments
 (0)