Skip to content

Commit 756e243

Browse files
authored
mlx - fix diag and diagonal in numpy (#19714)
1 parent 0641a1e commit 756e243

File tree

1 file changed

+4
-24
lines changed

1 file changed

+4
-24
lines changed

keras/src/backend/mlx/numpy.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -314,34 +314,14 @@ def _diagonal_indices(H, W, k):
314314

315315
def diag(x, k=0):
316316
x = convert_to_tensor(x)
317-
318-
if len(x.shape) == 2:
319-
return x[_diagonal_indices(*x.shape, k)]
320-
321-
elif len(x.shape) == 1:
322-
N = x.shape[0] + abs(k)
323-
zeros = mx.zeros((N, N))
324-
zeros[_diagonal_indices(N, N, k)] = x
325-
return zeros
326-
327-
else:
328-
raise ValueError("Input must be 1d or 2d")
317+
if x.dtype in [mx.int64, mx.uint64]:
318+
return mx.diag(x, k=k, stream=mx.Device(type=mx.DeviceType.cpu))
319+
return mx.diag(x, k=k)
329320

330321

331322
def diagonal(x, offset=0, axis1=0, axis2=1):
332323
x = convert_to_tensor(x)
333-
334-
ndim = x.ndim
335-
axis1 = (ndim + axis1) % ndim
336-
axis2 = (ndim + axis2) % ndim
337-
338-
max_axis = builtins.max(axis1, axis2)
339-
indices = [slice(None) for _ in range(max_axis + 1)]
340-
indices[axis1], indices[axis2] = _diagonal_indices(
341-
x.shape[axis1], x.shape[axis2], offset
342-
)
343-
344-
return x[indices]
324+
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
345325

346326

347327
def diff(x, n=1, axis=-1):

0 commit comments

Comments
 (0)