Skip to content

Commit b62d9ea

Browse files
authored
Refactor asyncio executors to share common code (#685)
Remove map_unordered Move pipeline_to_stream to asyncio.py Add general async_map_dag function Move cubed/runtime/executors/asyncio.py to cubed/runtime/asyncio.py
1 parent ea3496b commit b62d9ea

File tree

8 files changed

+337
-428
lines changed

8 files changed

+337
-428
lines changed

cubed/runtime/executors/asyncio.py renamed to cubed/runtime/asyncio.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,30 @@
22
import copy
33
import time
44
from asyncio import Future
5-
from typing import Any, AsyncIterator, Callable, Dict, Iterable, List, Optional, Tuple
5+
from typing import (
6+
Any,
7+
AsyncIterator,
8+
Callable,
9+
Dict,
10+
Iterable,
11+
List,
12+
Optional,
13+
Sequence,
14+
Tuple,
15+
)
16+
17+
from aiostream import stream
18+
from aiostream.core import Stream
19+
from networkx import MultiDiGraph
620

721
from cubed.runtime.backup import should_launch_backup
8-
from cubed.runtime.utils import batched
22+
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
23+
from cubed.runtime.types import Callback, CubedPipeline
24+
from cubed.runtime.utils import (
25+
batched,
26+
handle_callbacks,
27+
handle_operation_start_callbacks,
28+
)
929

1030

1131
async def async_map_unordered(
@@ -100,3 +120,61 @@ async def async_map_unordered(
100120
pending.update(new_tasks.keys())
101121
t = time.monotonic()
102122
start_times = {f: t for f in new_tasks.keys()}
123+
124+
125+
async def async_map_dag(
126+
create_futures_func: Callable,
127+
dag: MultiDiGraph,
128+
callbacks: Optional[Sequence[Callback]] = None,
129+
resume: Optional[bool] = None,
130+
compute_arrays_in_parallel: Optional[bool] = None,
131+
**kwargs,
132+
) -> None:
133+
"""
134+
Asynchronous parallel map over multiple pipelines from a DAG, with support for backups and batching.
135+
"""
136+
if not compute_arrays_in_parallel:
137+
# run one pipeline at a time
138+
for name, node in visit_nodes(dag, resume=resume):
139+
handle_operation_start_callbacks(callbacks, name)
140+
st = pipeline_to_stream(
141+
create_futures_func, name, node["pipeline"], **kwargs
142+
)
143+
async with st.stream() as streamer:
144+
async for result, stats in streamer:
145+
handle_callbacks(callbacks, result, stats)
146+
else:
147+
for gen in visit_node_generations(dag, resume=resume):
148+
# run pipelines in the same topological generation in parallel by merging their streams
149+
streams = [
150+
pipeline_to_stream(
151+
create_futures_func, name, node["pipeline"], **kwargs
152+
)
153+
for name, node in gen
154+
]
155+
merged_stream = stream.merge(*streams)
156+
async with merged_stream.stream() as streamer:
157+
async for result, stats in streamer:
158+
handle_callbacks(callbacks, result, stats)
159+
160+
161+
def pipeline_to_stream(
162+
create_futures_func: Callable,
163+
name: str,
164+
pipeline: CubedPipeline,
165+
**kwargs,
166+
) -> Stream:
167+
"""
168+
Turn a pipeline into an asynchronous stream of results.
169+
"""
170+
return stream.iterate(
171+
async_map_unordered(
172+
create_futures_func,
173+
pipeline.mappable,
174+
return_stats=True,
175+
name=name,
176+
func=pipeline.function,
177+
config=pipeline.config,
178+
**kwargs,
179+
)
180+
)

cubed/runtime/executors/dask.py

Lines changed: 63 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,13 @@
11
import asyncio
2-
from typing import (
3-
Any,
4-
AsyncIterator,
5-
Callable,
6-
Dict,
7-
Iterable,
8-
List,
9-
Optional,
10-
Sequence,
11-
Tuple,
12-
Union,
13-
)
14-
15-
from aiostream import stream
16-
from aiostream.core import Stream
2+
from typing import Any, Callable, Dict, Optional, Sequence
3+
174
from dask.distributed import Client
185
from networkx import MultiDiGraph
196

7+
from cubed.runtime.asyncio import async_map_dag
208
from cubed.runtime.backup import use_backups_default
21-
from cubed.runtime.executors.asyncio import async_map_unordered
22-
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
23-
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor
24-
from cubed.runtime.utils import (
25-
asyncio_run,
26-
execution_stats,
27-
gensym,
28-
handle_callbacks,
29-
handle_operation_start_callbacks,
30-
)
9+
from cubed.runtime.types import Callback, DagExecutor
10+
from cubed.runtime.utils import asyncio_run, execution_stats, gensym
3111
from cubed.spec import Spec
3212

3313

@@ -38,68 +18,6 @@ def run_func(input, pipeline_func=None, config=None, name=None, compute_id=None)
3818
return result
3919

4020

41-
async def map_unordered(
42-
client: Client,
43-
map_function: Callable[..., Any],
44-
map_iterdata: Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]],
45-
retries: int = 2,
46-
use_backups: bool = False,
47-
batch_size: Optional[int] = None,
48-
return_stats: bool = False,
49-
name: Optional[str] = None,
50-
**kwargs,
51-
) -> AsyncIterator[Any]:
52-
def create_futures_func(input, **kwargs):
53-
input = list(input) # dask expects a sequence (it calls `len` on it)
54-
key = name or gensym("map")
55-
key = key.replace("-", "_") # otherwise array number is not shown on dashboard
56-
return [
57-
(i, asyncio.ensure_future(f))
58-
for i, f in zip(
59-
input,
60-
client.map(map_function, input, key=key, retries=retries, **kwargs),
61-
)
62-
]
63-
64-
def create_backup_futures_func(input, **kwargs):
65-
input = list(input) # dask expects a sequence (it calls `len` on it)
66-
key = name or gensym("backup")
67-
key = key.replace("-", "_") # otherwise array number is not shown on dashboard
68-
return [
69-
(i, asyncio.ensure_future(f))
70-
for i, f in zip(input, client.map(map_function, input, key=key, **kwargs))
71-
]
72-
73-
async for result in async_map_unordered(
74-
create_futures_func,
75-
map_iterdata,
76-
use_backups=use_backups,
77-
create_backup_futures_func=create_backup_futures_func,
78-
batch_size=batch_size,
79-
return_stats=return_stats,
80-
name=name,
81-
**kwargs,
82-
):
83-
yield result
84-
85-
86-
def pipeline_to_stream(
87-
client: Client, name: str, pipeline: CubedPipeline, **kwargs
88-
) -> Stream:
89-
return stream.iterate(
90-
map_unordered(
91-
client,
92-
run_func,
93-
pipeline.mappable,
94-
return_stats=True,
95-
name=name,
96-
pipeline_func=pipeline.function,
97-
config=pipeline.config,
98-
**kwargs,
99-
)
100-
)
101-
102-
10321
def check_runtime_memory(spec, client):
10422
allowed_mem = spec.allowed_mem if spec is not None else None
10523
scheduler_info = client.scheduler_info()
@@ -114,40 +32,27 @@ def check_runtime_memory(spec, client):
11432
)
11533

