Skip to content

Commit fcd4d21

Browse files
authored
Introduce FinalizedPlan (#563)
The idea here is to formalise the planning process - compose -> finalize (optimize, compile, housekeeping) -> compute/visualize.
1 parent d2ba5e1 commit fcd4d21

File tree

5 files changed

+118
-94
lines changed

5 files changed

+118
-94
lines changed

cubed/array_api/array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _repr_html_(self):
6565
grid=grid,
6666
nbytes=nbytes,
6767
cbytes=cbytes,
68-
arrs_in_plan=f"{self.plan.num_arrays()} arrays in Plan",
68+
arrs_in_plan=f"{self.plan._finalize().num_arrays()} arrays in Plan",
6969
arrtype="np.ndarray",
7070
)
7171

cubed/core/plan.py

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
207207
"""Compiles functions from all blockwise ops by mutating the input dag."""
208208
# Recommended: make a copy of the dag before calling this function.
209209

210-
compile_with_config = 'config' in inspect.getfullargspec(compile_function).kwonlyargs
210+
compile_with_config = (
211+
"config" in inspect.getfullargspec(compile_function).kwonlyargs
212+
)
211213

212214
for n in dag.nodes:
213215
node = dag.nodes[n]
@@ -219,31 +221,36 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
219221
continue
220222

221223
if compile_with_config:
222-
compiled = compile_function(node["pipeline"].config.function, config=node["pipeline"].config)
224+
compiled = compile_function(
225+
node["pipeline"].config.function, config=node["pipeline"].config
226+
)
223227
else:
224228
compiled = compile_function(node["pipeline"].config.function)
225229

226230
# node is a blockwise primitive_op.
227231
# maybe we should investigate some sort of optics library for frozen dataclasses...
228232
new_pipeline = dataclasses.replace(
229233
node["pipeline"],
230-
config=dataclasses.replace(node["pipeline"].config, function=compiled)
234+
config=dataclasses.replace(node["pipeline"].config, function=compiled),
231235
)
232236
node["pipeline"] = new_pipeline
233237

234238
return dag
235239

236240
@lru_cache
237-
def _finalize_dag(
238-
self, optimize_graph: bool = True, optimize_function=None, compile_function: Optional[Decorator] = None,
239-
) -> nx.MultiDiGraph:
241+
def _finalize(
242+
self,
243+
optimize_graph: bool = True,
244+
optimize_function=None,
245+
compile_function: Optional[Decorator] = None,
246+
) -> "FinalizedPlan":
240247
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
241248
# create a copy since _create_lazy_zarr_arrays mutates the dag
242249
dag = dag.copy()
243250
if callable(compile_function):
244251
dag = self._compile_blockwise(dag, compile_function)
245252
dag = self._create_lazy_zarr_arrays(dag)
246-
return nx.freeze(dag)
253+
return FinalizedPlan(nx.freeze(dag))
247254

