@@ -207,7 +207,9 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
207
207
"""Compiles functions from all blockwise ops by mutating the input dag."""
208
208
# Recommended: make a copy of the dag before calling this function.
209
209
210
- compile_with_config = 'config' in inspect .getfullargspec (compile_function ).kwonlyargs
210
+ compile_with_config = (
211
+ "config" in inspect .getfullargspec (compile_function ).kwonlyargs
212
+ )
211
213
212
214
for n in dag .nodes :
213
215
node = dag .nodes [n ]
@@ -219,31 +221,36 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
219
221
continue
220
222
221
223
if compile_with_config :
222
- compiled = compile_function (node ["pipeline" ].config .function , config = node ["pipeline" ].config )
224
+ compiled = compile_function (
225
+ node ["pipeline" ].config .function , config = node ["pipeline" ].config
226
+ )
223
227
else :
224
228
compiled = compile_function (node ["pipeline" ].config .function )
225
229
226
230
# node is a blockwise primitive_op.
227
231
# maybe we should investigate some sort of optics library for frozen dataclasses...
228
232
new_pipeline = dataclasses .replace (
229
233
node ["pipeline" ],
230
- config = dataclasses .replace (node ["pipeline" ].config , function = compiled )
234
+ config = dataclasses .replace (node ["pipeline" ].config , function = compiled ),
231
235
)
232
236
node ["pipeline" ] = new_pipeline
233
237
234
238
return dag
235
239
236
240
@lru_cache
237
- def _finalize_dag (
238
- self , optimize_graph : bool = True , optimize_function = None , compile_function : Optional [Decorator ] = None ,
239
- ) -> nx .MultiDiGraph :
241
+ def _finalize (
242
+ self ,
243
+ optimize_graph : bool = True ,
244
+ optimize_function = None ,
245
+ compile_function : Optional [Decorator ] = None ,
246
+ ) -> "FinalizedPlan" :
240
247
dag = self .optimize (optimize_function ).dag if optimize_graph else self .dag
241
248
# create a copy since _create_lazy_zarr_arrays mutates the dag
242
249
dag = dag .copy ()
243
250
if callable (compile_function ):
244
251
dag = self ._compile_blockwise (dag , compile_function )
245
252
dag = self ._create_lazy_zarr_arrays (dag )
246
- return nx .freeze (dag )
253
+ return FinalizedPlan ( nx .freeze (dag ) )
247
254
248
255
def execute (
249
256
self ,
@@ -256,7 +263,10 @@ def execute(
256
263
spec = None ,
257
264
** kwargs ,
258
265
):
259
- dag = self ._finalize_dag (optimize_graph , optimize_function , compile_function )
266
+ finalized_plan = self ._finalize (
267
+ optimize_graph , optimize_function , compile_function
268
+ )
269
+ dag = finalized_plan .dag
260
270
261
271
compute_id = f"compute-{ datetime .now ().strftime ('%Y%m%dT%H%M%S.%f' )} "
262
272
@@ -275,43 +285,6 @@ def execute(
275
285
event = ComputeEndEvent (compute_id , dag )
276
286
[callback .on_compute_end (event ) for callback in callbacks ]
277
287
278
- def num_tasks (self , optimize_graph = True , optimize_function = None , resume = None ):
279
- """Return the number of tasks needed to execute this plan."""
280
- dag = self ._finalize_dag (optimize_graph , optimize_function )
281
- tasks = 0
282
- for _ , node in visit_nodes (dag , resume = resume ):
283
- tasks += node ["primitive_op" ].num_tasks
284
- return tasks
285
-
286
- def num_arrays (self , optimize_graph : bool = True , optimize_function = None ) -> int :
287
- """Return the number of arrays in this plan."""
288
- dag = self ._finalize_dag (optimize_graph , optimize_function )
289
- return sum (d .get ("type" ) == "array" for _ , d in dag .nodes (data = True ))
290
-
291
- def max_projected_mem (
292
- self , optimize_graph = True , optimize_function = None , resume = None
293
- ):
294
- """Return the maximum projected memory across all tasks to execute this plan."""
295
- dag = self ._finalize_dag (optimize_graph , optimize_function )
296
- projected_mem_values = [
297
- node ["primitive_op" ].projected_mem
298
- for _ , node in visit_nodes (dag , resume = resume )
299
- ]
300
- return max (projected_mem_values ) if len (projected_mem_values ) > 0 else 0
301
-
302
- def total_nbytes_written (
303
- self , optimize_graph : bool = True , optimize_function = None
304
- ) -> int :
305
- """Return the total number of bytes written for all materialized arrays in this plan."""
306
- dag = self ._finalize_dag (optimize_graph , optimize_function )
307
- nbytes = 0
308
- for _ , d in dag .nodes (data = True ):
309
- if d .get ("type" ) == "array" :
310
- target = d ["target" ]
311
- if isinstance (target , LazyZarrArray ):
312
- nbytes += target .nbytes
313
- return nbytes
314
-
315
288
def visualize (
316
289
self ,
317
290
filename = "cubed" ,
@@ -321,7 +294,8 @@ def visualize(
321
294
optimize_function = None ,
322
295
show_hidden = False ,
323
296
):
324
- dag = self ._finalize_dag (optimize_graph , optimize_function )
297
+ finalized_plan = self ._finalize (optimize_graph , optimize_function )
298
+ dag = finalized_plan .dag
325
299
dag = dag .copy () # make a copy since we mutate the DAG below
326
300
327
301
# remove edges from create-arrays output node to avoid cluttering the diagram
@@ -336,9 +310,9 @@ def visualize(
336
310
"rankdir" : rankdir ,
337
311
"label" : (
338
312
# note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
339
- rf"num tasks: { self .num_tasks (optimize_graph , optimize_function )} \l"
340
- rf"max projected memory: { memory_repr (self .max_projected_mem (optimize_graph , optimize_function ))} \l"
341
- rf"total nbytes written: { memory_repr (self .total_nbytes_written (optimize_graph , optimize_function ))} \l"
313
+ rf"num tasks: { finalized_plan .num_tasks ()} \l"
314
+ rf"max projected memory: { memory_repr (finalized_plan .max_projected_mem ())} \l"
315
+ rf"total nbytes written: { memory_repr (finalized_plan .total_nbytes_written ())} \l"
342
316
rf"optimized: { optimize_graph } \l"
343
317
),
344
318
"labelloc" : "bottom" ,
@@ -474,6 +448,49 @@ def visualize(
474
448
return None
475
449
476
450
451
+ class FinalizedPlan :
452
+ """A plan that is ready to be run.
453
+
454
+ Finalizing a plan involves the following steps:
455
+ 1. optimization (optional)
456
+ 2. adding housekeping nodes to create arrays
457
+ 3. compiling functions (optional)
458
+ 4. freezing the final DAG so it can't be changed
459
+ """
460
+
461
+ def __init__ (self , dag ):
462
+ self .dag = dag
463
+
464
+ def max_projected_mem (self , resume = None ):
465
+ """Return the maximum projected memory across all tasks to execute this plan."""
466
+ projected_mem_values = [
467
+ node ["primitive_op" ].projected_mem
468
+ for _ , node in visit_nodes (self .dag , resume = resume )
469
+ ]
470
+ return max (projected_mem_values ) if len (projected_mem_values ) > 0 else 0
471
+
472
+ def num_arrays (self ) -> int :
473
+ """Return the number of arrays in this plan."""
474
+ return sum (d .get ("type" ) == "array" for _ , d in self .dag .nodes (data = True ))
475
+
476
+ def num_tasks (self , resume = None ):
477
+ """Return the number of tasks needed to execute this plan."""
478
+ tasks = 0
479
+ for _ , node in visit_nodes (self .dag , resume = resume ):
480
+ tasks += node ["primitive_op" ].num_tasks
481
+ return tasks
482
+
483
+ def total_nbytes_written (self ) -> int :
484
+ """Return the total number of bytes written for all materialized arrays in this plan."""
485
+ nbytes = 0
486
+ for _ , d in self .dag .nodes (data = True ):
487
+ if d .get ("type" ) == "array" :
488
+ target = d ["target" ]
489
+ if isinstance (target , LazyZarrArray ):
490
+ nbytes += target .nbytes
491
+ return nbytes
492
+
493
+
477
494
def arrays_to_dag (* arrays ):
478
495
from .array import check_array_specs
479
496
0 commit comments