Skip to content

Commit 381aa7d

Browse files
committed
Pass array names to optimize functions
1 parent 4a86044 commit 381aa7d

File tree

4 files changed

+24
-11
lines changed

4 files changed

+24
-11
lines changed

cubed/core/array.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def compute(
285285
optimize_graph=optimize_graph,
286286
optimize_function=optimize_function,
287287
resume=resume,
288+
array_names=tuple(a.name for a in arrays),
288289
spec=spec,
289290
**kwargs,
290291
)
@@ -335,6 +336,7 @@ def visualize(
335336
optimize_graph=optimize_graph,
336337
optimize_function=optimize_function,
337338
show_hidden=show_hidden,
339+
array_names=tuple(a.name for a in arrays),
338340
)
339341

340342

cubed/core/optimization.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15-
def simple_optimize_dag(dag):
15+
def simple_optimize_dag(dag, array_names=None):
1616
"""Apply map blocks fusion."""
1717

1818
# note there is no need to prune the dag, since the way it is built
@@ -205,6 +205,7 @@ def fuse_predecessors(
205205
dag,
206206
name,
207207
*,
208+
array_names=None,
208209
max_total_source_arrays=4,
209210
max_total_num_input_blocks=None,
210211
always_fuse=None,
@@ -258,6 +259,7 @@ def fuse_predecessors(
258259
def multiple_inputs_optimize_dag(
259260
dag,
260261
*,
262+
array_names=None,
261263
max_total_source_arrays=4,
262264
max_total_num_input_blocks=None,
263265
always_fuse=None,
@@ -270,6 +272,7 @@ def multiple_inputs_optimize_dag(
270272
dag = fuse_predecessors(
271273
dag,
272274
name,
275+
array_names=array_names,
273276
max_total_source_arrays=max_total_source_arrays,
274277
max_total_num_input_blocks=max_total_num_input_blocks,
275278
always_fuse=always_fuse,
@@ -278,18 +281,20 @@ def multiple_inputs_optimize_dag(
278281
return dag
279282

280283

281-
def fuse_all_optimize_dag(dag):
284+
def fuse_all_optimize_dag(dag, array_names=None):
282285
"""Force all operations to be fused."""
283286
dag = dag.copy()
284287
always_fuse = [op for op in dag.nodes() if op.startswith("op-")]
285-
return multiple_inputs_optimize_dag(dag, always_fuse=always_fuse)
288+
return multiple_inputs_optimize_dag(
289+
dag, array_names=array_names, always_fuse=always_fuse
290+
)
286291

287292

288-
def fuse_only_optimize_dag(dag, *, only_fuse=None):
293+
def fuse_only_optimize_dag(dag, *, array_names=None, only_fuse=None):
289294
"""Force only specified operations to be fused, all others will be left even if they are suitable for fusion."""
290295
dag = dag.copy()
291296
always_fuse = only_fuse
292297
never_fuse = set(op for op in dag.nodes() if op.startswith("op-")) - set(only_fuse)
293298
return multiple_inputs_optimize_dag(
294-
dag, always_fuse=always_fuse, never_fuse=never_fuse
299+
dag, array_names=array_names, always_fuse=always_fuse, never_fuse=never_fuse
295300
)

cubed/core/plan.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import uuid
77
from datetime import datetime
88
from functools import lru_cache
9-
from typing import Callable, Optional
9+
from typing import Callable, Optional, Tuple
1010

1111
import networkx as nx
1212
import zarr
@@ -154,10 +154,11 @@ def arrays_to_plan(cls, *arrays):
154154
def optimize(
155155
self,
156156
optimize_function: Optional[Callable[..., nx.MultiDiGraph]] = None,
157+
array_names: Optional[Tuple[str]] = None,
157158
):
158159
if optimize_function is None:
159160
optimize_function = multiple_inputs_optimize_dag
160-
dag = optimize_function(self.dag)
161+
dag = optimize_function(self.dag, array_names=array_names)
161162
return Plan(dag)
162163

163164
def _create_lazy_zarr_arrays(self, dag):
@@ -243,8 +244,9 @@ def _finalize(
243244
optimize_graph: bool = True,
244245
optimize_function=None,
245246
compile_function: Optional[Decorator] = None,
247+
array_names=None,
246248
) -> "FinalizedPlan":
247-
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
249+
dag = self.optimize(optimize_function, array_names).dag if optimize_graph else self.dag
248250
# create a copy since _create_lazy_zarr_arrays mutates the dag
249251
dag = dag.copy()
250252
if callable(compile_function):
@@ -260,11 +262,12 @@ def execute(
260262
optimize_function=None,
261263
compile_function=None,
262264
resume=None,
265+
array_names=None,
263266
spec=None,
264267
**kwargs,
265268
):
266269
finalized_plan = self._finalize(
267-
optimize_graph, optimize_function, compile_function
270+
optimize_graph, optimize_function, compile_function, array_names=array_names
268271
)
269272
dag = finalized_plan.dag
270273

@@ -293,8 +296,11 @@ def visualize(
293296
optimize_graph=True,
294297
optimize_function=None,
295298
show_hidden=False,
299+
array_names=None,
296300
):
297-
finalized_plan = self._finalize(optimize_graph, optimize_function)
301+
finalized_plan = self._finalize(
302+
optimize_graph, optimize_function, array_names=array_names
303+
)
298304
dag = finalized_plan.dag
299305
dag = dag.copy() # make a copy since we mutate the DAG below
300306

cubed/tests/test_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_custom_optimize_function(spec):
198198
< num_tasks_with_no_optimization
199199
)
200200

201-
def custom_optimize_function(dag):
201+
def custom_optimize_function(dag, array_names=None):
202202
# leave DAG unchanged
203203
return dag
204204

0 commit comments

Comments
 (0)