248255
def execute(
249256
self,
@@ -256,7 +263,10 @@ def execute(
256263
spec=None,
257264
**kwargs,
258265
):
259-
dag = self._finalize_dag(optimize_graph, optimize_function, compile_function)
266+
finalized_plan = self._finalize(
267+
optimize_graph, optimize_function, compile_function
268+
)
269+
dag = finalized_plan.dag
260270

261271
compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"
262272

@@ -275,43 +285,6 @@ def execute(
275285
event = ComputeEndEvent(compute_id, dag)
276286
[callback.on_compute_end(event) for callback in callbacks]
277287

278-
def num_tasks(self, optimize_graph=True, optimize_function=None, resume=None):
279-
"""Return the number of tasks needed to execute this plan."""
280-
dag = self._finalize_dag(optimize_graph, optimize_function)
281-
tasks = 0
282-
for _, node in visit_nodes(dag, resume=resume):
283-
tasks += node["primitive_op"].num_tasks
284-
return tasks
285-
286-
def num_arrays(self, optimize_graph: bool = True, optimize_function=None) -> int:
287-
"""Return the number of arrays in this plan."""
288-
dag = self._finalize_dag(optimize_graph, optimize_function)
289-
return sum(d.get("type") == "array" for _, d in dag.nodes(data=True))
290-
291-
def max_projected_mem(
292-
self, optimize_graph=True, optimize_function=None, resume=None
293-
):
294-
"""Return the maximum projected memory across all tasks to execute this plan."""
295-
dag = self._finalize_dag(optimize_graph, optimize_function)
296-
projected_mem_values = [
297-
node["primitive_op"].projected_mem
298-
for _, node in visit_nodes(dag, resume=resume)
299-
]
300-
return max(projected_mem_values) if len(projected_mem_values) > 0 else 0
301-
302-
def total_nbytes_written(
303-
self, optimize_graph: bool = True, optimize_function=None
304-
) -> int:
305-
"""Return the total number of bytes written for all materialized arrays in this plan."""
306-
dag = self._finalize_dag(optimize_graph, optimize_function)
307-
nbytes = 0
308-
for _, d in dag.nodes(data=True):
309-
if d.get("type") == "array":
310-
target = d["target"]
311-
if isinstance(target, LazyZarrArray):
312-
nbytes += target.nbytes
313-
return nbytes
314-
315288
def visualize(
316289
self,
317290
filename="cubed",
@@ -321,7 +294,8 @@ def visualize(
321294
optimize_function=None,
322295
show_hidden=False,
323296
):
324-
dag = self._finalize_dag(optimize_graph, optimize_function)
297+
finalized_plan = self._finalize(optimize_graph, optimize_function)
298+
dag = finalized_plan.dag
325299
dag = dag.copy() # make a copy since we mutate the DAG below
326300

327301
# remove edges from create-arrays output node to avoid cluttering the diagram
@@ -336,9 +310,9 @@ def visualize(
336310
"rankdir": rankdir,
337311
"label": (
338312
# note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
339-
rf"num tasks: {self.num_tasks(optimize_graph, optimize_function)}\l"
340-
rf"max projected memory: {memory_repr(self.max_projected_mem(optimize_graph, optimize_function))}\l"
341-
rf"total nbytes written: {memory_repr(self.total_nbytes_written(optimize_graph, optimize_function))}\l"
313+
rf"num tasks: {finalized_plan.num_tasks()}\l"
314+
rf"max projected memory: {memory_repr(finalized_plan.max_projected_mem())}\l"
315+
rf"total nbytes written: {memory_repr(finalized_plan.total_nbytes_written())}\l"
342316
rf"optimized: {optimize_graph}\l"
343317
),
344318
"labelloc": "bottom",
@@ -474,6 +448,49 @@ def visualize(
474448
return None
475449

476450

451+
class FinalizedPlan:
452+
"""A plan that is ready to be run.
453+
454+
Finalizing a plan involves the following steps:
455+
1. optimization (optional)
456+
2. adding housekeping nodes to create arrays
457+
3. compiling functions (optional)
458+
4. freezing the final DAG so it can't be changed
459+
"""
460+
461+
def __init__(self, dag):
462+
self.dag = dag
463+
464+
def max_projected_mem(self, resume=None):
465+
"""Return the maximum projected memory across all tasks to execute this plan."""
466+
projected_mem_values = [
467+
node["primitive_op"].projected_mem
468+
for _, node in visit_nodes(self.dag, resume=resume)
469+
]
470+
return max(projected_mem_values) if len(projected_mem_values) > 0 else 0
471+
472+
def num_arrays(self) -> int:
473+
"""Return the number of arrays in this plan."""
474+
return sum(d.get("type") == "array" for _, d in self.dag.nodes(data=True))
475+
476+
def num_tasks(self, resume=None):
477+
"""Return the number of tasks needed to execute this plan."""
478+
tasks = 0
479+
for _, node in visit_nodes(self.dag, resume=resume):
480+
tasks += node["primitive_op"].num_tasks
481+
return tasks
482+
483+
def total_nbytes_written(self) -> int:
484+
"""Return the total number of bytes written for all materialized arrays in this plan."""
485+
nbytes = 0
486+
for _, d in self.dag.nodes(data=True):
487+
if d.get("type") == "array":
488+
target = d["target"]
489+
if isinstance(target, LazyZarrArray):
490+
nbytes += target.nbytes
491+
return nbytes
492+
493+
477494
def arrays_to_dag(*arrays):
478495
from .array import check_array_specs
479496

cubed/tests/test_core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,13 +373,14 @@ def test_reduction_multiple_rounds(tmp_path, executor):
373373
a = xp.ones((100, 10), dtype=np.uint8, chunks=(1, 10), spec=spec)
374374
b = xp.sum(a, axis=0, dtype=np.uint8)
375375
# check that there is > 1 blockwise step (after optimization)
376+
finalized_plan = b.plan._finalize()
376377
blockwises = [
377378
n
378-
for (n, d) in b.plan.dag.nodes(data=True)
379+
for (n, d) in finalized_plan.dag.nodes(data=True)
379380
if d.get("op_name", None) == "blockwise"
380381
]
381382
assert len(blockwises) > 1
382-
assert b.plan.max_projected_mem() <= 1000
383+
assert finalized_plan.max_projected_mem() <= 1000
383384
assert_array_equal(b.compute(executor=executor), np.ones((100, 10)).sum(axis=0))
384385

385386

@@ -555,7 +556,7 @@ def test_plan_scaling(tmp_path, factor):
555556
)
556557
c = xp.matmul(a, b)
557558

558-
assert c.plan.num_tasks() > 0
559+
assert c.plan._finalize().num_tasks() > 0
559560
c.visualize(filename=tmp_path / "c")
560561

561562

@@ -568,7 +569,7 @@ def test_plan_quad_means(tmp_path, t_length):
568569
uv = u * v
569570
m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)
570571

571-
assert m.plan.num_tasks() > 0
572+
assert m.plan._finalize().num_tasks() > 0
572573
m.visualize(
573574
filename=tmp_path / "quad_means_unoptimized",
574575
optimize_graph=False,

cubed/tests/test_executor_features.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_resume(spec, executor):
181181
d = xp.negative(c)
182182

183183
num_created_arrays = 2 # c, d
184-
assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 8
184+
assert d.plan._finalize(optimize_graph=False).num_tasks() == num_created_arrays + 8
185185

186186
task_counter = TaskCounter()
187187
c.compute(executor=executor, callbacks=[task_counter], optimize_graph=False)
@@ -321,13 +321,15 @@ def test_check_runtime_memory_processes(spec, executor):
321321

322322
try:
323323
from numba import jit as numba_jit
324+
324325
COMPILE_FUNCTIONS.append(numba_jit)
325326
except ModuleNotFoundError:
326327
pass
327328

328329
try:
329-
if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''):
330+
if "jax" in os.environ.get("CUBED_BACKEND_ARRAY_API_MODULE", ""):
330331
from jax import jit as jax_jit
332+
331333
COMPILE_FUNCTIONS.append(jax_jit)
332334
except ModuleNotFoundError:
333335
pass
@@ -339,7 +341,8 @@ def test_check_compilation(spec, executor, compile_function):
339341
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
340342
c = xp.add(a, b)
341343
assert_array_equal(
342-
c.compute(executor=executor, compile_function=compile_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]])
344+
c.compute(executor=executor, compile_function=compile_function),
345+
np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]),
343346
)
344347

345348

@@ -352,7 +355,7 @@ def compile_function(func):
352355
c = xp.add(a, b)
353356
with pytest.raises(NotImplementedError) as excinfo:
354357
c.compute(executor=executor, compile_function=compile_function)
355-
358+
356359
assert "add" in str(excinfo.value), "Compile function was applied to add operation."
357360

358361

@@ -365,5 +368,7 @@ def compile_function(func, *, config=None):
365368
c = xp.add(a, b)
366369
with pytest.raises(NotImplementedError) as excinfo:
367370
c.compute(executor=executor, compile_function=compile_function)
368-
369-
assert "BlockwiseSpec" in str(excinfo.value), "Compile function was applied with a config argument."
371+
372+
assert "BlockwiseSpec" in str(
373+
excinfo.value
374+
), "Compile function was applied with a config argument."

0 commit comments

Comments
 (0)