12
12
13
13
14
14
def add (x1 , x2 ):
15
- x1 , x2 = convert_to_tensors (x1 , x2 )
15
+ x1 = maybe_convert_to_tensor (x1 )
16
+ x2 = maybe_convert_to_tensor (x2 )
16
17
return mx .add (x1 , x2 )
17
18
18
19
@@ -21,7 +22,8 @@ def einsum(subscripts, *operands, **kwargs):
21
22
22
23
23
24
def subtract (x1 , x2 ):
24
- x1 , x2 = convert_to_tensors (x1 , x2 )
25
+ x1 = maybe_convert_to_tensor (x1 )
26
+ x2 = maybe_convert_to_tensor (x2 )
25
27
return mx .subtract (x1 , x2 )
26
28
27
29
@@ -31,7 +33,8 @@ def matmul(x1, x2):
31
33
32
34
33
35
def multiply (x1 , x2 ):
34
- x1 , x2 = convert_to_tensors (x1 , x2 )
36
+ x1 = maybe_convert_to_tensor (x1 )
37
+ x2 = maybe_convert_to_tensor (x2 )
35
38
return mx .multiply (x1 , x2 )
36
39
37
40
@@ -121,7 +124,10 @@ def arange(start, stop=None, step=1, dtype=None):
121
124
if stop is not None :
122
125
dtypes_to_resolve .append (getattr (stop , "dtype" , type (stop )))
123
126
dtype = result_type (* dtypes_to_resolve )
124
- dtype = standardize_dtype (dtype )
127
+ dtype = to_mlx_dtype (dtype )
128
+ if stop is None :
129
+ stop = start
130
+ start = 0
125
131
return mx .arange (start , stop , step = step , dtype = dtype )
126
132
127
133
@@ -383,8 +389,8 @@ def empty(shape, dtype=None):
383
389
384
390
385
391
def equal (x1 , x2 ):
386
- x1 = convert_to_tensor (x1 )
387
- x2 = convert_to_tensor (x2 )
392
+ x1 = maybe_convert_to_tensor (x1 )
393
+ x2 = maybe_convert_to_tensor (x2 )
388
394
return mx .equal (x1 , x2 )
389
395
390
396
@@ -437,12 +443,14 @@ def full_like(x, fill_value, dtype=None):
437
443
438
444
439
445
def greater (x1 , x2 ):
440
- x1 , x2 = convert_to_tensor (x1 ), convert_to_tensor (x2 )
446
+ x1 = maybe_convert_to_tensor (x1 )
447
+ x2 = maybe_convert_to_tensor (x2 )
441
448
return mx .greater (x1 , x2 )
442
449
443
450
444
451
def greater_equal (x1 , x2 ):
445
- x1 , x2 = convert_to_tensor (x1 ), convert_to_tensor (x2 )
452
+ x1 = maybe_convert_to_tensor (x1 )
453
+ x2 = maybe_convert_to_tensor (x2 )
446
454
return mx .greater_equal (x1 , x2 )
447
455
448
456
@@ -496,12 +504,14 @@ def isnan(x):
496
504
497
505
498
506
def less (x1 , x2 ):
499
- x1 , x2 = convert_to_tensor (x1 ), convert_to_tensor (x2 )
507
+ x1 = maybe_convert_to_tensor (x1 )
508
+ x2 = maybe_convert_to_tensor (x2 )
500
509
return mx .less (x1 , x2 )
501
510
502
511
503
512
def less_equal (x1 , x2 ):
504
- x1 , x2 = convert_to_tensor (x1 ), convert_to_tensor (x2 )
513
+ x1 = maybe_convert_to_tensor (x1 )
514
+ x2 = maybe_convert_to_tensor (x2 )
505
515
return mx .less_equal (x1 , x2 )
506
516
507
517
@@ -547,8 +557,8 @@ def log2(x):
547
557
548
558
549
559
def logaddexp (x1 , x2 ):
550
- x1 = convert_to_tensor (x1 )
551
- x2 = convert_to_tensor (x2 )
560
+ x1 = maybe_convert_to_tensor (x1 )
561
+ x2 = maybe_convert_to_tensor (x2 )
552
562
return mx .logaddexp (x1 , x2 )
553
563
554
564
@@ -578,8 +588,8 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
578
588
579
589
580
590
def maximum (x1 , x2 ):
581
- x1 = convert_to_tensor (x1 )
582
- x2 = convert_to_tensor (x2 )
591
+ x1 = maybe_convert_to_tensor (x1 )
592
+ x2 = maybe_convert_to_tensor (x2 )
583
593
return mx .maximum (x1 , x2 )
584
594
585
595
@@ -647,14 +657,14 @@ def min(x, axis=None, keepdims=False, initial=None):
647
657
648
658
649
659
def minimum (x1 , x2 ):
650
- x1 = convert_to_tensor (x1 )
651
- x2 = convert_to_tensor (x2 )
660
+ x1 = maybe_convert_to_tensor (x1 )
661
+ x2 = maybe_convert_to_tensor (x2 )
652
662
return mx .minimum (x1 , x2 )
653
663
654
664
655
665
def mod (x1 , x2 ):
656
- x1 = convert_to_tensor (x1 )
657
- x2 = convert_to_tensor (x2 )
666
+ x1 = maybe_convert_to_tensor (x1 )
667
+ x2 = maybe_convert_to_tensor (x2 )
658
668
return mx .remainder (x1 , x2 )
659
669
660
670
@@ -690,7 +700,8 @@ def nonzero(x):
690
700
691
701
692
702
def not_equal (x1 , x2 ):
693
- x1 , x2 = convert_to_tensor (x1 ), convert_to_tensor (x2 )
703
+ x1 = maybe_convert_to_tensor (x1 )
704
+ x2 = maybe_convert_to_tensor (x2 )
694
705
return x1 != x2
695
706
696
707
@@ -950,13 +961,14 @@ def divide_no_nan(x1, x2):
950
961
951
962
952
963
def true_divide (x1 , x2 ):
953
- x1 = convert_to_tensor (x1 )
954
- x2 = convert_to_tensor (x2 )
964
+ x1 = maybe_convert_to_tensor (x1 )
965
+ x2 = maybe_convert_to_tensor (x2 )
955
966
return divide (x1 , x2 )
956
967
957
968
958
969
def power (x1 , x2 ):
959
- x1 , x2 = convert_to_tensor (x1 ), convert_to_tensor (x2 )
970
+ x1 = maybe_convert_to_tensor (x1 )
971
+ x2 = maybe_convert_to_tensor (x2 )
960
972
return mx .power (x1 , x2 )
961
973
962
974
@@ -1014,3 +1026,9 @@ def floor_divide(x1, x2):
1014
1026
def logical_xor (x1 , x2 ):
1015
1027
x1 , x2 = convert_to_tensor (x1 ), convert_to_tensor (x2 )
1016
1028
return x1 .astype (mx .bool_ ) - x2 .astype (mx .bool_ )
1029
+
1030
+
1031
+ def maybe_convert_to_tensor (x ):
1032
+ if isinstance (x , (int , float , bool )):
1033
+ return x
1034
+ return convert_to_tensor (x )
0 commit comments