11634

117-
async def async_execute_dag(
118-
dag: MultiDiGraph,
119-
callbacks: Optional[Sequence[Callback]] = None,
120-
resume: Optional[bool] = None,
121-
spec: Optional[Spec] = None,
122-
compute_arrays_in_parallel: Optional[bool] = None,
123-
compute_kwargs: Optional[Dict[str, Any]] = None,
124-
**kwargs,
125-
) -> None:
126-
compute_kwargs = compute_kwargs or {}
127-
async with Client(asynchronous=True, **compute_kwargs) as client:
128-
if spec is not None:
129-
check_runtime_memory(spec, client)
130-
if "use_backups" not in kwargs and use_backups_default(spec):
131-
kwargs["use_backups"] = True
132-
if not compute_arrays_in_parallel:
133-
# run one pipeline at a time
134-
for name, node in visit_nodes(dag, resume=resume):
135-
handle_operation_start_callbacks(callbacks, name)
136-
st = pipeline_to_stream(client, name, node["pipeline"], **kwargs)
137-
async with st.stream() as streamer:
138-
async for result, stats in streamer:
139-
handle_callbacks(callbacks, result, stats)
140-
else:
141-
for gen in visit_node_generations(dag, resume=resume):
142-
# run pipelines in the same topological generation in parallel by merging their streams
143-
streams = [
144-
pipeline_to_stream(client, name, node["pipeline"], **kwargs)
145-
for name, node in gen
146-
]
147-
merged_stream = stream.merge(*streams)
148-
async with merged_stream.stream() as streamer:
149-
async for result, stats in streamer:
150-
handle_callbacks(callbacks, result, stats)
35+
def dask_create_futures_func(
36+
client,
37+
function: Callable[..., Any],
38+
name: Optional[str] = None,
39+
retries: Optional[str] = None,
40+
):
41+
def create_futures_func(input, **kwargs):
42+
input = list(input) # dask expects a sequence (it calls `len` on it)
43+
key = name or gensym("map")
44+
key = key.replace("-", "_") # otherwise array number is not shown on dashboard
45+
if "func" in kwargs:
46+
kwargs["pipeline_func"] = kwargs.pop("func") # rename to avoid clash
47+
return [
48+
(i, asyncio.ensure_future(f))
49+
for i, f in zip(
50+
input,
51+
client.map(function, input, key=key, retries=retries, **kwargs),
52+
)
53+
]
54+
55+
return create_futures_func
15156

