Skip to content

Commit 722883a

Browse files
authored
fix mlx/numpy (#19613)
skip float64 tests dont convert scalar to tensor if the mx api accepts it.
1 parent ba0c886 commit 722883a

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

keras/src/backend/mlx/numpy.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313

1414
def add(x1, x2):
15-
x1, x2 = convert_to_tensors(x1, x2)
15+
x1 = maybe_convert_to_tensor(x1)
16+
x2 = maybe_convert_to_tensor(x2)
1617
return mx.add(x1, x2)
1718

1819

@@ -21,7 +22,8 @@ def einsum(subscripts, *operands, **kwargs):
2122

2223

2324
def subtract(x1, x2):
24-
x1, x2 = convert_to_tensors(x1, x2)
25+
x1 = maybe_convert_to_tensor(x1)
26+
x2 = maybe_convert_to_tensor(x2)
2527
return mx.subtract(x1, x2)
2628

2729

@@ -31,7 +33,8 @@ def matmul(x1, x2):
3133

3234

3335
def multiply(x1, x2):
34-
x1, x2 = convert_to_tensors(x1, x2)
36+
x1 = maybe_convert_to_tensor(x1)
37+
x2 = maybe_convert_to_tensor(x2)
3538
return mx.multiply(x1, x2)
3639

3740

@@ -121,7 +124,10 @@ def arange(start, stop=None, step=1, dtype=None):
121124
if stop is not None:
122125
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
123126
dtype = result_type(*dtypes_to_resolve)
124-
dtype = standardize_dtype(dtype)
127+
dtype = to_mlx_dtype(dtype)
128+
if stop is None:
129+
stop = start
130+
start = 0
125131
return mx.arange(start, stop, step=step, dtype=dtype)
126132

127133

@@ -383,8 +389,8 @@ def empty(shape, dtype=None):
383389

384390

385391
def equal(x1, x2):
386-
x1 = convert_to_tensor(x1)
387-
x2 = convert_to_tensor(x2)
392+
x1 = maybe_convert_to_tensor(x1)
393+
x2 = maybe_convert_to_tensor(x2)
388394
return mx.equal(x1, x2)
389395

390396

@@ -437,12 +443,14 @@ def full_like(x, fill_value, dtype=None):
437443

438444

439445
def greater(x1, x2):
440-
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
446+
x1 = maybe_convert_to_tensor(x1)
447+
x2 = maybe_convert_to_tensor(x2)
441448
return mx.greater(x1, x2)
442449

443450

444451
def greater_equal(x1, x2):
445-
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
452+
x1 = maybe_convert_to_tensor(x1)
453+
x2 = maybe_convert_to_tensor(x2)
446454
return mx.greater_equal(x1, x2)
447455

448456

@@ -496,12 +504,14 @@ def isnan(x):
496504

497505

498506
def less(x1, x2):
499-
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
507+
x1 = maybe_convert_to_tensor(x1)
508+
x2 = maybe_convert_to_tensor(x2)
500509
return mx.less(x1, x2)
501510

502511

503512
def less_equal(x1, x2):
504-
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
513+
x1 = maybe_convert_to_tensor(x1)
514+
x2 = maybe_convert_to_tensor(x2)
505515
return mx.less_equal(x1, x2)
506516

507517

@@ -547,8 +557,8 @@ def log2(x):
547557

548558

549559
def logaddexp(x1, x2):
550-
x1 = convert_to_tensor(x1)
551-
x2 = convert_to_tensor(x2)
560+
x1 = maybe_convert_to_tensor(x1)
561+
x2 = maybe_convert_to_tensor(x2)
552562
return mx.logaddexp(x1, x2)
553563

554564

@@ -578,8 +588,8 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
578588

579589

580590
def maximum(x1, x2):
581-
x1 = convert_to_tensor(x1)
582-
x2 = convert_to_tensor(x2)
591+
x1 = maybe_convert_to_tensor(x1)
592+
x2 = maybe_convert_to_tensor(x2)
583593
return mx.maximum(x1, x2)
584594

585595

@@ -647,14 +657,14 @@ def min(x, axis=None, keepdims=False, initial=None):
647657

648658

649659
def minimum(x1, x2):
650-
x1 = convert_to_tensor(x1)
651-
x2 = convert_to_tensor(x2)
660+
x1 = maybe_convert_to_tensor(x1)
661+
x2 = maybe_convert_to_tensor(x2)
652662
return mx.minimum(x1, x2)
653663

654664

655665
def mod(x1, x2):
656-
x1 = convert_to_tensor(x1)
657-
x2 = convert_to_tensor(x2)
666+
x1 = maybe_convert_to_tensor(x1)
667+
x2 = maybe_convert_to_tensor(x2)
658668
return mx.remainder(x1, x2)
659669

660670

@@ -690,7 +700,8 @@ def nonzero(x):
690700

691701

692702
def not_equal(x1, x2):
693-
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
703+
x1 = maybe_convert_to_tensor(x1)
704+
x2 = maybe_convert_to_tensor(x2)
694705
return x1 != x2
695706

696707

@@ -950,13 +961,14 @@ def divide_no_nan(x1, x2):
950961

951962

952963
def true_divide(x1, x2):
953-
x1 = convert_to_tensor(x1)
954-
x2 = convert_to_tensor(x2)
964+
x1 = maybe_convert_to_tensor(x1)
965+
x2 = maybe_convert_to_tensor(x2)
955966
return divide(x1, x2)
956967

957968

958969
def power(x1, x2):
959-
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
970+
x1 = maybe_convert_to_tensor(x1)
971+
x2 = maybe_convert_to_tensor(x2)
960972
return mx.power(x1, x2)
961973

962974

@@ -1014,3 +1026,9 @@ def floor_divide(x1, x2):
10141026
def logical_xor(x1, x2):
10151027
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
10161028
return x1.astype(mx.bool_) - x2.astype(mx.bool_)
1029+
1030+
1031+
def maybe_convert_to_tensor(x):
1032+
if isinstance(x, (int, float, bool)):
1033+
return x
1034+
return convert_to_tensor(x)

keras/src/ops/numpy_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4750,6 +4750,9 @@ class NumpyDtypeTest(testing.TestCase, parameterized.TestCase):
47504750
INT_DTYPES = [
47514751
x for x in INT_DTYPES if x not in ["uint16", "uint32", "uint64"]
47524752
]
4753+
elif backend.backend() == "mlx":
4754+
ALL_DTYPES = [x for x in ALL_DTYPES if x != "float64"]
4755+
# FLOAT_DTYPES = [x for x in FLOAT_DTYPES if x != "float64" ]
47534756
# Remove float8 dtypes for the following tests
47544757
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]
47554758

0 commit comments

Comments
 (0)