Skip to content

Commit a008598

Browse files
authored
mlx - fix error with loading models with h5 and update mlx.core.while_loop (#20819)
* fix error with loading models with h5 and core while_loop * adjust h5py import error
1 parent f0e9882 commit a008598

File tree

1 file changed

+41
-7
lines changed

1 file changed

+41
-7
lines changed

keras/src/backend/mlx/core.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from keras.src.backend.common.keras_tensor import KerasTensor
88
from keras.src.backend.common.stateless_scope import StatelessScope
99

10+
try:
11+
import h5py
12+
except ImportError:
13+
h5py = None
14+
1015
SUPPORTS_SPARSE_TENSORS = False
1116

1217
MLX_DTYPES = {
@@ -55,6 +60,13 @@ def __array__(self, dtype=None):
5560
return value
5661

5762

63+
def _is_h5py_dataset(obj):
64+
return (
65+
type(obj).__module__.startswith("h5py.")
66+
and type(obj).__name__ == "Dataset"
67+
)
68+
69+
5870
def convert_to_tensor(x, dtype=None, sparse=None):
5971
if sparse:
6072
raise ValueError("`sparse=True` is not supported with mlx backend")
@@ -89,6 +101,14 @@ def to_scalar_list(x):
89101

90102
return mx.array(to_scalar_list(x), dtype=mlx_dtype)
91103

104+
if _is_h5py_dataset(x):
105+
if h5py is None:
106+
raise ImportError(
107+
"h5py must be installed in order to load HDF5 datasets."
108+
)
109+
# load h5py._hl.dataset.Dataset object with numpy
110+
x = np.array(x)
111+
92112
return mx.array(x, dtype=mlx_dtype)
93113

94114

@@ -279,18 +299,32 @@ def while_loop(
279299
loop_vars,
280300
maximum_iterations=None,
281301
):
282-
# TODO: How should we avoid evaluating cond when tracing?
283302
current_iter = 0
284303
iteration_check = (
285304
lambda iter: maximum_iterations is None or iter < maximum_iterations
286305
)
287-
loop_vars = tuple([convert_to_tensor(v) for v in loop_vars])
288-
while cond(*loop_vars) and iteration_check(current_iter):
289-
loop_vars = body(*loop_vars)
290-
if not isinstance(loop_vars, (list, tuple)):
291-
loop_vars = (loop_vars,)
292-
loop_vars = tuple(loop_vars)
306+
307+
is_sequence = isinstance(loop_vars, (tuple, list))
308+
309+
if is_sequence:
310+
loop_vars = tuple(convert_to_tensor(v) for v in loop_vars)
311+
else:
312+
loop_vars = tree.map_structure(convert_to_tensor, loop_vars)
313+
314+
while (
315+
cond(*loop_vars) if is_sequence else cond(loop_vars)
316+
) and iteration_check(current_iter):
317+
new_vars = body(*loop_vars) if is_sequence else body(loop_vars)
318+
319+
if is_sequence:
320+
if not isinstance(new_vars, (tuple, list)):
321+
new_vars = (new_vars,)
322+
loop_vars = tuple(convert_to_tensor(v) for v in new_vars)
323+
else:
324+
loop_vars = tree.map_structure(convert_to_tensor, new_vars)
325+
293326
current_iter += 1
327+
294328
return loop_vars
295329

296330

0 commit comments

Comments
 (0)