Skip to content

Commit cb431df

Browse files
Adds 3D pooling (#1526)
1 parent 61d7877 commit cb431df

File tree

3 files changed

+250
-1
lines changed

3 files changed

+250
-1
lines changed

python/mlx/nn/layers/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@
7070
LayerNorm,
7171
RMSNorm,
7272
)
73-
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
73+
from mlx.nn.layers.pooling import (
74+
AvgPool1d,
75+
AvgPool2d,
76+
AvgPool3d,
77+
MaxPool1d,
78+
MaxPool2d,
79+
MaxPool3d,
80+
)
7481
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
7582
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
7683
from mlx.nn.layers.recurrent import GRU, LSTM, RNN

python/mlx/nn/layers/pooling.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,30 @@ def __init__(
158158
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
159159

160160

161+
class _Pool3d(_Pool):
162+
def __init__(
163+
self,
164+
pooling_function,
165+
padding_value,
166+
kernel_size: Union[int, Tuple[int, int, int]],
167+
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
168+
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
169+
):
170+
class_name = type(self).__name__
171+
msg = "[{}] '{}' must be an integer or a tuple containing 3 integers"
172+
kernel_size = _value_or_list(
173+
kernel_size, 3, msg.format(class_name, "kernel_size")
174+
)
175+
if stride is not None:
176+
stride = _value_or_list(stride, 3, msg.format(class_name, "stride"))
177+
else:
178+
stride = kernel_size
179+
padding = _value_or_list(padding, 3, msg.format(class_name, "padding"))
180+
padding = [(p, p) for p in padding]
181+
182+
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
183+
184+
161185
class MaxPool1d(_Pool1d):
162186
r"""Applies 1-dimensional max pooling.
163187
@@ -332,3 +356,104 @@ def __init__(
332356
padding: Optional[Union[int, Tuple[int, int]]] = 0,
333357
):
334358
super().__init__(mx.mean, 0, kernel_size, stride, padding)
359+
360+
361+
class MaxPool3d(_Pool3d):
362+
"""
363+
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
364+
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
365+
H_{out}, W_{out}, C)`, given by:
366+
367+
.. math::
368+
\begin{aligned}
369+
\text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
370+
& \text{input}(N_i, \text{stride[0]} \times d + l,
371+
\text{stride[1]} \times h + m,
372+
\text{stride[2]} \times w + n, C_j),
373+
\end{aligned}
374+
375+
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
376+
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
377+
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
378+
379+
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
380+
381+
- a single ``int`` -- in which case the same value is used for the depth,
382+
height and width axis;
383+
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
384+
for the depth axis, the second ``int`` for the height axis, and the third
385+
``int`` for the width axis.
386+
387+
Args:
388+
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
389+
stride (int or tuple(int, int, int), optional): The stride of the pooling
390+
window. Default: ``kernel_size``.
391+
padding (int or tuple(int, int, int), optional): How much negative infinity
392+
padding to apply to the input. The padding is applied on both sides
393+
of the depth, height and width axis. Default: ``0``.
394+
395+
Examples:
396+
>>> import mlx.core as mx
397+
>>> import mlx.nn.layers as nn
398+
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
399+
>>> pool = nn.MaxPool3d(kernel_size=2, stride=2)
400+
>>> pool(x)
401+
"""
402+
403+
def __init__(
404+
self,
405+
kernel_size: Union[int, Tuple[int, int, int]],
406+
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
407+
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
408+
):
409+
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
410+
411+
412+
class AvgPool3d(_Pool3d):
413+
"""
414+
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
415+
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
416+
H_{out}, W_{out}, C)`, given by:
417+
418+
.. math::
419+
\begin{aligned}
420+
\text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
421+
& \text{input}(N_i, \text{stride[0]} \times d + l,
422+
\text{stride[1]} \times h + m,
423+
\text{stride[2]} \times w + n, C_j),
424+
\end{aligned}
425+
426+
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
427+
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
428+
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
429+
430+
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
431+
432+
- a single ``int`` -- in which case the same value is used for the depth,
433+
height and width axis;
434+
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
435+
for the depth axis, the second ``int`` for the height axis, and the third
436+
``int`` for the width axis.
437+
438+
Args:
439+
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
440+
stride (int or tuple(int, int, int), optional): The stride of the pooling
441+
window. Default: ``kernel_size``.
442+
padding (int or tuple(int, int, int), optional): How much zero
443+
padding to apply to the input. The padding is applied on both sides
444+
of the depth, height and width axis. Default: ``0``.
445+
446+
Examples:
447+
>>> import mlx.core as mx
448+
>>> import mlx.nn.layers as nn
449+
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
450+
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
451+
>>> pool(x)
452+
"""
453+
def __init__(
454+
self,
455+
kernel_size: Union[int, Tuple[int, int, int]],
456+
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
457+
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
458+
):
459+
super().__init__(mx.mean, 0, kernel_size, stride, padding)

python/tests/test_nn.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,123 @@ def test_pooling(self):
15891589
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
15901590
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
15911591
)
1592+
# Test 3d pooling
1593+
x = mx.array(
1594+
[
1595+
[
1596+
[
1597+
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
1598+
[[9, 10, 11], [12, 13, 14], [15, 16, 17]],
1599+
[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
1600+
],
1601+
[
1602+
[[27, 28, 29], [30, 31, 32], [33, 34, 35]],
1603+
[[36, 37, 38], [39, 40, 41], [42, 43, 44]],
1604+
[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
1605+
],
1606+
]
1607+
]
1608+
)
1609+
expected_max_pool_output_no_padding_stride_1 = [
1610+
[[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]]
1611+
]
1612+
1613+
expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]]
1614+
expected_max_pool_output_padding_1 = [
1615+
[
1616+
[[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]],
1617+
[[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]],
1618+
]
1619+
]
1620+
expected_irregular_max_pool_output = [
1621+
[
1622+
[[[9, 10, 11], [12, 13, 14], [15, 16, 17]]],
1623+
[[[36, 37, 38], [39, 40, 41], [42, 43, 44]]],
1624+
]
1625+
]
1626+
1627+
self.assertTrue(
1628+
np.array_equal(
1629+
nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x),
1630+
expected_max_pool_output_no_padding_stride_1,
1631+
)
1632+
)
1633+
self.assertTrue(
1634+
np.array_equal(
1635+
nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x),
1636+
expected_max_pool_output_no_padding_stride_2,
1637+
)
1638+
)
1639+
self.assertTrue(
1640+
np.array_equal(
1641+
nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x),
1642+
expected_max_pool_output_padding_1,
1643+
)
1644+
)
1645+
self.assertTrue(
1646+
np.array_equal(
1647+
nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
1648+
expected_irregular_max_pool_output,
1649+
)
1650+
)
1651+
self.assertEqual(
1652+
str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)),
1653+
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
1654+
)
1655+
1656+
expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5],
1657+
[22.5, 23.5, 24.5]],
1658+
[[28.5, 29.5, 30.5],
1659+
[31.5, 32.5, 33.5]]]]
1660+
]
1661+
1662+
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
1663+
expected_avg_pool_output_padding_1 = [
1664+
[[[[0, 0.125, 0.25],
1665+
[1.125, 1.375, 1.625]],
1666+
[[3.375, 3.625, 3.875],
1667+
[9, 9.5, 10]]],
1668+
[[[3.375, 3.5, 3.625],
1669+
[7.875, 8.125, 8.375]],
1670+
[[10.125, 10.375, 10.625],
1671+
[22.5, 23, 23.5]]]]
1672+
]
1673+
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
1674+
[7.5, 8.5, 9.5],
1675+
[10.5, 11.5, 12.5]]],
1676+
[[[31.5, 32.5, 33.5],
1677+
[34.5, 35.5, 36.5],
1678+
[37.5, 38.5, 39.5]]]]
1679+
]
1680+
1681+
self.assertTrue(
1682+
np.array_equal(
1683+
nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x),
1684+
expected_avg_pool_output_no_padding_stride_1,
1685+
)
1686+
)
1687+
self.assertTrue(
1688+
np.array_equal(
1689+
nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x),
1690+
expected_avg_pool_output_no_padding_stride_2,
1691+
)
1692+
)
1693+
self.assertTrue(
1694+
np.array_equal(
1695+
nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x),
1696+
expected_avg_pool_output_padding_1,
1697+
)
1698+
)
1699+
self.assertTrue(
1700+
np.array_equal(
1701+
nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
1702+
expected_irregular_avg_pool_output,
1703+
)
1704+
)
1705+
self.assertEqual(
1706+
str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)),
1707+
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
1708+
)
15921709

15931710
def test_set_dtype(self):
15941711
def assert_dtype(layer, dtype):

0 commit comments

Comments
 (0)