@@ -314,34 +314,14 @@ def _diagonal_indices(H, W, k):
314
314
315
315
def diag (x , k = 0 ):
316
316
x = convert_to_tensor (x )
317
-
318
- if len (x .shape ) == 2 :
319
- return x [_diagonal_indices (* x .shape , k )]
320
-
321
- elif len (x .shape ) == 1 :
322
- N = x .shape [0 ] + abs (k )
323
- zeros = mx .zeros ((N , N ))
324
- zeros [_diagonal_indices (N , N , k )] = x
325
- return zeros
326
-
327
- else :
328
- raise ValueError ("Input must be 1d or 2d" )
317
+ if x .dtype in [mx .int64 , mx .uint64 ]:
318
+ return mx .diag (x , k = k , stream = mx .Device (type = mx .DeviceType .cpu ))
319
+ return mx .diag (x , k = k )
329
320
330
321
331
322
def diagonal (x , offset = 0 , axis1 = 0 , axis2 = 1 ):
332
323
x = convert_to_tensor (x )
333
-
334
- ndim = x .ndim
335
- axis1 = (ndim + axis1 ) % ndim
336
- axis2 = (ndim + axis2 ) % ndim
337
-
338
- max_axis = builtins .max (axis1 , axis2 )
339
- indices = [slice (None ) for _ in range (max_axis + 1 )]
340
- indices [axis1 ], indices [axis2 ] = _diagonal_indices (
341
- x .shape [axis1 ], x .shape [axis2 ], offset
342
- )
343
-
344
- return x [indices ]
324
+ return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 )
345
325
346
326
347
327
def diff (x , n = 1 , axis = - 1 ):
0 commit comments