Skip to content

Commit 5368176

Browse files
authored
add linalg to mlx backend (#19698)
1 parent 1f9139c commit 5368176

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

keras/src/backend/mlx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from keras.src.backend.mlx import core
44
from keras.src.backend.mlx import image
5+
from keras.src.backend.mlx import linalg
56
from keras.src.backend.mlx import math
67
from keras.src.backend.mlx import nn
78
from keras.src.backend.mlx import numpy

keras/src/backend/mlx/linalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import mlx.core as mx
2+
3+
from keras.src.backend.common import standardize_dtype
4+
from keras.src.backend.mlx.core import convert_to_tensor
5+
6+
7+
def norm(x, ord=None, axis=None, keepdims=False):
8+
dtype = standardize_dtype(x.dtype)
9+
if "int" in dtype or dtype == "bool":
10+
dtype = dtypes.result_type(x.dtype, "float32")
11+
x = convert_to_tensor(x, dtype=dtype)
12+
return mx.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)

keras/src/ops/linalg_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,8 @@ def _reconstruct(lu, pivots, m, n):
444444
)
445445
)
446446
def test_norm(self, ndim, ord, axis, keepdims):
447+
if backend.backend() == "mlx" and ord == "nuc":
448+
self.skipTest("ord='nuc' not supported in MLX")
447449
if ndim == 1:
448450
x = np.random.random((5,))
449451
else:

0 commit comments

Comments
 (0)