|
4 | 4 | from keras.src.backend.mlx.core import convert_to_tensor
|
5 | 5 |
|
6 | 6 |
|
| 7 | +def det(a): |
| 8 | + raise NotImplementedError("det not yet implemented in mlx.") |
| 9 | + |
| 10 | + |
| 11 | +def eig(a): |
| 12 | + raise NotImplementedError("eig not yet implemented in mlx.") |
| 13 | + |
| 14 | + |
| 15 | +def eigh(a): |
| 16 | + raise NotImplementedError("eigh not yet implemented in mlx.") |
| 17 | + |
| 18 | + |
| 19 | +def lu_factor(a): |
| 20 | + raise NotImplementedError("lu_factor not yet implemented in mlx.") |
| 21 | + |
| 22 | + |
| 23 | +def solve(a, b): |
| 24 | + raise NotImplementedError("solve_triangular not yet implemented in mlx.") |
| 25 | + |
| 26 | + |
| 27 | +def solve_triangular(a, b, lower=False): |
| 28 | + raise NotImplementedError("solve_triangular not yet implemented in mlx.") |
| 29 | + |
| 30 | + |
| 31 | +def qr(x, mode="reduced"): |
| 32 | + return mx.linalg.qr(x) |
| 33 | + |
| 34 | + |
| 35 | +def svd(x, full_matrices=True, compute_uv=True): |
| 36 | + with mx.stream(mx.cpu): |
| 37 | + return mx.linalg.svd(x) |
| 38 | + |
| 39 | + |
| 40 | +def cholesky(a): |
| 41 | + with mx.stream(mx.cpu): |
| 42 | + return mx.linalg.cholesky(a) |
| 43 | + |
| 44 | + |
7 | 45 | def norm(x, ord=None, axis=None, keepdims=False):
|
8 | 46 | dtype = standardize_dtype(x.dtype)
|
9 | 47 | if "int" in dtype or dtype == "bool":
|
10 | 48 | dtype = dtypes.result_type(x.dtype, "float32")
|
11 | 49 | x = convert_to_tensor(x, dtype=dtype)
|
12 | 50 | return mx.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
|
| 51 | + |
| 52 | + |
| 53 | +def inv(a): |
| 54 | + with mx.stream(mx.cpu): |
| 55 | + return mx.linalg.inv(a) |
0 commit comments