Skip to content

Commit ad948c2

Browse files
authored
handle i64 for scatter and cumsum (#19666)
1 parent 5a3542b commit ad948c2

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

keras/src/backend/mlx/core.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ def convert_to_tensor(x, dtype=None, sparse=None):
7171
return x.value
7272

7373
if isinstance(x, np.ndarray):
74-
if x.dtype == np.int64:
75-
x = x.astype(np.int32)
7674
x = x.astype(standardize_dtype(x.dtype))
7775
return mx.array(x, dtype=mlx_dtype)
7876

@@ -211,6 +209,10 @@ def vectorized_map(function, elements):
211209
def scatter(indices, values, shape):
212210
indices = convert_to_tensor(indices)
213211
values = convert_to_tensor(values)
212+
if values.dtype == mx.int64:
213+
values = values.astype(mx.int32)
214+
elif values.dtype == mx.uint64:
215+
values = values.astype(mx.uint32)
214216
zeros = mx.zeros(shape, dtype=values.dtype)
215217
indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
216218
zeros = zeros.at[indices].add(values)
@@ -222,6 +224,10 @@ def scatter_update(inputs, indices, updates):
222224
inputs = convert_to_tensor(inputs)
223225
indices = convert_to_tensor(indices)
224226
updates = convert_to_tensor(updates)
227+
if inputs.dtype == mx.int64:
228+
inputs = inputs.astype(mx.int32)
229+
elif inputs.dtype == mx.uint64:
230+
inputs = inputs.astype(mx.uint32)
225231
indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
226232
inputs[indices] = updates
227233

keras/src/backend/mlx/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,19 @@ def cumprod(x, axis=None, dtype=None):
278278
x = convert_to_tensor(x)
279279
if dtype is not None:
280280
x = cast(x, dtype)
281+
if x.dtype in [mx.int64, mx.uint64]:
282+
return mx.cumprod(
283+
x, axis=axis, stream=mx.Device(type=mx.DeviceType.cpu)
284+
)
281285
return mx.cumprod(x, axis=axis)
282286

283287

284288
def cumsum(x, axis=None, dtype=None):
285289
x = convert_to_tensor(x)
286290
if dtype is not None:
287291
x = cast(x, dtype)
292+
if x.dtype in [mx.int64, mx.uint64]:
293+
return mx.cumsum(x, axis=axis, stream=mx.Device(type=mx.DeviceType.cpu))
288294
return mx.cumsum(x, axis=axis)
289295

290296

keras/src/ops/numpy_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4943,7 +4943,7 @@ class NumpyDtypeTest(testing.TestCase, parameterized.TestCase):
49434943
]
49444944
elif backend.backend() == "mlx":
49454945
ALL_DTYPES = [x for x in ALL_DTYPES if x != "float64"]
4946-
# FLOAT_DTYPES = [x for x in FLOAT_DTYPES if x != "float64" ]
4946+
FLOAT_DTYPES = tuple([x for x in FLOAT_DTYPES if x != "float64"])
49474947
# Remove float8 dtypes for the following tests
49484948
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]
49494949

0 commit comments

Comments
 (0)