|
7 | 7 | from keras.src.backend.common.keras_tensor import KerasTensor
|
8 | 8 | from keras.src.backend.common.stateless_scope import StatelessScope
|
9 | 9 |
|
| 10 | +try: |
| 11 | + import h5py |
| 12 | +except ImportError: |
| 13 | + h5py = None |
| 14 | + |
10 | 15 | SUPPORTS_SPARSE_TENSORS = False
|
11 | 16 |
|
12 | 17 | MLX_DTYPES = {
|
@@ -55,6 +60,13 @@ def __array__(self, dtype=None):
|
55 | 60 | return value
|
56 | 61 |
|
57 | 62 |
|
| 63 | +def _is_h5py_dataset(obj): |
| 64 | + return ( |
| 65 | + type(obj).__module__.startswith("h5py.") |
| 66 | + and type(obj).__name__ == "Dataset" |
| 67 | + ) |
| 68 | + |
| 69 | + |
58 | 70 | def convert_to_tensor(x, dtype=None, sparse=None):
|
59 | 71 | if sparse:
|
60 | 72 | raise ValueError("`sparse=True` is not supported with mlx backend")
|
@@ -89,6 +101,14 @@ def to_scalar_list(x):
|
89 | 101 |
|
90 | 102 | return mx.array(to_scalar_list(x), dtype=mlx_dtype)
|
91 | 103 |
|
| 104 | + if _is_h5py_dataset(x): |
| 105 | + if h5py is None: |
| 106 | + raise ImportError( |
| 107 | + "h5py must be installed in order to load HDF5 datasets." |
| 108 | + ) |
| 109 | + # load h5py._hl.dataset.Dataset object with numpy |
| 110 | + x = np.array(x) |
| 111 | + |
92 | 112 | return mx.array(x, dtype=mlx_dtype)
|
93 | 113 |
|
94 | 114 |
|
@@ -279,18 +299,32 @@ def while_loop(
|
279 | 299 | loop_vars,
|
280 | 300 | maximum_iterations=None,
|
281 | 301 | ):
|
282 |
| - # TODO: How should we avoid evaluating cond when tracing? |
283 | 302 | current_iter = 0
|
284 | 303 | iteration_check = (
|
285 | 304 | lambda iter: maximum_iterations is None or iter < maximum_iterations
|
286 | 305 | )
|
287 |
| - loop_vars = tuple([convert_to_tensor(v) for v in loop_vars]) |
288 |
| - while cond(*loop_vars) and iteration_check(current_iter): |
289 |
| - loop_vars = body(*loop_vars) |
290 |
| - if not isinstance(loop_vars, (list, tuple)): |
291 |
| - loop_vars = (loop_vars,) |
292 |
| - loop_vars = tuple(loop_vars) |
| 306 | + |
| 307 | + is_sequence = isinstance(loop_vars, (tuple, list)) |
| 308 | + |
| 309 | + if is_sequence: |
| 310 | + loop_vars = tuple(convert_to_tensor(v) for v in loop_vars) |
| 311 | + else: |
| 312 | + loop_vars = tree.map_structure(convert_to_tensor, loop_vars) |
| 313 | + |
| 314 | + while ( |
| 315 | + cond(*loop_vars) if is_sequence else cond(loop_vars) |
| 316 | + ) and iteration_check(current_iter): |
| 317 | + new_vars = body(*loop_vars) if is_sequence else body(loop_vars) |
| 318 | + |
| 319 | + if is_sequence: |
| 320 | + if not isinstance(new_vars, (tuple, list)): |
| 321 | + new_vars = (new_vars,) |
| 322 | + loop_vars = tuple(convert_to_tensor(v) for v in new_vars) |
| 323 | + else: |
| 324 | + loop_vars = tree.map_structure(convert_to_tensor, new_vars) |
| 325 | + |
293 | 326 | current_iter += 1
|
| 327 | + |
294 | 328 | return loop_vars
|
295 | 329 |
|
296 | 330 |
|
|
0 commit comments