Skip to content

Commit c8bc015

Browse files
authored
linalg ops fix (#19763)
1 parent 5d45a9c commit c8bc015

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

keras/src/backend/mlx/linalg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,52 @@
44
from keras.src.backend.mlx.core import convert_to_tensor
55

66

7+
def det(a):
8+
raise NotImplementedError("det not yet implemented in mlx.")
9+
10+
11+
def eig(a):
12+
raise NotImplementedError("eig not yet implemented in mlx.")
13+
14+
15+
def eigh(a):
16+
raise NotImplementedError("eigh not yet implemented in mlx.")
17+
18+
19+
def lu_factor(a):
20+
raise NotImplementedError("lu_factor not yet implemented in mlx.")
21+
22+
23+
def solve(a, b):
24+
raise NotImplementedError("solve_triangular not yet implemented in mlx.")
25+
26+
27+
def solve_triangular(a, b, lower=False):
28+
raise NotImplementedError("solve_triangular not yet implemented in mlx.")
29+
30+
31+
def qr(x, mode="reduced"):
32+
return mx.linalg.qr(x)
33+
34+
35+
def svd(x, full_matrices=True, compute_uv=True):
36+
with mx.stream(mx.cpu):
37+
return mx.linalg.svd(x)
38+
39+
40+
def cholesky(a):
41+
with mx.stream(mx.cpu):
42+
return mx.linalg.cholesky(a)
43+
44+
745
def norm(x, ord=None, axis=None, keepdims=False):
846
dtype = standardize_dtype(x.dtype)
947
if "int" in dtype or dtype == "bool":
1048
dtype = dtypes.result_type(x.dtype, "float32")
1149
x = convert_to_tensor(x, dtype=dtype)
1250
return mx.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
51+
52+
53+
def inv(a):
54+
with mx.stream(mx.cpu):
55+
return mx.linalg.inv(a)

0 commit comments

Comments
 (0)