15257

15358
class DaskExecutor(DagExecutor):
@@ -172,7 +77,7 @@ def execute_dag(
17277
) -> None:
17378
merged_kwargs = {**self.kwargs, **kwargs}
17479
asyncio_run(
175-
async_execute_dag(
80+
self._async_execute_dag(
17681
dag,
17782
callbacks=callbacks,
17883
resume=resume,
@@ -182,3 +87,39 @@ def execute_dag(
18287
**merged_kwargs,
18388
)
18489
)
90+
91+
async def _async_execute_dag(
92+
self,
93+
dag: MultiDiGraph,
94+
callbacks: Optional[Sequence[Callback]] = None,
95+
resume: Optional[bool] = None,
96+
spec: Optional[Spec] = None,
97+
compute_kwargs: Optional[Dict[str, Any]] = None,
98+
compute_arrays_in_parallel: Optional[bool] = None,
99+
**kwargs,
100+
) -> None:
101+
compute_kwargs = compute_kwargs or {}
102+
retries = kwargs.pop("retries", 2)
103+
name = kwargs.get("name", None)
104+
async with Client(asynchronous=True, **compute_kwargs) as client:
105+
if spec is not None:
106+
check_runtime_memory(spec, client)
107+
if "use_backups" not in kwargs and use_backups_default(spec):
108+
kwargs["use_backups"] = True
109+
110+
create_futures_func = dask_create_futures_func(
111+
client, run_func, name, retries=retries
112+
)
113+
create_backup_futures_func = dask_create_futures_func(
114+
client, run_func, name
115+
)
116+
117+
await async_map_dag(
118+
create_futures_func,
119+
dag=dag,
120+
callbacks=callbacks,
121+
resume=resume,
122+
compute_arrays_in_parallel=compute_arrays_in_parallel,
123+
create_backup_futures_func=create_backup_futures_func,
124+
**kwargs,
125+
)

0 commit comments

Comments
 (0)