Skip to content

Commit 82b7c57

Browse files
authored
mlx - fix repeat and eye (#19734)
1 parent 756e243 commit 82b7c57

File tree

1 file changed

+2
-16
lines changed

1 file changed

+2
-16
lines changed

keras/src/backend/mlx/numpy.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -751,21 +751,7 @@ def reciprocal(x):
751751

752752
def repeat(x, repeats, axis=None):
753753
x = convert_to_tensor(x)
754-
755-
if axis is None:
756-
x = x.reshape(-1)
757-
axis = 0
758-
759-
shape = x.shape
760-
shape.insert(axis + 1, 1)
761-
x = x.reshape(shape)
762-
shape[axis + 1] = repeats
763-
x = mx.broadcast_to(x, shape)
764-
shape.pop(axis + 1)
765-
shape[axis] *= repeats
766-
x = x.reshape(shape)
767-
768-
return x
754+
return mx.repeat(x, repeats, axis=axis)
769755

770756

771757
def reshape(x, new_shape):
@@ -1004,7 +990,7 @@ def eye(N, M=None, k=None, dtype=None):
1004990
dtype = to_mlx_dtype(dtype or config.floatx())
1005991
M = N if M is None else M
1006992
k = 0 if k is None else k
1007-
return mx.eye(N, M, k)
993+
return mx.eye(N, M, k, dtype=dtype)
1008994

1009995

1010996
def floor_divide(x1, x2):

0 commit comments

Comments
 (0)