Skip to content

Commit e1ccbfe

Browse files
committed
fixup! Make Dagger.finish_stream() propagate downstream
1 parent 3f5f8d4 commit e1ccbfe

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

src/stream.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B}
4040
end
4141
@dagdebug thunk_id :stream "adding $value ($(length(store.output_streams)) outputs)"
4242
for output_uid in keys(store.output_streams)
43-
if !haskey(store.output_buffers, output_uid)
44-
initialize_output_stream!(store, output_uid)
45-
end
4643
buffer = store.output_buffers[output_uid]
4744
while isfull(buffer)
4845
if !isopen(store)
@@ -257,10 +254,13 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S
257254
end
258255
initialize_input_stream!(our_store::StreamStore, arg) = arg
259256
function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B}
260-
@assert islocked(our_store.lock)
261257
@dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid"
262-
buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount)
263-
our_store.output_buffers[output_uid] = buffer
258+
local buffer
259+
@lock our_store.lock begin
260+
buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount)
261+
our_store.output_buffers[output_uid] = buffer
262+
end
263+
264264
our_uid = our_store.uid
265265
output_stream = our_store.output_streams[output_uid]
266266
output_fetcher = our_store.output_fetchers[output_uid]
@@ -595,6 +595,16 @@ function stream!(sf::StreamingFunction, uid,
595595
f = move(thunk_processor(), sf.f)
596596
counter = 0
597597

598+
# Initialize output streams. We can't do this in add_waiters!() because the
599+
# output handlers depend on the DTaskTLS, so they have to be set up from
600+
# within the DTask.
601+
store = sf.stream.store
602+
for output_uid in keys(store.output_streams)
603+
if !haskey(store.output_buffers, output_uid)
604+
initialize_output_stream!(store, output_uid)
605+
end
606+
end
607+
598608
while true
599609
# Yield to other (streaming) tasks
600610
yield()

test/streaming.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,19 @@ for idx in 1:5
413413
# end
414414

415415
@testset "Graceful finishing" begin
416+
@test test_finishes("finish_stream() without return value") do
417+
B = Dagger.spawn_streaming() do
418+
A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream()
419+
420+
Dagger.@spawn scope=rand(scopes) accumulator(A)
421+
end
422+
423+
fetch(B)
424+
# Since we don't return any value in the call to finish_stream(), B
425+
# should never execute.
426+
@test isempty(ACCUMULATOR)
427+
end
428+
416429
@test test_finishes("finish_stream() with one downstream task") do
417430
B = Dagger.spawn_streaming() do
418431
A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream(42)

0 commit comments

Comments
 (0)