Skip to content

Commit 1f9139c

Browse files
authored
add argpartition to mlx numpy (#19680)
1 parent ad948c2 commit 1f9139c

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

keras/src/backend/mlx/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ def argsort(x, axis=-1):
180180
return mx.argsort(x, axis=axis)
181181

182182

183+
def argpartition(x, kth, axis=-1):
184+
x = convert_to_tensor(x)
185+
return mx.argpartition(x, kth, axis).astype(mx.int32)
186+
187+
183188
def array(x, dtype=None):
184189
return convert_to_tensor(x, dtype=dtype)
185190

0 commit comments

Comments
 (0)