Skip to content

Commit f0e9882

Browse files
fbadineFadi Badine
andauthored
Add max_pool and average_pool for MLX (#20814)
* Added max_pool and avg_pool functionalities * Corrected and enhanced arguments formatting * Replaced printing length of arguments by printing the arguments themselves --------- Co-authored-by: Fadi Badine <fadibadine@Fadis-MacBook-Air.local>
1 parent 2bc4baf commit f0e9882

File tree

1 file changed

+161
-2
lines changed
  • keras/src/backend/mlx

1 file changed

+161
-2
lines changed

keras/src/backend/mlx/nn.py

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import builtins
2+
import math
3+
import operator
4+
from itertools import accumulate
25

36
import mlx.core as mx
47
import mlx.nn as nn
@@ -122,16 +125,172 @@ def log_softmax(x, axis=-1):
122125
return x - mx.logsumexp(x, axis=axis, keepdims=True)
123126

124127

128+
def _calculate_padding(input_shape, pool_size, strides):
129+
ndim = len(input_shape)
130+
131+
padding = ()
132+
for d in range(ndim):
133+
pad = max(0, (pool_size[d] - 1) - ((input_shape[d] - 1) % strides[d]))
134+
padding = padding + (pad,)
135+
136+
return [(p // 2, (p + 1) // 2) for p in padding]
137+
138+
139+
def _non_overlapping_sliding_windows(x, shape, window_shape):
140+
# Compute the intermediate shape
141+
new_shape = [shape[0]]
142+
for s, w in zip(shape[1:], window_shape):
143+
new_shape.append(s // w)
144+
new_shape.append(w)
145+
new_shape.append(shape[-1])
146+
147+
last_axis = len(new_shape) - 1
148+
axis_order = [
149+
0,
150+
*range(1, last_axis, 2),
151+
*range(2, last_axis, 2),
152+
last_axis,
153+
]
154+
155+
x = x.reshape(new_shape)
156+
x = x.transpose(axis_order)
157+
return x
158+
159+
160+
def _sliding_windows(x, window_shape, window_strides):
161+
if x.ndim < 3:
162+
raise ValueError(
163+
"To extract sliding windows at least 1 spatial dimension "
164+
f"(3 total) is needed but the input only has {x.ndim} dimension(s)."
165+
)
166+
167+
spatial_dims = x.shape[1:-1]
168+
if not (len(spatial_dims) == len(window_shape) == len(window_strides)):
169+
raise ValueError(
170+
"To extract sliding windows, the lengths of window_shape and "
171+
"window_strides must be equal to the signal's spatial dimensions. "
172+
f"However, the signal has spatial_dims={spatial_dims} while "
173+
f"window_shape={window_shape} and window_strides={window_strides}."
174+
)
175+
176+
shape = x.shape
177+
if all(
178+
window == stride and size % window == 0
179+
for size, window, stride in zip(
180+
spatial_dims, window_shape, window_strides
181+
)
182+
):
183+
return _non_overlapping_sliding_windows(x, shape, window_shape)
184+
185+
strides = list(
186+
reversed(list(accumulate(reversed(shape + (1,)), operator.mul)))
187+
)[1:]
188+
189+
# Compute the output shape
190+
final_shape = [shape[0]]
191+
final_shape += [
192+
(size - window) // stride + 1
193+
for size, window, stride in zip(
194+
spatial_dims, window_shape, window_strides
195+
)
196+
]
197+
final_shape += window_shape
198+
final_shape += [shape[-1]]
199+
200+
# Compute the output strides
201+
final_strides = strides[:1]
202+
final_strides += [
203+
og_stride * stride
204+
for og_stride, stride in zip(strides[1:-1], window_strides)
205+
]
206+
final_strides += strides[1:-1]
207+
final_strides += strides[-1:] # should always be [1]
208+
209+
return mx.as_strided(x, final_shape, final_strides)
210+
211+
212+
def _pool(
213+
inputs, pool_size, strides, padding, padding_value, data_format, pooling_fn
214+
):
215+
if padding not in ("same", "valid"):
216+
raise ValueError(
217+
f"Invalid padding '{padding}', must be 'same' or 'valid'."
218+
)
219+
220+
if data_format == "channels_first":
221+
# mlx expects channels_last
222+
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)
223+
224+
if padding == "same":
225+
pads = _calculate_padding(inputs.shape[1:-1], pool_size, strides)
226+
227+
if any(p[1] > 0 for p in pads):
228+
inputs = mx.pad(
229+
inputs,
230+
[(0, 0)] + pads + [(0, 0)],
231+
constant_values=padding_value,
232+
)
233+
234+
inputs = _sliding_windows(inputs, pool_size, strides)
235+
236+
axes = tuple(range(-len(pool_size) - 1, -1, 1))
237+
result = pooling_fn(inputs, axes)
238+
239+
if data_format == "channels_first":
240+
result = result.transpose(0, -1, *range(1, result.ndim - 1))
241+
return result
242+
243+
125244
def max_pool(
126245
inputs, pool_size, strides=None, padding="valid", data_format=None
127246
):
128-
raise NotImplementedError("MLX backend doesn't support max pooling yet")
247+
inputs = convert_to_tensor(inputs)
248+
data_format = standardize_data_format(data_format)
249+
num_spatial_dims = inputs.ndim - 2
250+
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
251+
strides = pool_size if strides is None else strides
252+
strides = standardize_tuple(strides, num_spatial_dims, "strides")
253+
254+
return _pool(
255+
inputs, pool_size, strides, padding, -mx.inf, data_format, mx.max
256+
)
129257

130258

131259
def average_pool(
132260
inputs, pool_size, strides=None, padding="valid", data_format=None
133261
):
134-
raise NotImplementedError("MLX backend doesn't support average pooling yet")
262+
inputs = convert_to_tensor(inputs)
263+
data_format = standardize_data_format(data_format)
264+
num_spatial_dims = inputs.ndim - 2
265+
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
266+
strides = pool_size if strides is None else strides
267+
strides = standardize_tuple(strides, num_spatial_dims, "strides")
268+
269+
# Create a pool by applying the sum function in each window
270+
pooled = _pool(
271+
inputs, pool_size, strides, padding, 0.0, data_format, mx.sum
272+
)
273+
if padding == "valid":
274+
# No padding needed. Divide by the size of the pool which gives
275+
# the average
276+
return pooled / math.prod(pool_size)
277+
else:
278+
# Create a tensor of ones of the same shape of inputs.
279+
# Then create a pool, padding by zero and using sum as function.
280+
# This will create a tensor of the smae dimensions as pooled tensor
281+
# with values being the sum.
282+
# By dividing pooled by windows_counts, we get the average while
283+
# skipping the padded values.
284+
window_counts = _pool(
285+
mx.ones(inputs.shape, inputs.dtype),
286+
pool_size,
287+
strides,
288+
padding,
289+
0.0,
290+
data_format,
291+
mx.sum,
292+
)
293+
return pooled / window_counts
135294

136295

137296
def conv(

0 commit comments

Comments
 (0)