Skip to content

Commit 93d52b0

Browse files
authored
mlx - remat and continued test updates (#21244)
* all tests excl remaining linalg ops passing on Apple silicon * comment message
1 parent 4970a00 commit 93d52b0

File tree

10 files changed

+55
-90
lines changed

10 files changed

+55
-90
lines changed

keras/src/backend/mlx/core.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
SUPPORTS_SPARSE_TENSORS = False
2323
SUPPORTS_RAGGED_TENSORS = False
24-
IS_THREAD_SAFE = True
24+
# TODO: follow updates and adjust to thread safe when possible
25+
IS_THREAD_SAFE = False # False as of mlx 0.24.0
2526

2627
MLX_DTYPES = {
2728
"float16": mx.float16,
@@ -596,6 +597,18 @@ def __call__(self, *args, **kwargs):
596597
return outputs
597598

598599

600+
def remat(f):
601+
"""Implementation of rematerialization.
602+
603+
Args:
604+
f: The function or operation to rematerialize.
605+
Returns:
606+
A function wrapping f that defines a custom gradient, which
607+
recomputes f on the backwards pass of a gradient call.
608+
"""
609+
return mx.checkpoint(f)
610+
611+
599612
def enable_float64():
600613
"""Returns context manager forcing operations on cpu
601614

keras/src/dtype_policies/dtype_policy_map_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from keras.src import backend
45
from keras.src import dtype_policies
56
from keras.src import layers
67
from keras.src import models
@@ -23,6 +24,9 @@ def tearDown(self):
2324

2425
@pytest.mark.requires_trainable_backend
2526
def test_basic_usage(self):
27+
if backend.backend() == "mlx":
28+
self.skipTest("mlx backend does not yet support quantization")
29+
2630
# Create a subclass that might contain mixing dtype policies for
2731
# sublayers.
2832
# It is important to ensure that `dtype` is passed to sublayers and

keras/src/initializers/constant_initializers_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def test_identity_initializer(self):
7777

7878
@skip_if_backend("openvino", "openvino backend does not support `arange`")
7979
def test_stft_initializer(self):
80+
if backend.backend() == "mlx":
81+
# for mlx backend force on to cpu for float64
82+
self.mlx_cpu_context = backend.core.enable_float64()
83+
self.mlx_cpu_context.__enter__()
84+
8085
shape = (256, 1, 513)
8186
time_range = np.arange(256).reshape((-1, 1, 1))
8287
freq_range = (np.arange(513) / 1024.0).reshape((1, 1, -1))
@@ -142,3 +147,6 @@ def test_stft_initializer(self):
142147
# Test compatible class_name
143148
initializer = initializers.get("STFTInitializer")
144149
self.assertIsInstance(initializer, initializers.STFT)
150+
151+
if backend.backend() == "mlx":
152+
self.mlx_cpu_context.__exit__(None, None, None)

keras/src/layers/attention/grouped_query_attention_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,9 @@ def test_masking(self, use_causal_mask):
287287
mask = mask & np.array(
288288
[[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3]
289289
).astype(bool)
290-
del masked_query._keras_mask
291-
del masked_value._keras_mask
290+
if backend.backend() != "mlx":
291+
del masked_query._keras_mask
292+
del masked_value._keras_mask
292293
output_with_manual_mask = layer(
293294
query=masked_query, value=masked_value, attention_mask=mask
294295
)

keras/src/layers/attention/multi_head_attention_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,9 @@ def test_masking(self, use_causal_mask):
356356
)
357357
if use_causal_mask:
358358
mask = mask & np.array([[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3])
359-
del masked_query._keras_mask
360-
del masked_value._keras_mask
359+
if backend.backend() != "mlx":
360+
del masked_query._keras_mask
361+
del masked_value._keras_mask
361362
output_with_manual_mask = layer(
362363
query=masked_query, value=masked_value, attention_mask=mask
363364
)

keras/src/layers/layer_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def call(self, x):
199199

200200
def test_quantized_layer_with_remat(self):
201201
"""Test rematerialization on a quantized layer."""
202+
if backend.backend() == "mlx":
203+
self.skipTest("float8 is not yet supported in mlx backend.")
202204
with patch(
203205
"keras.src.backend.common.remat.remat", wraps=remat.remat
204206
) as mock_remat:

keras/src/models/model_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,8 +1238,8 @@ def test_export_error(self):
12381238
with self.assertRaisesRegex(
12391239
NotImplementedError,
12401240
(
1241-
r"`export_saved_model` only currently supports the "
1242-
r"tensorflow, jax and torch backends."
1241+
r"`ExportArchive` is only compatible with TensorFlow, "
1242+
r"JAX and Torch backends."
12431243
),
12441244
):
12451245
model.export(temp_filepath, format="tf_saved_model")

keras/src/ops/linalg_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ def test_cholesky(self):
198198
out = linalg.cholesky(x)
199199
self.assertEqual(out.shape, (4, 3, 3))
200200

201+
if backend.backend() == "mlx":
202+
# mlx backend currently cannot mimic numpy ValueError
203+
# for bad Cholesky decomp, e.g. if matrix not pos semi-def
204+
return
205+
201206
x = KerasTensor([10, 20, 15])
202207
with self.assertRaises(ValueError):
203208
linalg.cholesky(x)
@@ -340,8 +345,11 @@ def test_svd(self):
340345
class LinalgOpsCorrectnessTest(testing.TestCase):
341346
def test_cholesky(self):
342347
x = np.random.rand(4, 3, 3).astype("float32")
343-
with self.assertRaises(ValueError):
344-
linalg.cholesky(x)
348+
if backend.backend() != "mlx":
349+
# mlx backend currently cannot mimic numpy ValueError
350+
# for bad Cholesky decomp, e.g. if matrix not pos semi-def
351+
with self.assertRaises(ValueError):
352+
linalg.cholesky(x)
345353
x_psd = x @ x.transpose((0, 2, 1)) + 1e-5 * np.eye(3)
346354
out = linalg.cholesky(x_psd)
347355
self.assertAllClose(out, np.linalg.cholesky(x_psd), atol=1e-4)

0 commit comments

Comments
 (0)