|
1 | 1 | import builtins
|
| 2 | +import math |
| 3 | +import operator |
| 4 | +from itertools import accumulate |
2 | 5 |
|
3 | 6 | import mlx.core as mx
|
4 | 7 | import mlx.nn as nn
|
@@ -122,16 +125,172 @@ def log_softmax(x, axis=-1):
|
122 | 125 | return x - mx.logsumexp(x, axis=axis, keepdims=True)
|
123 | 126 |
|
124 | 127 |
|
| 128 | +def _calculate_padding(input_shape, pool_size, strides): |
| 129 | + ndim = len(input_shape) |
| 130 | + |
| 131 | + padding = () |
| 132 | + for d in range(ndim): |
| 133 | + pad = max(0, (pool_size[d] - 1) - ((input_shape[d] - 1) % strides[d])) |
| 134 | + padding = padding + (pad,) |
| 135 | + |
| 136 | + return [(p // 2, (p + 1) // 2) for p in padding] |
| 137 | + |
| 138 | + |
| 139 | +def _non_overlapping_sliding_windows(x, shape, window_shape): |
| 140 | + # Compute the intermediate shape |
| 141 | + new_shape = [shape[0]] |
| 142 | + for s, w in zip(shape[1:], window_shape): |
| 143 | + new_shape.append(s // w) |
| 144 | + new_shape.append(w) |
| 145 | + new_shape.append(shape[-1]) |
| 146 | + |
| 147 | + last_axis = len(new_shape) - 1 |
| 148 | + axis_order = [ |
| 149 | + 0, |
| 150 | + *range(1, last_axis, 2), |
| 151 | + *range(2, last_axis, 2), |
| 152 | + last_axis, |
| 153 | + ] |
| 154 | + |
| 155 | + x = x.reshape(new_shape) |
| 156 | + x = x.transpose(axis_order) |
| 157 | + return x |
| 158 | + |
| 159 | + |
| 160 | +def _sliding_windows(x, window_shape, window_strides): |
| 161 | + if x.ndim < 3: |
| 162 | + raise ValueError( |
| 163 | + "To extract sliding windows at least 1 spatial dimension " |
| 164 | + f"(3 total) is needed but the input only has {x.ndim} dimension(s)." |
| 165 | + ) |
| 166 | + |
| 167 | + spatial_dims = x.shape[1:-1] |
| 168 | + if not (len(spatial_dims) == len(window_shape) == len(window_strides)): |
| 169 | + raise ValueError( |
| 170 | + "To extract sliding windows, the lengths of window_shape and " |
| 171 | + "window_strides must be equal to the signal's spatial dimensions. " |
| 172 | + f"However, the signal has spatial_dims={spatial_dims} while " |
| 173 | + f"window_shape={window_shape} and window_strides={window_strides}." |
| 174 | + ) |
| 175 | + |
| 176 | + shape = x.shape |
| 177 | + if all( |
| 178 | + window == stride and size % window == 0 |
| 179 | + for size, window, stride in zip( |
| 180 | + spatial_dims, window_shape, window_strides |
| 181 | + ) |
| 182 | + ): |
| 183 | + return _non_overlapping_sliding_windows(x, shape, window_shape) |
| 184 | + |
| 185 | + strides = list( |
| 186 | + reversed(list(accumulate(reversed(shape + (1,)), operator.mul))) |
| 187 | + )[1:] |
| 188 | + |
| 189 | + # Compute the output shape |
| 190 | + final_shape = [shape[0]] |
| 191 | + final_shape += [ |
| 192 | + (size - window) // stride + 1 |
| 193 | + for size, window, stride in zip( |
| 194 | + spatial_dims, window_shape, window_strides |
| 195 | + ) |
| 196 | + ] |
| 197 | + final_shape += window_shape |
| 198 | + final_shape += [shape[-1]] |
| 199 | + |
| 200 | + # Compute the output strides |
| 201 | + final_strides = strides[:1] |
| 202 | + final_strides += [ |
| 203 | + og_stride * stride |
| 204 | + for og_stride, stride in zip(strides[1:-1], window_strides) |
| 205 | + ] |
| 206 | + final_strides += strides[1:-1] |
| 207 | + final_strides += strides[-1:] # should always be [1] |
| 208 | + |
| 209 | + return mx.as_strided(x, final_shape, final_strides) |
| 210 | + |
| 211 | + |
| 212 | +def _pool( |
| 213 | + inputs, pool_size, strides, padding, padding_value, data_format, pooling_fn |
| 214 | +): |
| 215 | + if padding not in ("same", "valid"): |
| 216 | + raise ValueError( |
| 217 | + f"Invalid padding '{padding}', must be 'same' or 'valid'." |
| 218 | + ) |
| 219 | + |
| 220 | + if data_format == "channels_first": |
| 221 | + # mlx expects channels_last |
| 222 | + inputs = inputs.transpose(0, *range(2, inputs.ndim), 1) |
| 223 | + |
| 224 | + if padding == "same": |
| 225 | + pads = _calculate_padding(inputs.shape[1:-1], pool_size, strides) |
| 226 | + |
| 227 | + if any(p[1] > 0 for p in pads): |
| 228 | + inputs = mx.pad( |
| 229 | + inputs, |
| 230 | + [(0, 0)] + pads + [(0, 0)], |
| 231 | + constant_values=padding_value, |
| 232 | + ) |
| 233 | + |
| 234 | + inputs = _sliding_windows(inputs, pool_size, strides) |
| 235 | + |
| 236 | + axes = tuple(range(-len(pool_size) - 1, -1, 1)) |
| 237 | + result = pooling_fn(inputs, axes) |
| 238 | + |
| 239 | + if data_format == "channels_first": |
| 240 | + result = result.transpose(0, -1, *range(1, result.ndim - 1)) |
| 241 | + return result |
| 242 | + |
| 243 | + |
125 | 244 | def max_pool(
|
126 | 245 | inputs, pool_size, strides=None, padding="valid", data_format=None
|
127 | 246 | ):
|
128 |
| - raise NotImplementedError("MLX backend doesn't support max pooling yet") |
| 247 | + inputs = convert_to_tensor(inputs) |
| 248 | + data_format = standardize_data_format(data_format) |
| 249 | + num_spatial_dims = inputs.ndim - 2 |
| 250 | + pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size") |
| 251 | + strides = pool_size if strides is None else strides |
| 252 | + strides = standardize_tuple(strides, num_spatial_dims, "strides") |
| 253 | + |
| 254 | + return _pool( |
| 255 | + inputs, pool_size, strides, padding, -mx.inf, data_format, mx.max |
| 256 | + ) |
129 | 257 |
|
130 | 258 |
|
131 | 259 | def average_pool(
|
132 | 260 | inputs, pool_size, strides=None, padding="valid", data_format=None
|
133 | 261 | ):
|
134 |
| - raise NotImplementedError("MLX backend doesn't support average pooling yet") |
| 262 | + inputs = convert_to_tensor(inputs) |
| 263 | + data_format = standardize_data_format(data_format) |
| 264 | + num_spatial_dims = inputs.ndim - 2 |
| 265 | + pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size") |
| 266 | + strides = pool_size if strides is None else strides |
| 267 | + strides = standardize_tuple(strides, num_spatial_dims, "strides") |
| 268 | + |
| 269 | + # Create a pool by applying the sum function in each window |
| 270 | + pooled = _pool( |
| 271 | + inputs, pool_size, strides, padding, 0.0, data_format, mx.sum |
| 272 | + ) |
| 273 | + if padding == "valid": |
| 274 | + # No padding needed. Divide by the size of the pool which gives |
| 275 | + # the average |
| 276 | + return pooled / math.prod(pool_size) |
| 277 | + else: |
| 278 | + # Create a tensor of ones of the same shape of inputs. |
| 279 | + # Then create a pool, padding by zero and using sum as function. |
| 280 | + # This will create a tensor of the smae dimensions as pooled tensor |
| 281 | + # with values being the sum. |
| 282 | + # By dividing pooled by windows_counts, we get the average while |
| 283 | + # skipping the padded values. |
| 284 | + window_counts = _pool( |
| 285 | + mx.ones(inputs.shape, inputs.dtype), |
| 286 | + pool_size, |
| 287 | + strides, |
| 288 | + padding, |
| 289 | + 0.0, |
| 290 | + data_format, |
| 291 | + mx.sum, |
| 292 | + ) |
| 293 | + return pooled / window_counts |
135 | 294 |
|
136 | 295 |
|
137 | 296 | def conv(
|
|
0 commit comments