@@ -585,19 +585,48 @@ def maximum(x1, x2):
585
585
586
586
def median (x , axis = - 1 , keepdims = False ):
587
587
x = convert_to_tensor (x )
588
- x_sorted = mx .sort (x , axis = axis )
589
- axis_size = x_sorted .shape [axis ]
590
- medians = mx .take (
591
- x_sorted , indices = mx .array ([(axis_size // 2 ) - 1 ]), axis = axis
592
- )
593
- if not keepdims :
594
- medians = mx .squeeze (medians , axis = axis )
588
+
589
+ if axis is None :
590
+ x = x .flatten ()
591
+ axis = (0 ,)
592
+ elif isinstance (axis , int ):
593
+ axis = (axis ,)
594
+
595
+ axis = tuple (sorted (ax if ax >= 0 else ax + x .ndim for ax in axis ))
596
+
597
+ transposed_axes = [i for i in range (x .ndim ) if i not in axis ] + list (axis )
598
+ x = x .transpose (* transposed_axes )
599
+
600
+ shape_without_axes = tuple (x .shape [i ] for i in range (x .ndim - len (axis )))
601
+ x = x .reshape (shape_without_axes + (- 1 ,))
602
+
603
+ x_sorted = mx .sort (x , axis = - 1 )
604
+ mid_index = x_sorted .shape [- 1 ] // 2
605
+ if x_sorted .shape [- 1 ] % 2 == 0 :
606
+ lower = mx .take (x_sorted , mx .array ([mid_index - 1 ]), axis = - 1 )
607
+ upper = mx .take (x_sorted , mx .array ([mid_index ]), axis = - 1 )
608
+ medians = (lower + upper ) / 2
609
+ else :
610
+ medians = mx .take (x_sorted , mx .array ([mid_index ]), axis = - 1 )
611
+
612
+ if keepdims :
613
+ final_shape = list (shape_without_axes ) + [1 ] * len (axis )
614
+ medians = medians .reshape (final_shape )
615
+ index_value_pairs = [
616
+ (i , transposed_axes [i ]) for i in range (len (transposed_axes ))
617
+ ]
618
+ index_value_pairs .sort (key = lambda pair : pair [1 ])
619
+ sorted_indices = [pair [0 ] for pair in index_value_pairs ]
620
+ medians = medians .transpose (* sorted_indices )
621
+ else :
622
+ medians = medians .squeeze ()
623
+
595
624
return medians
596
625
597
626
598
627
def meshgrid (* x , indexing = "xy" ):
599
- # TODO: Implement inline like linspace
600
- raise NotImplementedError ( "The MLX backend doesn't support meshgrid yet" )
628
+ x = [ convert_to_tensor ( xi ) for xi in x ]
629
+ return mx . meshgrid ( * x , indexing = indexing )
601
630
602
631
603
632
def min (x , axis = None , keepdims = False , initial = None ):
@@ -826,6 +855,7 @@ def tensordot(x1, x2, axes=2):
826
855
827
856
828
857
def round (x , decimals = 0 ):
858
+ x = convert_to_tensor (x )
829
859
return mx .round (x , decimals = decimals )
830
860
831
861
0 commit comments