@@ -35,7 +35,8 @@ def det(a):
35
35
36
36
37
37
def eig (a ):
38
- raise NotImplementedError ("eig not yet implemented in mlx." )
38
+ # Using numpy for now, as mlx does not support eig yet.
39
+ return np .linalg .eig (a )
39
40
40
41
41
42
def eigh (a ):
@@ -44,7 +45,9 @@ def eigh(a):
44
45
45
46
46
47
def lu_factor (a ):
47
- raise NotImplementedError ("lu_factor not yet implemented in mlx." )
48
+ with mx .stream (mx .cpu ):
49
+ # This op is not yet supported on the GPU.
50
+ return mx .linalg .lu_factor (a )
48
51
49
52
50
53
def solve (a , b ):
@@ -55,7 +58,15 @@ def solve(a, b):
55
58
56
59
57
60
def solve_triangular (a , b , lower = False ):
58
- raise NotImplementedError ("solve_triangular not yet implemented in mlx." )
61
+ upper = not lower
62
+ with mx .stream (mx .cpu ):
63
+ # This op is not yet supported on the GPU.
64
+ if b .ndim == a .ndim - 1 :
65
+ b = mx .expand_dims (b , axis = - 1 )
66
+ return mx .squeeze (
67
+ mx .linalg .solve_triangular (a , b , upper = upper ), axis = - 1
68
+ )
69
+ return mx .linalg .solve_triangular (a , b , upper = upper )
59
70
60
71
61
72
def qr (x , mode = "reduced" ):
@@ -103,4 +114,43 @@ def inv(a):
103
114
104
115
105
116
def lstsq (a , b , rcond = None ):
106
- raise NotImplementedError ("lstsq not yet implemented in mlx." )
117
+ a = convert_to_tensor (a )
118
+ b = convert_to_tensor (b )
119
+ if a .shape [0 ] != b .shape [0 ]:
120
+ raise ValueError (
121
+ "Incompatible shapes: a and b must have the same number of rows."
122
+ )
123
+ b_orig_ndim = b .ndim
124
+ if b .ndim == 1 :
125
+ b = mx .expand_dims (b , axis = - 1 )
126
+ elif b .ndim > 2 :
127
+ raise ValueError ("b must be 1D or 2D." )
128
+
129
+ if b .ndim != 2 :
130
+ raise ValueError ("b must be 1D or 2D." )
131
+
132
+ m , n = a .shape
133
+ dtype = a .dtype
134
+
135
+ eps = np .finfo (np .array (a ).dtype ).eps
136
+ if a .shape == ():
137
+ s = mx .zeros ((0 ,), dtype = dtype )
138
+ x = mx .zeros ((n , * b .shape [1 :]), dtype = dtype )
139
+ else :
140
+ if rcond is None :
141
+ rcond = eps * max (m , n )
142
+ else :
143
+ rcond = mx .where (rcond < 0 , eps , rcond )
144
+ u , s , vt = svd (a , full_matrices = False )
145
+
146
+ mask = s >= mx .array (rcond , dtype = s .dtype ) * s [0 ]
147
+ safe_s = mx .array (mx .where (mask , s , 1 ), dtype = dtype )
148
+ s_inv = mx .where (mask , 1 / safe_s , 0 )
149
+ s_inv = mx .expand_dims (s_inv , axis = 1 )
150
+ u_t_b = mx .matmul (mx .transpose (mx .conj (u )), b )
151
+ x = mx .matmul (mx .transpose (mx .conj (vt )), s_inv * u_t_b )
152
+
153
+ if b_orig_ndim == 1 :
154
+ x = mx .squeeze (x , axis = - 1 )
155
+
156
+ return x
0 commit comments