Skip to content

Commit 3656030

Browse files
authored
Merge pull request #584 from JuliaParallel/jps/stream-teardown
streaming: Add DAG teardown option
2 parents a765cbe + 3c5c389 commit 3656030

File tree

6 files changed

+223
-25
lines changed

6 files changed

+223
-25
lines changed

docs/src/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,4 +427,5 @@ wait(t)
427427
The above example demonstrates a streaming region that generates random numbers
428428
continuously and writes each random number to a file. The streaming region is
429429
terminated when a random number less than 0.01 is generated, which is done by
430-
calling `Dagger.finish_stream()` (this exits the current streaming task).
430+
calling `Dagger.finish_stream()` (this terminates the current task, and will
431+
also terminate all streaming tasks launched by `spawn_streaming`).

docs/src/streaming.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,8 @@ end
7979
```
8080

8181
If you want to stop the streaming DAG and tear it all down, you can call
82-
`Dagger.cancel!.(all_vals)` and `Dagger.cancel!.(all_vals_written)` to
83-
terminate each streaming task. In the future, a more convenient way to tear
84-
down a full DAG will be added; for now, each task must be cancelled individually.
82+
`Dagger.cancel!(all_vals[1])` (or with any other task in the streaming DAG) to
83+
terminate all streaming tasks.
8584

8685
Alternatively, tasks can stop themselves from the inside with
8786
`finish_stream`, optionally returning a value that can be `fetch`'d. Let's

src/dtask.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ function Base.fetch(t::DTask; raw=false)
8585
end
8686
return fetch(t.future; raw)
8787
end
88+
function waitany(tasks::Vector{DTask})
89+
if isempty(tasks)
90+
return
91+
end
92+
cond = Threads.Condition()
93+
for task in tasks
94+
Sch.errormonitor_tracked("waitany listener", Threads.@spawn begin
95+
wait(task)
96+
@lock cond notify(cond)
97+
end)
98+
end
99+
@lock cond wait(cond)
100+
return
101+
end
102+
function waitall(tasks::Vector{DTask})
103+
if isempty(tasks)
104+
return
105+
end
106+
@sync for task in tasks
107+
Threads.@spawn begin
108+
wait(task)
109+
@lock cond notify(cond)
110+
end
111+
end
112+
return
113+
end
88114
function Base.show(io::IO, t::DTask)
89115
status = if istaskstarted(t)
90116
isready(t) ? "finished" : "running"

src/stream.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,37 @@ function initialize_streaming!(self_streams, spec, task)
426426
end
427427
end
428428

429-
function spawn_streaming(f::Base.Callable)
429+
"""
430+
Starts a streaming region, within which all tasks run continuously and
431+
concurrently. Any `DTask` argument that is itself a streaming task will be
432+
treated as a streaming input/output. The streaming region will automatically
433+
handle the buffering and synchronization of these tasks' values.
434+
435+
# Keyword Arguments
436+
- `teardown::Bool=true`: If `true`, the streaming region will automatically
437+
cancel all tasks if any task fails or is cancelled. Otherwise, a failing task
438+
will not cancel the other tasks, which will continue running.
439+
"""
440+
function spawn_streaming(f::Base.Callable; teardown::Bool=true)
430441
queue = StreamingTaskQueue()
431442
result = with_options(f; task_queue=queue)
432443
if length(queue.tasks) > 0
433444
finalize_streaming!(queue.tasks, queue.self_streams)
434445
enqueue!(queue.tasks)
446+
447+
if teardown
448+
# Start teardown monitor
449+
dtasks = map(last, queue.tasks)::Vector{DTask}
450+
Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin
451+
# Wait for any task to finish
452+
waitany(dtasks)
453+
454+
# Cancel all tasks
455+
for task in dtasks
456+
cancel!(task; graceful=false)
457+
end
458+
end)
459+
end
435460
end
436461
return result
437462
end

src/utils/tasks.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,115 @@ function set_task_tid!(task::Task, tid::Integer)
1818
end
1919
@assert Threads.threadid(task) == tid "jl_set_task_tid failed!"
2020
end
21+
22+
if isdefined(Base, :waitany)
23+
import Base: waitany, waitall
24+
else
25+
# Vendored from Base
26+
# License is MIT
27+
waitany(tasks; throw=true) = _wait_multiple(tasks, throw)
28+
waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast)
29+
function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
30+
tasks = Task[]
31+
32+
for t in waiting_tasks
33+
t isa Task || error("Expected an iterator of `Task` object")
34+
push!(tasks, t)
35+
end
36+
37+
if (all && !failfast) || length(tasks) <= 1
38+
exception = false
39+
# Force everything to finish synchronously for the case of waitall
40+
# with failfast=false
41+
for t in tasks
42+
_wait(t)
43+
exception |= istaskfailed(t)
44+
end
45+
if exception && throwexc
46+
exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)]
47+
throw(CompositeException(exceptions))
48+
else
49+
return tasks, Task[]
50+
end
51+
end
52+
53+
exception = false
54+
nremaining::Int = length(tasks)
55+
done_mask = falses(nremaining)
56+
for (i, t) in enumerate(tasks)
57+
if istaskdone(t)
58+
done_mask[i] = true
59+
exception |= istaskfailed(t)
60+
nremaining -= 1
61+
else
62+
done_mask[i] = false
63+
end
64+
end
65+
66+
if nremaining == 0
67+
return tasks, Task[]
68+
elseif any(done_mask) && (!all || (failfast && exception))
69+
if throwexc && (!all || failfast) && exception
70+
exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)]
71+
throw(CompositeException(exceptions))
72+
else
73+
return tasks[done_mask], tasks[.~done_mask]
74+
end
75+
end
76+
77+
chan = Channel{Int}(Inf)
78+
sentinel = current_task()
79+
waiter_tasks = fill(sentinel, length(tasks))
80+
81+
for (i, done) in enumerate(done_mask)
82+
done && continue
83+
t = tasks[i]
84+
if istaskdone(t)
85+
done_mask[i] = true
86+
exception |= istaskfailed(t)
87+
nremaining -= 1
88+
exception && failfast && break
89+
else
90+
waiter = @task put!(chan, i)
91+
waiter.sticky = false
92+
_wait2(t, waiter)
93+
waiter_tasks[i] = waiter
94+
end
95+
end
96+
97+
while nremaining > 0
98+
i = take!(chan)
99+
t = tasks[i]
100+
waiter_tasks[i] = sentinel
101+
done_mask[i] = true
102+
exception |= istaskfailed(t)
103+
nremaining -= 1
104+
105+
# stop early if requested, unless there is something immediately
106+
# ready to consume from the channel (using a race-y check)
107+
if (!all || (failfast && exception)) && !isready(chan)
108+
break
109+
end
110+
end
111+
112+
close(chan)
113+
114+
if nremaining == 0
115+
return tasks, Task[]
116+
else
117+
remaining_mask = .~done_mask
118+
for i in findall(remaining_mask)
119+
waiter = waiter_tasks[i]
120+
donenotify = tasks[i].donenotify::ThreadSynchronizer
121+
@lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter)
122+
end
123+
done_tasks = tasks[done_mask]
124+
if throwexc && exception
125+
exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)]
126+
throw(CompositeException(exceptions))
127+
else
128+
return done_tasks, tasks[remaining_mask]
129+
end
130+
end
131+
end
132+
end

0 commit comments

Comments
 (0)