1
1
import asyncio
2
- from typing import (
3
- Any ,
4
- AsyncIterator ,
5
- Callable ,
6
- Dict ,
7
- Iterable ,
8
- List ,
9
- Optional ,
10
- Sequence ,
11
- Tuple ,
12
- Union ,
13
- )
14
-
15
- from aiostream import stream
16
- from aiostream .core import Stream
2
+ from typing import Any , Callable , Dict , Optional , Sequence
3
+
17
4
from dask .distributed import Client
18
5
from networkx import MultiDiGraph
19
6
7
+ from cubed .runtime .asyncio import async_map_dag
20
8
from cubed .runtime .backup import use_backups_default
21
- from cubed .runtime .executors .asyncio import async_map_unordered
22
- from cubed .runtime .pipeline import visit_node_generations , visit_nodes
23
- from cubed .runtime .types import Callback , CubedPipeline , DagExecutor
24
- from cubed .runtime .utils import (
25
- asyncio_run ,
26
- execution_stats ,
27
- gensym ,
28
- handle_callbacks ,
29
- handle_operation_start_callbacks ,
30
- )
9
+ from cubed .runtime .types import Callback , DagExecutor
10
+ from cubed .runtime .utils import asyncio_run , execution_stats , gensym
31
11
from cubed .spec import Spec
32
12
33
13
@@ -38,68 +18,6 @@ def run_func(input, pipeline_func=None, config=None, name=None, compute_id=None)
38
18
return result
39
19
40
20
41
- async def map_unordered (
42
- client : Client ,
43
- map_function : Callable [..., Any ],
44
- map_iterdata : Iterable [Union [List [Any ], Tuple [Any , ...], Dict [str , Any ]]],
45
- retries : int = 2 ,
46
- use_backups : bool = False ,
47
- batch_size : Optional [int ] = None ,
48
- return_stats : bool = False ,
49
- name : Optional [str ] = None ,
50
- ** kwargs ,
51
- ) -> AsyncIterator [Any ]:
52
- def create_futures_func (input , ** kwargs ):
53
- input = list (input ) # dask expects a sequence (it calls `len` on it)
54
- key = name or gensym ("map" )
55
- key = key .replace ("-" , "_" ) # otherwise array number is not shown on dashboard
56
- return [
57
- (i , asyncio .ensure_future (f ))
58
- for i , f in zip (
59
- input ,
60
- client .map (map_function , input , key = key , retries = retries , ** kwargs ),
61
- )
62
- ]
63
-
64
- def create_backup_futures_func (input , ** kwargs ):
65
- input = list (input ) # dask expects a sequence (it calls `len` on it)
66
- key = name or gensym ("backup" )
67
- key = key .replace ("-" , "_" ) # otherwise array number is not shown on dashboard
68
- return [
69
- (i , asyncio .ensure_future (f ))
70
- for i , f in zip (input , client .map (map_function , input , key = key , ** kwargs ))
71
- ]
72
-
73
- async for result in async_map_unordered (
74
- create_futures_func ,
75
- map_iterdata ,
76
- use_backups = use_backups ,
77
- create_backup_futures_func = create_backup_futures_func ,
78
- batch_size = batch_size ,
79
- return_stats = return_stats ,
80
- name = name ,
81
- ** kwargs ,
82
- ):
83
- yield result
84
-
85
-
86
- def pipeline_to_stream (
87
- client : Client , name : str , pipeline : CubedPipeline , ** kwargs
88
- ) -> Stream :
89
- return stream .iterate (
90
- map_unordered (
91
- client ,
92
- run_func ,
93
- pipeline .mappable ,
94
- return_stats = True ,
95
- name = name ,
96
- pipeline_func = pipeline .function ,
97
- config = pipeline .config ,
98
- ** kwargs ,
99
- )
100
- )
101
-
102
-
103
21
def check_runtime_memory (spec , client ):
104
22
allowed_mem = spec .allowed_mem if spec is not None else None
105
23
scheduler_info = client .scheduler_info ()
@@ -114,40 +32,27 @@ def check_runtime_memory(spec, client):
114
32
)
115
33
116
34
117
- async def async_execute_dag (
118
- dag : MultiDiGraph ,
119
- callbacks : Optional [Sequence [Callback ]] = None ,
120
- resume : Optional [bool ] = None ,
121
- spec : Optional [Spec ] = None ,
122
- compute_arrays_in_parallel : Optional [bool ] = None ,
123
- compute_kwargs : Optional [Dict [str , Any ]] = None ,
124
- ** kwargs ,
125
- ) -> None :
126
- compute_kwargs = compute_kwargs or {}
127
- async with Client (asynchronous = True , ** compute_kwargs ) as client :
128
- if spec is not None :
129
- check_runtime_memory (spec , client )
130
- if "use_backups" not in kwargs and use_backups_default (spec ):
131
- kwargs ["use_backups" ] = True
132
- if not compute_arrays_in_parallel :
133
- # run one pipeline at a time
134
- for name , node in visit_nodes (dag , resume = resume ):
135
- handle_operation_start_callbacks (callbacks , name )
136
- st = pipeline_to_stream (client , name , node ["pipeline" ], ** kwargs )
137
- async with st .stream () as streamer :
138
- async for result , stats in streamer :
139
- handle_callbacks (callbacks , result , stats )
140
- else :
141
- for gen in visit_node_generations (dag , resume = resume ):
142
- # run pipelines in the same topological generation in parallel by merging their streams
143
- streams = [
144
- pipeline_to_stream (client , name , node ["pipeline" ], ** kwargs )
145
- for name , node in gen
146
- ]
147
- merged_stream = stream .merge (* streams )
148
- async with merged_stream .stream () as streamer :
149
- async for result , stats in streamer :
150
- handle_callbacks (callbacks , result , stats )
35
+ def dask_create_futures_func (
36
+ client ,
37
+ function : Callable [..., Any ],
38
+ name : Optional [str ] = None ,
39
+ retries : Optional [str ] = None ,
40
+ ):
41
+ def create_futures_func (input , ** kwargs ):
42
+ input = list (input ) # dask expects a sequence (it calls `len` on it)
43
+ key = name or gensym ("map" )
44
+ key = key .replace ("-" , "_" ) # otherwise array number is not shown on dashboard
45
+ if "func" in kwargs :
46
+ kwargs ["pipeline_func" ] = kwargs .pop ("func" ) # rename to avoid clash
47
+ return [
48
+ (i , asyncio .ensure_future (f ))
49
+ for i , f in zip (
50
+ input ,
51
+ client .map (function , input , key = key , retries = retries , ** kwargs ),
52
+ )
53
+ ]
54
+
55
+ return create_futures_func
151
56
152
57
153
58
class DaskExecutor (DagExecutor ):
@@ -172,7 +77,7 @@ def execute_dag(
172
77
) -> None :
173
78
merged_kwargs = {** self .kwargs , ** kwargs }
174
79
asyncio_run (
175
- async_execute_dag (
80
+ self . _async_execute_dag (
176
81
dag ,
177
82
callbacks = callbacks ,
178
83
resume = resume ,
@@ -182,3 +87,39 @@ def execute_dag(
182
87
** merged_kwargs ,
183
88
)
184
89
)
90
+
91
+ async def _async_execute_dag (
92
+ self ,
93
+ dag : MultiDiGraph ,
94
+ callbacks : Optional [Sequence [Callback ]] = None ,
95
+ resume : Optional [bool ] = None ,
96
+ spec : Optional [Spec ] = None ,
97
+ compute_kwargs : Optional [Dict [str , Any ]] = None ,
98
+ compute_arrays_in_parallel : Optional [bool ] = None ,
99
+ ** kwargs ,
100
+ ) -> None :
101
+ compute_kwargs = compute_kwargs or {}
102
+ retries = kwargs .pop ("retries" , 2 )
103
+ name = kwargs .get ("name" , None )
104
+ async with Client (asynchronous = True , ** compute_kwargs ) as client :
105
+ if spec is not None :
106
+ check_runtime_memory (spec , client )
107
+ if "use_backups" not in kwargs and use_backups_default (spec ):
108
+ kwargs ["use_backups" ] = True
109
+
110
+ create_futures_func = dask_create_futures_func (
111
+ client , run_func , name , retries = retries
112
+ )
113
+ create_backup_futures_func = dask_create_futures_func (
114
+ client , run_func , name
115
+ )
116
+
117
+ await async_map_dag (
118
+ create_futures_func ,
119
+ dag = dag ,
120
+ callbacks = callbacks ,
121
+ resume = resume ,
122
+ compute_arrays_in_parallel = compute_arrays_in_parallel ,
123
+ create_backup_futures_func = create_backup_futures_func ,
124
+ ** kwargs ,
125
+ )
0 commit comments