Skip to content

Commit ba0c886

Browse files
Implement missing functions in mlx backend (#19574)
* Implement missing functions in mlx backend * fixing median function implementation in mlx backend * Refactor median function implementation in mlx backend to pass all the tests * no lambda * use convert_to_tensor
1 parent bf15326 commit ba0c886

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

keras/src/backend/mlx/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def logsumexp(x, axis=None, keepdims=False):
3636

3737

3838
def qr(x, mode="reduced"):
39+
# TODO https://ml-explore.github.io/mlx/build/html/python/linalg.html
3940
raise NotImplementedError("QR decomposition not supported in mlx yet")
4041

4142

@@ -54,18 +55,22 @@ def extract_sequences(x, sequence_length, sequence_stride):
5455

5556

5657
def fft(x):
58+
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
5759
raise NotImplementedError("fft not yet implemented in mlx")
5860

5961

6062
def fft2(x):
63+
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
6164
raise NotImplementedError("fft not yet implemented in mlx")
6265

6366

6467
def rfft(x, fft_length=None):
68+
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
6569
raise NotImplementedError("fft not yet implemented in mlx")
6670

6771

6872
def irfft(x, fft_length=None):
73+
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
6974
raise NotImplementedError("fft not yet implemented in mlx")
7075

7176

keras/src/backend/mlx/numpy.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -585,19 +585,48 @@ def maximum(x1, x2):
585585

586586
def median(x, axis=-1, keepdims=False):
587587
x = convert_to_tensor(x)
588-
x_sorted = mx.sort(x, axis=axis)
589-
axis_size = x_sorted.shape[axis]
590-
medians = mx.take(
591-
x_sorted, indices=mx.array([(axis_size // 2) - 1]), axis=axis
592-
)
593-
if not keepdims:
594-
medians = mx.squeeze(medians, axis=axis)
588+
589+
if axis is None:
590+
x = x.flatten()
591+
axis = (0,)
592+
elif isinstance(axis, int):
593+
axis = (axis,)
594+
595+
axis = tuple(sorted(ax if ax >= 0 else ax + x.ndim for ax in axis))
596+
597+
transposed_axes = [i for i in range(x.ndim) if i not in axis] + list(axis)
598+
x = x.transpose(*transposed_axes)
599+
600+
shape_without_axes = tuple(x.shape[i] for i in range(x.ndim - len(axis)))
601+
x = x.reshape(shape_without_axes + (-1,))
602+
603+
x_sorted = mx.sort(x, axis=-1)
604+
mid_index = x_sorted.shape[-1] // 2
605+
if x_sorted.shape[-1] % 2 == 0:
606+
lower = mx.take(x_sorted, mx.array([mid_index - 1]), axis=-1)
607+
upper = mx.take(x_sorted, mx.array([mid_index]), axis=-1)
608+
medians = (lower + upper) / 2
609+
else:
610+
medians = mx.take(x_sorted, mx.array([mid_index]), axis=-1)
611+
612+
if keepdims:
613+
final_shape = list(shape_without_axes) + [1] * len(axis)
614+
medians = medians.reshape(final_shape)
615+
index_value_pairs = [
616+
(i, transposed_axes[i]) for i in range(len(transposed_axes))
617+
]
618+
index_value_pairs.sort(key=lambda pair: pair[1])
619+
sorted_indices = [pair[0] for pair in index_value_pairs]
620+
medians = medians.transpose(*sorted_indices)
621+
else:
622+
medians = medians.squeeze()
623+
595624
return medians
596625

597626

598627
def meshgrid(*x, indexing="xy"):
599-
# TODO: Implement inline like linspace
600-
raise NotImplementedError("The MLX backend doesn't support meshgrid yet")
628+
x = [convert_to_tensor(xi) for xi in x]
629+
return mx.meshgrid(*x, indexing=indexing)
601630

602631

603632
def min(x, axis=None, keepdims=False, initial=None):
@@ -826,6 +855,7 @@ def tensordot(x1, x2, axes=2):
826855

827856

828857
def round(x, decimals=0):
858+
x = convert_to_tensor(x)
829859
return mx.round(x, decimals=decimals)
830860

831861

0 commit comments

Comments
 (0)