@@ -158,6 +158,30 @@ def __init__(
158
158
super ().__init__ (pooling_function , kernel_size , stride , padding , padding_value )
159
159
160
160
161
+ class _Pool3d (_Pool ):
162
+ def __init__ (
163
+ self ,
164
+ pooling_function ,
165
+ padding_value ,
166
+ kernel_size : Union [int , Tuple [int , int , int ]],
167
+ stride : Optional [Union [int , Tuple [int , int , int ]]] = None ,
168
+ padding : Optional [Union [int , Tuple [int , int , int ]]] = 0 ,
169
+ ):
170
+ class_name = type (self ).__name__
171
+ msg = "[{}] '{}' must be an integer or a tuple containing 3 integers"
172
+ kernel_size = _value_or_list (
173
+ kernel_size , 3 , msg .format (class_name , "kernel_size" )
174
+ )
175
+ if stride is not None :
176
+ stride = _value_or_list (stride , 3 , msg .format (class_name , "stride" ))
177
+ else :
178
+ stride = kernel_size
179
+ padding = _value_or_list (padding , 3 , msg .format (class_name , "padding" ))
180
+ padding = [(p , p ) for p in padding ]
181
+
182
+ super ().__init__ (pooling_function , kernel_size , stride , padding , padding_value )
183
+
184
+
161
185
class MaxPool1d (_Pool1d ):
162
186
r"""Applies 1-dimensional max pooling.
163
187
@@ -332,3 +356,104 @@ def __init__(
332
356
padding : Optional [Union [int , Tuple [int , int ]]] = 0 ,
333
357
):
334
358
super ().__init__ (mx .mean , 0 , kernel_size , stride , padding )
359
+
360
+
361
+ class MaxPool3d (_Pool3d ):
362
+ """
363
+ Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
364
+ :math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
365
+ H_{out}, W_{out}, C)`, given by:
366
+
367
+ .. math::
368
+ \b egin{aligned}
369
+ \t ext{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
370
+ & \t ext{input}(N_i, \t ext{stride[0]} \t imes d + l,
371
+ \t ext{stride[1]} \t imes h + m,
372
+ \t ext{stride[2]} \t imes w + n, C_j),
373
+ \end{aligned}
374
+
375
+ where :math:`D_{out} = \left\lfloor\f rac{D + 2 * \t ext{padding[0]} - \t ext{kernel\_size[0]}}{\t ext{stride[0]}}\r ight\r floor + 1`,
376
+ :math:`H_{out} = \left\lfloor\f rac{H + 2 * \t ext{padding[1]} - \t ext{kernel\_size[1]}}{\t ext{stride[1]}}\r ight\r floor + 1`,
377
+ :math:`W_{out} = \left\lfloor\f rac{W + 2 * \t ext{padding[2]} - \t ext{kernel\_size[2]}}{\t ext{stride[2]}}\r ight\r floor + 1`.
378
+
379
+ The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
380
+
381
+ - a single ``int`` -- in which case the same value is used for the depth,
382
+ height and width axis;
383
+ - a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
384
+ for the depth axis, the second ``int`` for the height axis, and the third
385
+ ``int`` for the width axis.
386
+
387
+ Args:
388
+ kernel_size (int or tuple(int, int, int)): The size of the pooling window.
389
+ stride (int or tuple(int, int, int), optional): The stride of the pooling
390
+ window. Default: ``kernel_size``.
391
+ padding (int or tuple(int, int, int), optional): How much negative infinity
392
+ padding to apply to the input. The padding is applied on both sides
393
+ of the depth, height and width axis. Default: ``0``.
394
+
395
+ Examples:
396
+ >>> import mlx.core as mx
397
+ >>> import mlx.nn.layers as nn
398
+ >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
399
+ >>> pool = nn.MaxPool3d(kernel_size=2, stride=2)
400
+ >>> pool(x)
401
+ """
402
+
403
+ def __init__ (
404
+ self ,
405
+ kernel_size : Union [int , Tuple [int , int , int ]],
406
+ stride : Optional [Union [int , Tuple [int , int , int ]]] = None ,
407
+ padding : Optional [Union [int , Tuple [int , int , int ]]] = 0 ,
408
+ ):
409
+ super ().__init__ (mx .max , - float ("inf" ), kernel_size , stride , padding )
410
+
411
+
412
+ class AvgPool3d (_Pool3d ):
413
+ """
414
+ Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
415
+ :math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
416
+ H_{out}, W_{out}, C)`, given by:
417
+
418
+ .. math::
419
+ \b egin{aligned}
420
+ \t ext{out}(N_i, d, h, w, C_j) = & \f rac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
421
+ & \t ext{input}(N_i, \t ext{stride[0]} \t imes d + l,
422
+ \t ext{stride[1]} \t imes h + m,
423
+ \t ext{stride[2]} \t imes w + n, C_j),
424
+ \end{aligned}
425
+
426
+ where :math:`D_{out} = \left\lfloor\f rac{D + 2 * \t ext{padding[0]} - \t ext{kernel\_size[0]}}{\t ext{stride[0]}}\r ight\r floor + 1`,
427
+ :math:`H_{out} = \left\lfloor\f rac{H + 2 * \t ext{padding[1]} - \t ext{kernel\_size[1]}}{\t ext{stride[1]}}\r ight\r floor + 1`,
428
+ :math:`W_{out} = \left\lfloor\f rac{W + 2 * \t ext{padding[2]} - \t ext{kernel\_size[2]}}{\t ext{stride[2]}}\r ight\r floor + 1`.
429
+
430
+ The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
431
+
432
+ - a single ``int`` -- in which case the same value is used for the depth,
433
+ height and width axis;
434
+ - a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
435
+ for the depth axis, the second ``int`` for the height axis, and the third
436
+ ``int`` for the width axis.
437
+
438
+ Args:
439
+ kernel_size (int or tuple(int, int, int)): The size of the pooling window.
440
+ stride (int or tuple(int, int, int), optional): The stride of the pooling
441
+ window. Default: ``kernel_size``.
442
+ padding (int or tuple(int, int, int), optional): How much zero
443
+ padding to apply to the input. The padding is applied on both sides
444
+ of the depth, height and width axis. Default: ``0``.
445
+
446
+ Examples:
447
+ >>> import mlx.core as mx
448
+ >>> import mlx.nn.layers as nn
449
+ >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
450
+ >>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
451
+ >>> pool(x)
452
+ """
453
+ def __init__ (
454
+ self ,
455
+ kernel_size : Union [int , Tuple [int , int , int ]],
456
+ stride : Optional [Union [int , Tuple [int , int , int ]]] = None ,
457
+ padding : Optional [Union [int , Tuple [int , int , int ]]] = 0 ,
458
+ ):
459
+ super ().__init__ (mx .mean , 0 , kernel_size , stride , padding )
0 commit comments