Skip to content

Commit 3442cca

Browse files
authored
Several linalg functions support (#21260)
1 parent 93d52b0 commit 3442cca

File tree

1 file changed

+54
-4
lines changed

1 file changed

+54
-4
lines changed

keras/src/backend/mlx/linalg.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def det(a):
3535

3636

3737
def eig(a):
38-
raise NotImplementedError("eig not yet implemented in mlx.")
38+
# Using numpy for now, as mlx does not support eig yet.
39+
return np.linalg.eig(a)
3940

4041

4142
def eigh(a):
@@ -44,7 +45,9 @@ def eigh(a):
4445

4546

4647
def lu_factor(a):
47-
raise NotImplementedError("lu_factor not yet implemented in mlx.")
48+
with mx.stream(mx.cpu):
49+
# This op is not yet supported on the GPU.
50+
return mx.linalg.lu_factor(a)
4851

4952

5053
def solve(a, b):
@@ -55,7 +58,15 @@ def solve(a, b):
5558

5659

5760
def solve_triangular(a, b, lower=False):
58-
raise NotImplementedError("solve_triangular not yet implemented in mlx.")
61+
upper = not lower
62+
with mx.stream(mx.cpu):
63+
# This op is not yet supported on the GPU.
64+
if b.ndim == a.ndim - 1:
65+
b = mx.expand_dims(b, axis=-1)
66+
return mx.squeeze(
67+
mx.linalg.solve_triangular(a, b, upper=upper), axis=-1
68+
)
69+
return mx.linalg.solve_triangular(a, b, upper=upper)
5970

6071

6172
def qr(x, mode="reduced"):
@@ -103,4 +114,43 @@ def inv(a):
103114

104115

105116
def lstsq(a, b, rcond=None):
106-
raise NotImplementedError("lstsq not yet implemented in mlx.")
117+
a = convert_to_tensor(a)
118+
b = convert_to_tensor(b)
119+
if a.shape[0] != b.shape[0]:
120+
raise ValueError(
121+
"Incompatible shapes: a and b must have the same number of rows."
122+
)
123+
b_orig_ndim = b.ndim
124+
if b.ndim == 1:
125+
b = mx.expand_dims(b, axis=-1)
126+
elif b.ndim > 2:
127+
raise ValueError("b must be 1D or 2D.")
128+
129+
if b.ndim != 2:
130+
raise ValueError("b must be 1D or 2D.")
131+
132+
m, n = a.shape
133+
dtype = a.dtype
134+
135+
eps = np.finfo(np.array(a).dtype).eps
136+
if a.shape == ():
137+
s = mx.zeros((0,), dtype=dtype)
138+
x = mx.zeros((n, *b.shape[1:]), dtype=dtype)
139+
else:
140+
if rcond is None:
141+
rcond = eps * max(m, n)
142+
else:
143+
rcond = mx.where(rcond < 0, eps, rcond)
144+
u, s, vt = svd(a, full_matrices=False)
145+
146+
mask = s >= mx.array(rcond, dtype=s.dtype) * s[0]
147+
safe_s = mx.array(mx.where(mask, s, 1), dtype=dtype)
148+
s_inv = mx.where(mask, 1 / safe_s, 0)
149+
s_inv = mx.expand_dims(s_inv, axis=1)
150+
u_t_b = mx.matmul(mx.transpose(mx.conj(u)), b)
151+
x = mx.matmul(mx.transpose(mx.conj(vt)), s_inv * u_t_b)
152+
153+
if b_orig_ndim == 1:
154+
x = mx.squeeze(x, axis=-1)
155+
156+
return x

0 commit comments

Comments
 (0)