From af692b57b840fcfc29b167fd8422c487006b2e68 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:42:50 -0700 Subject: [PATCH 01/14] Add metadata to EagerThunk --- src/eager_thunk.jl | 14 +++++++++++++- src/submission.jl | 9 ++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/eager_thunk.jl b/src/eager_thunk.jl index 8a021737e..bb0766be1 100644 --- a/src/eager_thunk.jl +++ b/src/eager_thunk.jl @@ -29,6 +29,16 @@ end Options(;options...) = Options((;options...)) Options(options...) = Options((;options...)) +""" + EagerThunkMetadata + +Represents some useful metadata pertaining to an `EagerThunk`: +- `return_type::Type` - The inferred return type of the task +""" +mutable struct EagerThunkMetadata + return_type::Type +end + """ EagerThunk @@ -39,9 +49,11 @@ be `fetch`'d or `wait`'d on at any time. mutable struct EagerThunk uid::UInt future::ThunkFuture + metadata::EagerThunkMetadata finalizer_ref::DRef thunk_ref::DRef - EagerThunk(uid, future, finalizer_ref) = new(uid, future, finalizer_ref) + EagerThunk(uid, future, metadata, finalizer_ref) = + new(uid, future, metadata, finalizer_ref) end Base.isready(t::EagerThunk) = isready(t.future) diff --git a/src/submission.jl b/src/submission.jl index 3898c922a..c791e8c76 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -183,14 +183,21 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) return options end end +function EagerThunkMetadata(spec::EagerTaskSpec) + arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) + return_type = Base._return_type(spec.f, Base.to_tuple_type(arg_types)) + return EagerThunkMetadata(return_type) +end +chunktype(t::EagerThunk) = t.metadata.return_type function eager_spawn(spec::EagerTaskSpec) # Generate new EagerThunk uid = eager_next_id() future = ThunkFuture() + metadata = EagerThunkMetadata(spec) finalizer_ref = poolset(EagerThunkFinalizer(uid); device=MemPool.CPURAMDevice()) # Return unlaunched EagerThunk - return EagerThunk(uid, future, finalizer_ref) + return EagerThunk(uid, future, metadata, finalizer_ref) end function eager_launch!((spec, task)::Pair{EagerTaskSpec,EagerThunk}) # Lookup EagerThunk -> ThunkID From 7204709d9948a46d0a2f7ed67defc2633ca294e5 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:44:14 -0700 Subject: [PATCH 02/14] Sch: Allow occupancy key to be Any --- src/sch/util.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index c9b72f2c9..71e62a531 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -362,12 +362,19 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) else get(state.signature_alloc_cost, sig, UInt64(0)) end::UInt64 - est_occupancy = if occupancy !== nothing && haskey(occupancy, T) - # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` - Base.unsafe_trunc(UInt32, clamp(occupancy[T], 0, 1) * typemax(UInt32)) - else - typemax(UInt32) - end::UInt32 + est_occupancy::UInt32 = typemax(UInt32) + if occupancy !== nothing + occ = nothing + if haskey(occupancy, T) + occ = occupancy[T] + elseif haskey(occupancy, Any) + occ = occupancy[Any] + end + if occ !== nothing + # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` + est_occupancy = Base.unsafe_trunc(UInt32, clamp(occ, 0, 1) * typemax(UInt32)) + end + end #= FIXME: Estimate if cached data can be swapped to storage storage = storage_resource(p) real_alloc_util = state.worker_storage_pressure[gp][storage] From b40018982019a6a11127ac4d8d27399b7cd30862 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 12 Sep 2023 10:56:47 -0500 Subject: [PATCH 03/14] Add streaming API --- docs/make.jl | 1 + docs/src/streaming.md | 105 ++++++++++++++ src/Dagger.jl | 3 + src/eager_thunk.jl | 12 +- src/sch/eager.jl | 7 + src/stream.jl | 313 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 440 insertions(+), 1 deletion(-) create mode 100644 docs/src/streaming.md create mode 100644 src/stream.jl diff --git a/docs/make.jl b/docs/make.jl index 679c3b9e9..d6a86bcaa 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,6 +19,7 @@ makedocs(; "Task Spawning" => "task-spawning.md", "Data Management" => "data-management.md", "Distributed Arrays" => "darray.md", + "Streaming Tasks" => "streaming.md", "Scopes" => "scopes.md", "Processors" => "processors.md", "Task Queues" => "task-queues.md", diff --git a/docs/src/streaming.md b/docs/src/streaming.md new file mode 100644 index 000000000..0a13a1472 --- /dev/null +++ b/docs/src/streaming.md @@ -0,0 +1,105 @@ +# Streaming Tasks + +Dagger tasks have a limited lifetime - they are created, execute, finish, and +are eventually destroyed when they're no longer needed. Thus, if one wants +to run the same kind of computations over and over, one might re-create a +similar set of tasks for each unit of data that needs processing. + +This might be fine for computations which take a long time to run (thus +dwarfing the cost of task creation, which is quite small), or when working with +a limited set of data, but this approach is not great for doing lots of small +computations on a large (or endless) amount of data. For example, processing +image frames from a webcam, reacting to messages from a message bus, reading +samples from a software radio, etc. All of these tasks are better suited to a +"streaming" model of data processing, where data is simply piped into a +continuously-running task (or DAG of tasks) forever, or until the data runs +out. + +Thankfully, if you have a problem which is best modeled as a streaming system +of tasks, Dagger has you covered! Building on its support for +["Task Queues"](@ref), Dagger provides a means to convert an entire DAG of +tasks into a streaming DAG, where data flows into and out of each task +asynchronously, using the `spawn_streaming` function: + +```julia +Dagger.spawn_streaming() do # enters a streaming region + vals = Dagger.@spawn rand() + print_vals = Dagger.@spawn println(vals) +end # exits the streaming region, and starts the DAG running +``` + +In the above example, `vals` is a Dagger task which has been transformed to run +in a streaming manner - instead of just calling `rand()` once and returning its +result, it will re-run `rand()` endlessly, continuously producing new random +values. In typical Dagger style, `print_vals` is a Dagger task which depends on +`vals`, but in streaming form - it will continuously `println` the random +values produced from `vals`. Both tasks will run forever, and will run +efficiently, only doing the work necessary to generate, transfer, and consume +values. + +As the comments point out, `spawn_streaming` creates a streaming region, during +which `vals` and `print_vals` are created and configured. Both tasks are halted +until `spawn_streaming` returns, allowing large DAGs to be built all at once, +without any task losing a single value. If desired, streaming regions can be +connected, although some values might be lost while tasks are being connected: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.@spawn rand() +end + +# Some values might be generated by `vals` but thrown away +# before `print_vals` is fully setup and connected to it + +print_vals = Dagger.spawn_streaming() do + Dagger.@spawn println(vals) +end +``` + +More complicated streaming DAGs can be easily constructed, without doing +anything different. For example, we can generate multiple streams of random +numbers, write them all to their own files, and print the combined results: + +```julia +Dagger.spawn_streaming() do + all_vals = [Dagger.spawn(rand) for i in 1:4] + all_vals_written = map(1:4) do i + Dagger.spawn(all_vals[i]) do val + open("results_$i.txt"; write=true, create=true, append=true) do io + println(io, repr(val)) + end + return val + end + end + Dagger.spawn(all_vals_written...) do all_vals_written... + vals_sum = sum(all_vals_written) + println(vals_sum) + end +end +``` + +If you want to stop the streaming DAG and tear it all down, you can call +`Dagger.kill!(all_vals[1])` (or `Dagger.kill!(all_vals_written[2])`, etc., the +kill propagates throughout the DAG). + +Alternatively, tasks can stop themselves from the inside with +`finish_streaming`, optionally returning a value that can be `fetch`'d. Let's +do this when our randomly-drawn number falls within some arbitrary range: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.spawn() do + x = rand() + if x < 0.001 + # That's good enough, let's be done + return Dagger.finish_streaming("Finished!") + end + return x + end +end +fetch(vals) +``` + +In this example, the call to `fetch` will hang (while random numbers continue +to be drawn), until a drawn number is less than 0.001; at that point, `fetch` +will return with "Finished!", and the task `vals` will have terminated. diff --git a/src/Dagger.jl b/src/Dagger.jl index be6ee3075..6c19fec6c 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -42,6 +42,9 @@ include("utils/system_uuid.jl") include("utils/caching.jl") include("sch/Sch.jl"); using .Sch +# Streaming +include("stream.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") diff --git a/src/eager_thunk.jl b/src/eager_thunk.jl index bb0766be1..00856c863 100644 --- a/src/eager_thunk.jl +++ b/src/eager_thunk.jl @@ -67,7 +67,17 @@ function Base.fetch(t::EagerThunk; raw=false) if !isdefined(t, :thunk_ref) throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `EagerThunk`")) end - return fetch(t.future; raw) + stream = task_to_stream(t.uid) + if stream !== nothing + add_waiters!(stream, [0]) + end + try + return fetch(t.future; raw) + finally + if stream !== nothing + remove_waiters!(stream, [0]) + end + end end function Base.show(io::IO, t::EagerThunk) status = if isdefined(t, :thunk_ref) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 3d62ed8bf..ee58ccd7d 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -116,6 +116,13 @@ function eager_cleanup(state, uid) # N.B. cache and errored expire automatically delete!(state.thunk_dict, tid) end + remotecall_wait(1, uid) do uid + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + delete!(global_streams, uid) + end + end + end end function _find_thunk(e::Dagger.EagerThunk) diff --git a/src/stream.jl b/src/stream.jl new file mode 100644 index 000000000..68022587f --- /dev/null +++ b/src/stream.jl @@ -0,0 +1,313 @@ +mutable struct StreamStore + waiters::Vector{Int} + buffers::Dict{Int,Vector{Any}} + open::Bool + lock::Threads.Condition + StreamStore() = new(zeros(Int, 0), Dict{Int,Vector{Any}}(), true, Threads.Condition()) +end +tid() = Dagger.Sch.sch_handle().thunk_id.id +function uid() + thunk_id = tid() + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end +end +function Base.put!(store::StreamStore, @nospecialize(value)) + @lock store.lock begin + while length(store.waiters) == 0 && isopen(store) + @dagdebug nothing :stream_put "[$(uid())] no waiters, not putting" + wait(store.lock) + end + if !isopen(store) + @dagdebug nothing :stream_put "[$(uid())] closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug nothing :stream_put "[$(uid())] adding $value" + for buffer in values(store.buffers) + #elem = StreamElement(value) + push!(buffer, value) + end + notify(store.lock) + end +end +function Base.take!(store::StreamStore, id::UInt) + @lock store.lock begin + buffer = store.buffers[id] + while length(buffer) == 0 && isopen(store, id) + @dagdebug nothing :stream_take "[$(uid())] no elements, not taking" + wait(store.lock) + end + @dagdebug nothing :stream_take "[$(uid())] wait finished" + if !isopen(store, id) + @dagdebug nothing :stream_take "[$(uid())] closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + value = popfirst!(buffer) + @dagdebug nothing :stream_take "[$(uid())] value accepted" + return value + end +end +"Returns whether the store is actively open. Only check this when deciding if new values can be pushed." +Base.isopen(store::StreamStore) = store.open +"Returns whether the store is actively open, or if closing, still has remaining messages for `id`. Only check this when deciding if existing values can be taken." +function Base.isopen(store::StreamStore, id::UInt) + @lock store.lock begin + if !isempty(store.buffers[id]) + return true + end + return store.open + end +end +function Base.close(store::StreamStore) + store.open = false + @lock store.lock notify(store.lock) +end +function add_waiters!(store::StreamStore, waiters::Vector{Int}) + @lock store.lock begin + for w in waiters + store.buffers[w] = Any[] + end + append!(store.waiters, waiters) + notify(store.lock) + end +end +function remove_waiters!(store::StreamStore, waiters::Vector{Int}) + @lock store.lock begin + for w in waiters + delete!(store.buffers, w) + idx = findfirst(wo->wo==w, store.waiters) + deleteat!(store.waiters, idx) + end + notify(store.lock) + end +end + +mutable struct Stream{T} <: AbstractChannel{T} + ref::Chunk + function Stream{T}() where T + store = tochunk(StreamStore()) + return new{T}(store) + end +end +Stream() = Stream{Any}() + +function Base.put!(stream::Stream, @nospecialize(value)) + tls = Dagger.get_tls() + remotecall_wait(stream.ref.handle.owner, stream.ref.handle, value) do ref, value + Dagger.set_tls!(tls) + @nospecialize value + store = MemPool.poolget(ref)::StreamStore + put!(store, value) + end +end +function Base.take!(stream::Stream{T}, id::UInt) where T + tls = Dagger.get_tls() + return remotecall_fetch(stream.ref.handle.owner, stream.ref.handle) do ref + Dagger.set_tls!(tls) + store = MemPool.poolget(ref)::StreamStore + return take!(store, id)::T + end +end +function Base.isopen(stream::Stream, id::UInt)::Bool + return remotecall_fetch(stream.ref.handle.owner, stream.ref.handle) do ref + return isopen(MemPool.poolget(ref)::StreamStore, id) + end +end +function Base.close(stream::Stream) + remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref + close(MemPool.poolget(ref)::StreamStore) + end +end +function add_waiters!(stream::Stream, waiters::Vector{Int}) + remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref + add_waiters!(MemPool.poolget(ref)::StreamStore, waiters) + end +end +add_waiters!(stream::Stream, waiter::Integer) = + add_waiters!(stream::Stream, Int[waiter]) +function remove_waiters!(stream::Stream, waiters::Vector{Int}) + remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref + remove_waiters!(MemPool.poolget(ref)::StreamStore, waiters) + end +end +remove_waiters!(stream::Stream, waiter::Integer) = + remove_waiters!(stream::Stream, Int[waiter]) + +struct StreamingTaskQueue <: AbstractTaskQueue + tasks::Vector{Pair{EagerTaskSpec,EagerThunk}} + self_streams::Dict{UInt,Stream} + StreamingTaskQueue() = new(Pair{EagerTaskSpec,EagerThunk}[], + Dict{UInt,Stream}()) +end + +function enqueue!(queue::StreamingTaskQueue, spec::Pair{EagerTaskSpec,EagerThunk}) + push!(queue.tasks, spec) + initialize_streaming!(queue.self_streams, spec...) +end +function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{EagerTaskSpec,EagerThunk}}) + append!(queue.tasks, specs) + for (spec, task) in specs + initialize_streaming!(queue.self_streams, spec, task) + end +end +function initialize_streaming!(self_streams, spec, task) + if !isa(spec.f, StreamingFunction) + # Adapt called function for streaming and generate output Streams + # FIXME: Infer type + stream = Stream() + self_streams[task.uid] = stream + + spec.f = StreamingFunction(spec.f, stream) + # FIXME: Generalize to other processors + spec.options = merge(spec.options, (;occupancy=Dict(ThreadProc=>0))) + + # Register Stream globally + remotecall_wait(1, task.uid, stream) do uid, stream + lock(EAGER_THUNK_STREAMS) do global_streams + global_streams[uid] = stream + end + end + end +end + +function spawn_streaming(f::Base.Callable) + queue = StreamingTaskQueue() + result = with_options(f; task_queue=queue) + if length(queue.tasks) > 0 + finalize_streaming!(queue.tasks, queue.self_streams) + enqueue!(queue.tasks) + end + return result +end + +struct FinishedStreaming{T} + value::T +end +finish_streaming(value=nothing) = FinishedStreaming(value) + +struct StreamingFunction{F, T} + f::F + stream::Stream{T} +end +function (sf::StreamingFunction)(args...; kwargs...) + @nospecialize sf args kwargs + result = nothing + stream_args = Base.mapany(identity, args) + stream_kwargs = Base.mapany(identity, kwargs) + thunk_id = tid() + # FIXME: Fetch from worker 1 + uid = lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end + try + while true + # Get values from Stream args/kwargs + for (idx, arg) in enumerate(args) + if arg isa Stream + stream_args[idx] = take!(arg, uid) + end + end + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + stream_kwargs[idx] = pos => take!(arg, uid) + end + end + + # Run a single cycle of f + stream_result = sf.f(stream_args...; stream_kwargs...) + + # Exit streaming on graceful request + if stream_result isa FinishedStreaming + return stream_result.value + end + + # Put the result into the output stream + put!(sf.stream, stream_result) + + # Allow other tasks to run + yield() + end + finally + # Remove ourself as a waiter for upstream Streams + streams = Set{Stream}() + for (idx, arg) in enumerate(args) + if arg isa Stream + push!(streams, arg) + end + end + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + push!(streams, arg) + end + end + for stream in streams + @dagdebug nothing :stream_close "[$uid] dropping waiter" + remove_waiters!(stream, uid) + end + + # Ensure downstream tasks also terminate + @dagdebug nothing :stream_close "[$uid] closed stream" + close(sf.stream) + end +end + +# FIXME: Ensure this gets cleaned up +const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Stream}()) +function task_to_stream(uid::UInt) + if myid() != 1 + return remotecall_fetch(task_to_stream, 1, uid) + end + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + return global_streams[uid] + end + return + end +end + +function finalize_streaming!(tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}, self_streams) + stream_waiter_changes = Dict{UInt,Vector{Int}}() + + for (spec, task) in tasks + if !haskey(self_streams, task.uid) + continue + end + + # Adapt args to accept Stream output of other streaming tasks + for (idx, (pos, arg)) in enumerate(spec.args) + if arg isa EagerThunk + if haskey(self_streams, arg.uid) + other_stream = self_streams[arg.uid] + spec.args[idx] = pos => other_stream + changes = get!(stream_waiter_changes, arg.uid) do + Int[] + end + push!(changes, task.uid) + elseif (other_stream = task_to_stream(arg.uid)) !== nothing + spec.args[idx] = pos => other_stream + changes = get!(stream_waiter_changes, arg.uid) do + Int[] + end + push!(changes, task.uid) + end + end + end + end + + # Adjust waiter count of Streams with dependencies + for (uid, waiters) in stream_waiter_changes + stream = task_to_stream(uid) + add_waiters!(stream, waiters) + end +end + +# TODO: Allow stopping arbitrary tasks +kill!(t::EagerThunk) = close(task_to_stream(t.uid)) From 78146e6f222dea1bb88885cdb44d47cc3d866cf2 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:49:48 -0700 Subject: [PATCH 04/14] fixup! Add streaming API --- src/stream.jl | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index 68022587f..9a82cefd9 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -1,9 +1,11 @@ -mutable struct StreamStore +mutable struct StreamStore{T} waiters::Vector{Int} buffers::Dict{Int,Vector{Any}} open::Bool lock::Threads.Condition - StreamStore() = new(zeros(Int, 0), Dict{Int,Vector{Any}}(), true, Threads.Condition()) + StreamStore{T}() where T = + new{T}(zeros(Int, 0), Dict{Int,Vector{T}}(), + true, Threads.Condition()) end tid() = Dagger.Sch.sch_handle().thunk_id.id function uid() @@ -16,7 +18,7 @@ function uid() end end end -function Base.put!(store::StreamStore, @nospecialize(value)) +function Base.put!(store::StreamStore{T}, @nospecialize(value::T)) where T @lock store.lock begin while length(store.waiters) == 0 && isopen(store) @dagdebug nothing :stream_put "[$(uid())] no waiters, not putting" @@ -89,7 +91,7 @@ end mutable struct Stream{T} <: AbstractChannel{T} ref::Chunk function Stream{T}() where T - store = tochunk(StreamStore()) + store = tochunk(StreamStore{T}()) return new{T}(store) end end @@ -157,13 +159,16 @@ end function initialize_streaming!(self_streams, spec, task) if !isa(spec.f, StreamingFunction) # Adapt called function for streaming and generate output Streams - # FIXME: Infer type - stream = Stream() + T_old = Base.uniontypes(task.metadata.return_type) + T_old = map(t->(t !== Union{} && t <: FinishedStreaming) ? only(t.parameters) : t, T_old) + # We treat non-dominating error paths as unreachable + T_old = filter(t->t !== Union{}, T_old) + T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any + stream = Stream{T}() self_streams[task.uid] = stream spec.f = StreamingFunction(spec.f, stream) - # FIXME: Generalize to other processors - spec.options = merge(spec.options, (;occupancy=Dict(ThreadProc=>0))) + spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) # Register Stream globally remotecall_wait(1, task.uid, stream) do uid, stream From d43a26221300e517ff42ec1046c0db70aa4da054 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 18 Dec 2023 19:11:24 -0700 Subject: [PATCH 05/14] fixup! fixup! Add streaming API --- src/eager_thunk.jl | 4 +- src/stream.jl | 147 +++++++++++++++++++++++++++++++-------------- 2 files changed, 104 insertions(+), 47 deletions(-) diff --git a/src/eager_thunk.jl b/src/eager_thunk.jl index 00856c863..cf778444a 100644 --- a/src/eager_thunk.jl +++ b/src/eager_thunk.jl @@ -68,13 +68,13 @@ function Base.fetch(t::EagerThunk; raw=false) throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `EagerThunk`")) end stream = task_to_stream(t.uid) - if stream !== nothing + if stream isa Stream add_waiters!(stream, [0]) end try return fetch(t.future; raw) finally - if stream !== nothing + if stream isa Stream remove_waiters!(stream, [0]) end end diff --git a/src/stream.jl b/src/stream.jl index 9a82cefd9..191730c75 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -30,7 +30,6 @@ function Base.put!(store::StreamStore{T}, @nospecialize(value::T)) where T end @dagdebug nothing :stream_put "[$(uid())] adding $value" for buffer in values(store.buffers) - #elem = StreamElement(value) push!(buffer, value) end notify(store.lock) @@ -139,11 +138,31 @@ end remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream::Stream, Int[waiter]) +struct NullStream end +Base.put!(ns::NullStream, x) = nothing +Base.take!(ns::NullStream) = throw(ConcurrencyViolationError("Cannot `take!` from a `NullStream`")) + +mutable struct StreamWrapper{S} + stream::S + open::Bool + StreamWrapper(stream::S) where S = new{S}(stream, true) +end +Base.isopen(sw::StreamWrapper) = sw.open +Base.close(sw::StreamWrapper) = (sw.open = false;) +function Base.put!(sw::StreamWrapper, x) + isopen(sw) || throw(InvalidStateException("Stream is closed.", :closed)) + put!(sw.stream, x) +end +function Base.take!(sw::StreamWrapper) + isopen(sw) || throw(InvalidStateException("Stream is closed.", :closed)) + take!(sw.stream) +end + struct StreamingTaskQueue <: AbstractTaskQueue tasks::Vector{Pair{EagerTaskSpec,EagerThunk}} - self_streams::Dict{UInt,Stream} + self_streams::Dict{UInt,Any} StreamingTaskQueue() = new(Pair{EagerTaskSpec,EagerThunk}[], - Dict{UInt,Stream}()) + Dict{UInt,Any}()) end function enqueue!(queue::StreamingTaskQueue, spec::Pair{EagerTaskSpec,EagerThunk}) @@ -164,7 +183,20 @@ function initialize_streaming!(self_streams, spec, task) # We treat non-dominating error paths as unreachable T_old = filter(t->t !== Union{}, T_old) T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any - stream = Stream{T}() + if haskey(spec.options, :stream) + if spec.options.stream !== nothing + # Use the user-provided stream + @warn "Replace StreamWrapper with Stream" maxlog=1 + stream = StreamWrapper(spec.options.stream) + else + # Use a non-readable, non-writing stream + stream = StreamWrapper(NullStream()) + end + spec.options = NamedTuple(filter(opt -> opt[1] != :stream, Base.pairs(spec.options))) + else + # Create a built-in Stream object + stream = Stream{T}() + end self_streams[task.uid] = stream spec.f = StreamingFunction(spec.f, stream) @@ -190,56 +222,33 @@ function spawn_streaming(f::Base.Callable) end struct FinishedStreaming{T} - value::T + value::Union{Some{T},Nothing} end -finish_streaming(value=nothing) = FinishedStreaming(value) +finish_streaming(value) = FinishedStreaming{Any}(Some{T}(value)) +finish_streaming() = FinishedStreaming{Union{}}(nothing) -struct StreamingFunction{F, T} +struct StreamingFunction{F, S} f::F - stream::Stream{T} + stream::S end function (sf::StreamingFunction)(args...; kwargs...) @nospecialize sf args kwargs result = nothing - stream_args = Base.mapany(identity, args) - stream_kwargs = Base.mapany(identity, kwargs) thunk_id = tid() - # FIXME: Fetch from worker 1 - uid = lock(Sch.EAGER_ID_MAP) do id_map - for (uid, otid) in id_map - if thunk_id == otid - return uid + @warn "Fetch from worker 1 more efficiently" maxlog=1 + uid = remotecall_fetch(1, thunk_id) do thunk_id + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end end end end try - while true - # Get values from Stream args/kwargs - for (idx, arg) in enumerate(args) - if arg isa Stream - stream_args[idx] = take!(arg, uid) - end - end - for (idx, (pos, arg)) in enumerate(kwargs) - if arg isa Stream - stream_kwargs[idx] = pos => take!(arg, uid) - end - end - - # Run a single cycle of f - stream_result = sf.f(stream_args...; stream_kwargs...) - - # Exit streaming on graceful request - if stream_result isa FinishedStreaming - return stream_result.value - end - - # Put the result into the output stream - put!(sf.stream, stream_result) - - # Allow other tasks to run - yield() - end + kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) + kwarg_values = map(last, (kwargs...,)) + return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) finally # Remove ourself as a waiter for upstream Streams streams = Set{Stream}() @@ -263,9 +272,55 @@ function (sf::StreamingFunction)(args...; kwargs...) close(sf.stream) end end +# N.B We specialize to minimize/eliminate allocations +function stream!(sf::StreamingFunction, uid, + args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) + while true + #@time begin + # Get values from Stream args/kwargs + stream_args = _stream_take_values!(args) + stream_kwarg_values = _stream_take_values!(kwarg_values) + stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + + # Run a single cycle of f + stream_result = sf.f(stream_args...; stream_kwargs...) + + # Exit streaming on graceful request + if stream_result isa FinishedStreaming + @info "Terminating!" + if stream_result.value !== nothing + value = something(stream_result.value) + put!(sf.stream, value) + return value + end + return nothing + end -# FIXME: Ensure this gets cleaned up -const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Stream}()) + # Put the result into the output stream + put!(sf.stream, stream_result) + #end + end +end +function _stream_take_values!(args) + return ntuple(length(args)) do idx + arg = args[idx] + if arg isa Stream + take!(arg, uid) + elseif arg isa Union{AbstractChannel,RemoteChannel,StreamWrapper} # FIXME: Use trait query + take!(arg) + else + arg + end + end +end +@inline @generated function _stream_namedtuple(kwarg_names::Tuple, + stream_kwarg_values::Tuple) + name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...) + NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) + return :($NT(stream_kwarg_values)) +end + +const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) function task_to_stream(uid::UInt) if myid() != 1 return remotecall_fetch(task_to_stream, 1, uid) @@ -310,7 +365,9 @@ function finalize_streaming!(tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}, self # Adjust waiter count of Streams with dependencies for (uid, waiters) in stream_waiter_changes stream = task_to_stream(uid) - add_waiters!(stream, waiters) + if stream isa Stream # FIXME: Use trait query + add_waiters!(stream, waiters) + end end end From 8eb1a6aa5baf8e70da3940e535144704d70ccd24 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 18 Dec 2023 19:11:44 -0700 Subject: [PATCH 06/14] TEMP: Add stream migration support --- src/stream.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/stream.jl b/src/stream.jl index 191730c75..d90bc3ed3 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -138,6 +138,20 @@ end remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream::Stream, Int[waiter]) +function migrate_stream!(stream::Stream, w::Integer=myid()) + # Take lock to prevent any further modifications + # N.B. Serialization automatically unlocks + remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref + lock((MemPool.poolget(ref)::StreamStore).lock) + end + + # Perform migration of the StreamStore + # MemPool will block access to the new ref until the migration completes + if stream.ref.handle.owner != w + MemPool.migrate!(stream.ref.handle, w) + end +end + struct NullStream end Base.put!(ns::NullStream, x) = nothing Base.take!(ns::NullStream) = throw(ConcurrencyViolationError("Cannot `take!` from a `NullStream`")) @@ -245,6 +259,9 @@ function (sf::StreamingFunction)(args...; kwargs...) end end end + if sf.stream isa Stream + migrate_stream!(sf.stream) + end try kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) kwarg_values = map(last, (kwargs...,)) From b22567c13ac352d85e1dfbd2652bd8c97dea6119 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 11 Mar 2024 09:31:24 -0700 Subject: [PATCH 07/14] fixup! Add metadata to EagerThunk --- Project.toml | 1 + src/submission.jl | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c0452482b..dd092a9a0 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" +Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/submission.jl b/src/submission.jl index c791e8c76..776315f65 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -184,8 +184,9 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) end end function EagerThunkMetadata(spec::EagerTaskSpec) + f = chunktype(spec.f).instance arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) - return_type = Base._return_type(spec.f, Base.to_tuple_type(arg_types)) + return_type = Base._return_type(f, Base.to_tuple_type(arg_types)) return EagerThunkMetadata(return_type) end chunktype(t::EagerThunk) = t.metadata.return_type From bc79b842cf1238866da46454f778dc85746c8428 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 11 Mar 2024 09:37:53 -0700 Subject: [PATCH 08/14] fixup! fixup! fixup! Add streaming API --- Manifest.toml | 50 ++++---- Project.toml | 1 - src/Dagger.jl | 2 + src/eager_thunk.jl | 12 +- src/stream-buffers.jl | 204 +++++++++++++++++++++++++++++++++ src/stream-fetchers.jl | 24 ++++ src/stream.jl | 252 ++++++++++++++++++++++------------------- 7 files changed, 392 insertions(+), 153 deletions(-) create mode 100644 src/stream-buffers.jl create mode 100644 src/stream-fetchers.jl diff --git a/Manifest.toml b/Manifest.toml index 0cf9adc65..3b0885f3c 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "5333a6c200b6e6add81c46547527f66ddc0dc16c" +project_hash = "1e12d6aa088ae431916872c11d09544380c7a130" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -12,9 +12,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "b66b8f8e3db5d7835fb8cbe2589ffd1cd456e491" +git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.17.0" +version = "1.23.0" [[deps.ChangesOfVariables]] deps = ["InverseFunctions", "LinearAlgebra", "Test"] @@ -23,10 +23,10 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.8" [[deps.Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c" +deps = ["Dates", "LinearAlgebra", "TOML", "UUIDs"] +git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.10.0" +version = "4.14.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] @@ -34,15 +34,15 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "1.0.1+0" [[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" +version = "1.16.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" +git-tree-sha1 = "0f4b5d62a88d8f59003e43c25a8a90de9eb76317" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" +version = "0.18.18" [[deps.Dates]] deps = ["Printf"] @@ -91,28 +91,28 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.26" +version = "0.3.27" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.11" +version = "0.5.13" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MemPool]] -deps = ["DataStructures", "Distributed", "Mmap", "Random", "Serialization", "Sockets"] -git-tree-sha1 = "b9c1a032c3c1310a857c061ce487c632eaa1faa4" +deps = ["DataStructures", "Distributed", "Mmap", "Random", "ScopedValues", "Serialization", "Sockets"] +git-tree-sha1 = "60dd4ac427d39e0b3f15b193845324523ee71c03" uuid = "f9f48841-c794-520a-933b-121f7ba6ed94" -version = "0.4.4" +version = "0.4.6" [[deps.Missings]] deps = ["DataAPI"] @@ -133,9 +133,9 @@ uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" version = "0.3.20+0" [[deps.OrderedCollections]] -git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.2" +version = "1.6.3" [[deps.PrecompileTools]] deps = ["Preferences"] @@ -145,9 +145,9 @@ version = "1.2.0" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.1" +version = "1.4.3" [[deps.Printf]] deps = ["Unicode"] @@ -173,9 +173,9 @@ version = "0.7.0" [[deps.ScopedValues]] deps = ["HashArrayMappedTries", "Logging"] -git-tree-sha1 = "e3b5e4ccb1702db2ae9ac2a660d4b6b2a8595742" +git-tree-sha1 = "c27d546a4749c81f70d1fabd604da6aa5054e3d2" uuid = "7e506255-f358-4e82-b7e4-beb19740aa63" -version = "1.1.0" +version = "1.2.0" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -189,9 +189,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" +version = "1.2.1" [[deps.SparseArrays]] deps = ["LinearAlgebra", "Random"] diff --git a/Project.toml b/Project.toml index dd092a9a0..c0452482b 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" -Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/Dagger.jl b/src/Dagger.jl index 6c19fec6c..cb95ce7a3 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -43,6 +43,8 @@ include("utils/caching.jl") include("sch/Sch.jl"); using .Sch # Streaming +include("stream-buffers.jl") +include("stream-fetchers.jl") include("stream.jl") # Array computations diff --git a/src/eager_thunk.jl b/src/eager_thunk.jl index cf778444a..bb0766be1 100644 --- a/src/eager_thunk.jl +++ b/src/eager_thunk.jl @@ -67,17 +67,7 @@ function Base.fetch(t::EagerThunk; raw=false) if !isdefined(t, :thunk_ref) throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `EagerThunk`")) end - stream = task_to_stream(t.uid) - if stream isa Stream - add_waiters!(stream, [0]) - end - try - return fetch(t.future; raw) - finally - if stream isa Stream - remove_waiters!(stream, [0]) - end - end + return fetch(t.future; raw) end function Base.show(io::IO, t::EagerThunk) status = if isdefined(t, :thunk_ref) diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl new file mode 100644 index 000000000..9c242dea1 --- /dev/null +++ b/src/stream-buffers.jl @@ -0,0 +1,204 @@ +using Mmap + +""" +A buffer that drops all elements put into it. Only to be used as the output +buffer for a task - will throw if attached as an input. +""" +struct DropBuffer{T} end +DropBuffer{T}(_) where T = DropBuffer{T}() +Base.isempty(::DropBuffer) = true +isfull(::DropBuffer) = false +Base.put!(::DropBuffer, _) = nothing +Base.take!(::DropBuffer) = error("Cannot `take!` from a DropBuffer") + +"A process-local buffer backed by a `Channel{T}`." +struct ChannelBuffer{T} + channel::Channel{T} + len::Int + count::Threads.Atomic{Int} + ChannelBuffer{T}(len::Int=1024) where T = + new{T}(Channel{T}(len), len, Threads.Atomic{Int}(0)) +end +Base.isempty(cb::ChannelBuffer) = isempty(cb.channel) +isfull(cb::ChannelBuffer) = cb.count[] == cb.len +function Base.put!(cb::ChannelBuffer{T}, x) where T + put!(cb.channel, convert(T, x)) + Threads.atomic_add!(cb.count, 1) +end +function Base.take!(cb::ChannelBuffer) + take!(cb.channel) + Threads.atomic_sub!(cb.count, 1) +end + +"A cross-worker buffer backed by a `RemoteChannel{T}`." +struct RemoteChannelBuffer{T} + channel::RemoteChannel{Channel{T}} + len::Int + count::Threads.Atomic{Int} + RemoteChannelBuffer{T}(len::Int=1024) where T = + new{T}(RemoteChannel(()->Channel{T}(len)), len, Threads.Atomic{Int}(0)) +end +Base.isempty(cb::RemoteChannelBuffer) = isempty(cb.channel) +isfull(cb::RemoteChannelBuffer) = cb.count[] == cb.len +function Base.put!(cb::RemoteChannelBuffer{T}, x) where T + put!(cb.channel, convert(T, x)) + Threads.atomic_add!(cb.count, 1) +end +function Base.take!(cb::RemoteChannelBuffer) + take!(cb.channel) + Threads.atomic_sub!(cb.count, 1) +end + +"A process-local ring buffer." +mutable struct ProcessRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function ProcessRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 +isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +function Base.put!(rb::ProcessRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ProcessRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end + +#= TODO +"A server-local ring buffer backed by shared-memory." +mutable struct ServerRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function ServerRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::ServerRingBuffer) = (@atomic rb.count) == 0 +function Base.put!(rb::ServerRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ServerRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end +=# + +#= +"A TCP-based ring buffer." +mutable struct TCPRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function TCPRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::TCPRingBuffer) = (@atomic rb.count) == 0 +function Base.put!(rb::TCPRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::TCPRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end +=# + +#= +""" +A flexible puller which switches to the most efficient buffer type based +on the sender and receiver locations. +""" +mutable struct UniBuffer{T} + buffer::Union{ProcessRingBuffer{T}, Nothing} +end +function initialize_stream_buffer!(::Type{UniBuffer{T}}, T, send_proc, recv_proc, buffer_amount) where T + if buffer_amount == 0 + error("Return NullBuffer") + end + send_osproc = get_parent(send_proc) + recv_osproc = get_parent(recv_proc) + if send_osproc.pid == recv_osproc.pid + inner = RingBuffer{T}(buffer_amount) + elseif system_uuid(send_osproc.pid) == system_uuid(recv_osproc.pid) + inner = ProcessBuffer{T}(buffer_amount) + else + inner = RemoteBuffer{T}(buffer_amount) + end + return UniBuffer{T}(buffer_amount) +end + +struct LocalPuller{T,B} + buffer::B{T} + id::UInt + function LocalPuller{T,B}(id::UInt, buffer_amount::Integer) where {T,B} + buffer = initialize_stream_buffer!(B, T, buffer_amount) + return new{T,B}(buffer, id) + end +end +function Base.take!(pull::LocalPuller{T,B}) where {T,B} + if pull.buffer === nothing + pull.buffer = + error("Return NullBuffer") + end + value = take!(pull.buffer) +end +function initialize_input_stream!(stream::Stream{T,B}, id::UInt, send_proc::Processor, recv_proc::Processor, buffer_amount::Integer) where {T,B} + local_buffer = remotecall_fetch(stream.ref.handle.owner, stream.ref.handle, id) do ref, id + local_buffer, remote_buffer = initialize_stream_buffer!(B, T, send_proc, recv_proc, buffer_amount) + ref.buffers[id] = remote_buffer + return local_buffer + end + stream.buffer = local_buffer + return stream +end +=# diff --git a/src/stream-fetchers.jl b/src/stream-fetchers.jl new file mode 100644 index 000000000..f8660cdf1 --- /dev/null +++ b/src/stream-fetchers.jl @@ -0,0 +1,24 @@ +struct RemoteFetcher end +function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} + if store_ref.handle.owner == myid() + store = fetch(store_ref)::Store_remote + while !isfull(buffer) + value = take!(store, id)::T + put!(buffer, value) + end + else + tls = Dagger.get_tls() + values = remotecall_fetch(store_ref.handle.owner, store_ref.handle, id, T, Store_remote) do store_ref, id, T, Store_remote + store = MemPool.poolget(store_ref)::Store_remote + values = T[] + while !isempty(store, id) + value = take!(store, id)::T + push!(values, value) + end + return values + end::Vector{T} + for value in values + put!(buffer, value) + end + end +end diff --git a/src/stream.jl b/src/stream.jl index d90bc3ed3..67f3143d2 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -1,11 +1,12 @@ -mutable struct StreamStore{T} +mutable struct StreamStore{T,B} waiters::Vector{Int} - buffers::Dict{Int,Vector{Any}} + buffers::Dict{Int,B} + buffer_amount::Int open::Bool lock::Threads.Condition - StreamStore{T}() where T = - new{T}(zeros(Int, 0), Dict{Int,Vector{T}}(), - true, Threads.Condition()) + StreamStore{T,B}(buffer_amount::Integer) where {T,B} = + new{T,B}(zeros(Int, 0), Dict{Int,B}(), buffer_amount, + true, Threads.Condition()) end tid() = Dagger.Sch.sch_handle().thunk_id.id function uid() @@ -18,19 +19,19 @@ function uid() end end end -function Base.put!(store::StreamStore{T}, @nospecialize(value::T)) where T +function Base.put!(store::StreamStore{T,B}, value) where {T,B} @lock store.lock begin - while length(store.waiters) == 0 && isopen(store) - @dagdebug nothing :stream_put "[$(uid())] no waiters, not putting" - wait(store.lock) - end if !isopen(store) @dagdebug nothing :stream_put "[$(uid())] closed!" throw(InvalidStateException("Stream is closed", :closed)) end @dagdebug nothing :stream_put "[$(uid())] adding $value" for buffer in values(store.buffers) - push!(buffer, value) + while isfull(buffer) + @dagdebug nothing :stream_put "[$(uid())] buffer full, waiting" + wait(store.lock) + end + put!(buffer, value) end notify(store.lock) end @@ -38,7 +39,7 @@ end function Base.take!(store::StreamStore, id::UInt) @lock store.lock begin buffer = store.buffers[id] - while length(buffer) == 0 && isopen(store, id) + while isempty(buffer) && isopen(store, id) @dagdebug nothing :stream_take "[$(uid())] no elements, not taking" wait(store.lock) end @@ -47,11 +48,19 @@ function Base.take!(store::StreamStore, id::UInt) @dagdebug nothing :stream_take "[$(uid())] closed!" throw(InvalidStateException("Stream is closed", :closed)) end - value = popfirst!(buffer) + unlock(store.lock) + value = try + take!(buffer) + finally + lock(store.lock) + end @dagdebug nothing :stream_take "[$(uid())] value accepted" + notify(store.lock) return value end end +Base.isempty(store::StreamStore, id::UInt) = isempty(store.buffers[id]) +isfull(store::StreamStore, id::UInt) = isfull(store.buffers[id]) "Returns whether the store is actively open. Only check this when deciding if new values can be pushed." Base.isopen(store::StreamStore) = store.open "Returns whether the store is actively open, or if closing, still has remaining messages for `id`. Only check this when deciding if existing values can be taken." @@ -64,13 +73,16 @@ function Base.isopen(store::StreamStore, id::UInt) end end function Base.close(store::StreamStore) - store.open = false - @lock store.lock notify(store.lock) + if store.open + store.open = false + @lock store.lock notify(store.lock) + end end -function add_waiters!(store::StreamStore, waiters::Vector{Int}) +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Int}) where {T,B} @lock store.lock begin for w in waiters - store.buffers[w] = Any[] + buffer = initialize_stream_buffer(B, T, store.buffer_amount) + store.buffers[w] = buffer end append!(store.waiters, waiters) notify(store.lock) @@ -87,51 +99,52 @@ function remove_waiters!(store::StreamStore, waiters::Vector{Int}) end end -mutable struct Stream{T} <: AbstractChannel{T} - ref::Chunk - function Stream{T}() where T - store = tochunk(StreamStore{T}()) - return new{T}(store) +mutable struct Stream{T,B} + store::Union{StreamStore{T,B},Nothing} + store_ref::Chunk + input_buffer::Union{B,Nothing} + buffer_amount::Int + function Stream{T,B}(buffer_amount::Integer=0) where {T,B} + # Creates a new output stream + store = StreamStore{T,B}(buffer_amount) + store_ref = tochunk(store) + return new{T,B}(store, store_ref, nothing, buffer_amount) end -end -Stream() = Stream{Any}() - -function Base.put!(stream::Stream, @nospecialize(value)) - tls = Dagger.get_tls() - remotecall_wait(stream.ref.handle.owner, stream.ref.handle, value) do ref, value - Dagger.set_tls!(tls) - @nospecialize value - store = MemPool.poolget(ref)::StreamStore - put!(store, value) + function Stream{B}(stream::Stream{T}, buffer_amount::Integer=0) where {T,B} + # References an existing output stream + return new{T,B}(nothing, stream.store_ref, nothing, buffer_amount) end end -function Base.take!(stream::Stream{T}, id::UInt) where T - tls = Dagger.get_tls() - return remotecall_fetch(stream.ref.handle.owner, stream.ref.handle) do ref - Dagger.set_tls!(tls) - store = MemPool.poolget(ref)::StreamStore - return take!(store, id)::T - end +function initialize_input_stream!(stream::Stream{T,B}) where {T,B} + stream.input_buffer = initialize_stream_buffer(B, T, stream.buffer_amount) +end + +Base.put!(stream::Stream, @nospecialize(value)) = + put!(stream.store, value) +function Base.take!(stream::Stream{T,B}, id::UInt) where {T,B} + @warn "Make remote fetcher configurable" maxlog=1 + stream_fetch_values!(RemoteFetcher, T, stream.store_ref, stream.input_buffer, id) + return take!(stream.input_buffer) end function Base.isopen(stream::Stream, id::UInt)::Bool - return remotecall_fetch(stream.ref.handle.owner, stream.ref.handle) do ref + return remotecall_fetch(stream.store_ref.handle.owner, stream.store_ref.handle) do ref return isopen(MemPool.poolget(ref)::StreamStore, id) end end function Base.close(stream::Stream) - remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref close(MemPool.poolget(ref)::StreamStore) end end function add_waiters!(stream::Stream, waiters::Vector{Int}) - remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref add_waiters!(MemPool.poolget(ref)::StreamStore, waiters) end end add_waiters!(stream::Stream, waiter::Integer) = add_waiters!(stream::Stream, Int[waiter]) function remove_waiters!(stream::Stream, waiters::Vector{Int}) - remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref remove_waiters!(MemPool.poolget(ref)::StreamStore, waiters) end end @@ -139,37 +152,17 @@ remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream::Stream, Int[waiter]) function migrate_stream!(stream::Stream, w::Integer=myid()) - # Take lock to prevent any further modifications - # N.B. Serialization automatically unlocks - remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref - lock((MemPool.poolget(ref)::StreamStore).lock) - end - # Perform migration of the StreamStore # MemPool will block access to the new ref until the migration completes - if stream.ref.handle.owner != w - MemPool.migrate!(stream.ref.handle, w) - end -end - -struct NullStream end -Base.put!(ns::NullStream, x) = nothing -Base.take!(ns::NullStream) = throw(ConcurrencyViolationError("Cannot `take!` from a `NullStream`")) + if stream.store_ref.handle.owner != w + # Take lock to prevent any further modifications + # N.B. Serialization automatically unlocks + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + lock((MemPool.poolget(ref)::StreamStore).lock) + end -mutable struct StreamWrapper{S} - stream::S - open::Bool - StreamWrapper(stream::S) where S = new{S}(stream, true) -end -Base.isopen(sw::StreamWrapper) = sw.open -Base.close(sw::StreamWrapper) = (sw.open = false;) -function Base.put!(sw::StreamWrapper, x) - isopen(sw) || throw(InvalidStateException("Stream is closed.", :closed)) - put!(sw.stream, x) -end -function Base.take!(sw::StreamWrapper) - isopen(sw) || throw(InvalidStateException("Stream is closed.", :closed)) - take!(sw.stream) + MemPool.migrate!(stream.store_ref.handle, w) + end end struct StreamingTaskQueue <: AbstractTaskQueue @@ -193,24 +186,19 @@ function initialize_streaming!(self_streams, spec, task) if !isa(spec.f, StreamingFunction) # Adapt called function for streaming and generate output Streams T_old = Base.uniontypes(task.metadata.return_type) - T_old = map(t->(t !== Union{} && t <: FinishedStreaming) ? only(t.parameters) : t, T_old) + T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old) # We treat non-dominating error paths as unreachable T_old = filter(t->t !== Union{}, T_old) T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any - if haskey(spec.options, :stream) - if spec.options.stream !== nothing - # Use the user-provided stream - @warn "Replace StreamWrapper with Stream" maxlog=1 - stream = StreamWrapper(spec.options.stream) - else - # Use a non-readable, non-writing stream - stream = StreamWrapper(NullStream()) - end - spec.options = NamedTuple(filter(opt -> opt[1] != :stream, Base.pairs(spec.options))) - else - # Create a built-in Stream object - stream = Stream{T}() + output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) + if output_buffer_amount <= 0 + throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) end + output_buffer = get(spec.options, :stream_output_buffer, ProcessRingBuffer) + stream = Stream{T,output_buffer}(output_buffer_amount) + spec.options = NamedTuple(filter(opt -> opt[1] != :stream_output_buffer && + opt[1] != :stream_output_buffer_amount, + Base.pairs(spec.options))) self_streams[task.uid] = stream spec.f = StreamingFunction(spec.f, stream) @@ -235,21 +223,31 @@ function spawn_streaming(f::Base.Callable) return result end -struct FinishedStreaming{T} +struct FinishStream{T,R} value::Union{Some{T},Nothing} + result::R +end +finish_stream(value::T; result::R=nothing) where {T,R} = + FinishStream{T,R}(Some{T}(value), result) +finish_stream(; result::R=nothing) where R = + FinishStream{Union{},R}(nothing, result) + +function cancel_stream!(t::EagerThunk) + stream = task_to_stream(t.uid) + if stream !== nothing + close(stream) + end end -finish_streaming(value) = FinishedStreaming{Any}(Some{T}(value)) -finish_streaming() = FinishedStreaming{Union{}}(nothing) struct StreamingFunction{F, S} f::F stream::S end +chunktype(sf::StreamingFunction{F}) where F = F function (sf::StreamingFunction)(args...; kwargs...) @nospecialize sf args kwargs result = nothing thunk_id = tid() - @warn "Fetch from worker 1 more efficiently" maxlog=1 uid = remotecall_fetch(1, thunk_id) do thunk_id lock(Sch.EAGER_ID_MAP) do id_map for (uid, otid) in id_map @@ -259,12 +257,22 @@ function (sf::StreamingFunction)(args...; kwargs...) end end end + + # Migrate our output stream to this worker if sf.stream isa Stream migrate_stream!(sf.stream) end + try + # TODO: This kwarg song-and-dance is required to ensure that we don't + # allocate boxes within `stream!`, when possible kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) kwarg_values = map(last, (kwargs...,)) + for arg in args + if arg isa Stream + initialize_input_stream!(arg) + end + end return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) finally # Remove ourself as a waiter for upstream Streams @@ -282,6 +290,7 @@ function (sf::StreamingFunction)(args...; kwargs...) for stream in streams @dagdebug nothing :stream_close "[$uid] dropping waiter" remove_waiters!(stream, uid) + @dagdebug nothing :stream_close "[$uid] dropped waiter" end # Ensure downstream tasks also terminate @@ -292,39 +301,34 @@ end # N.B We specialize to minimize/eliminate allocations function stream!(sf::StreamingFunction, uid, args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) + f = move(thunk_processor(), sf.f) while true - #@time begin # Get values from Stream args/kwargs - stream_args = _stream_take_values!(args) - stream_kwarg_values = _stream_take_values!(kwarg_values) + stream_args = _stream_take_values!(args, uid) + stream_kwarg_values = _stream_take_values!(kwarg_values, uid) stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) # Run a single cycle of f - stream_result = sf.f(stream_args...; stream_kwargs...) + stream_result = f(stream_args...; stream_kwargs...) # Exit streaming on graceful request - if stream_result isa FinishedStreaming - @info "Terminating!" + if stream_result isa FinishStream if stream_result.value !== nothing value = something(stream_result.value) put!(sf.stream, value) - return value end - return nothing + return stream_result.result end # Put the result into the output stream put!(sf.stream, stream_result) - #end end end -function _stream_take_values!(args) +function _stream_take_values!(args, uid) return ntuple(length(args)) do idx arg = args[idx] if arg isa Stream take!(arg, uid) - elseif arg isa Union{AbstractChannel,RemoteChannel,StreamWrapper} # FIXME: Use trait query - take!(arg) else arg end @@ -336,6 +340,7 @@ end NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) return :($NT(stream_kwarg_values)) end +initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount) const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) function task_to_stream(uid::UInt) @@ -354,22 +359,31 @@ function finalize_streaming!(tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}, self stream_waiter_changes = Dict{UInt,Vector{Int}}() for (spec, task) in tasks - if !haskey(self_streams, task.uid) - continue - end + @assert haskey(self_streams, task.uid) # Adapt args to accept Stream output of other streaming tasks for (idx, (pos, arg)) in enumerate(spec.args) if arg isa EagerThunk + # Check if this is a streaming task if haskey(self_streams, arg.uid) other_stream = self_streams[arg.uid] + else + other_stream = task_to_stream(arg.uid) + end + + if other_stream !== nothing + # Get input stream configs and configure input stream + @warn "Support no input buffering" maxlog=1 + input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + input_buffer = get(spec.options, :stream_input_buffer, ProcessRingBuffer) + # FIXME: input_fetcher = get(spec.options, :stream_input_fetcher, RemoteFetcher) + @warn "Accept custom input fetcher" maxlog=1 + input_stream = Stream{input_buffer}(other_stream, input_buffer_amount) + + # Replace the EagerThunk with the input Stream spec.args[idx] = pos => other_stream - changes = get!(stream_waiter_changes, arg.uid) do - Int[] - end - push!(changes, task.uid) - elseif (other_stream = task_to_stream(arg.uid)) !== nothing - spec.args[idx] = pos => other_stream + + # Add this task as a waiter for the associated output Stream changes = get!(stream_waiter_changes, arg.uid) do Int[] end @@ -377,16 +391,22 @@ function finalize_streaming!(tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}, self end end end + + # Filter out all streaming options + to_filter = (:stream_input_buffer, :stream_input_buffer_amount, + :stream_output_buffer, :stream_output_buffer_amount) + spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), + Base.pairs(spec.options))) + if haskey(spec.options, :propagates) + propagates = filter(opt -> !(opt in to_filter), + spec.options.propagates) + spec.options = merge(spec.options, (;propagates)) + end end # Adjust waiter count of Streams with dependencies for (uid, waiters) in stream_waiter_changes stream = task_to_stream(uid) - if stream isa Stream # FIXME: Use trait query - add_waiters!(stream, waiters) - end + add_waiters!(stream, waiters) end end - -# TODO: Allow stopping arbitrary tasks -kill!(t::EagerThunk) = close(task_to_stream(t.uid)) From 1eea436834ef1e6f1528fab8786124ed2380f5f3 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 11 Mar 2024 09:39:29 -0700 Subject: [PATCH 09/14] fixup! fixup! fixup! fixup! Add streaming API --- src/stream-buffers.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl index 9c242dea1..753f8c11c 100644 --- a/src/stream-buffers.jl +++ b/src/stream-buffers.jl @@ -1,5 +1,3 @@ -using Mmap - """ A buffer that drops all elements put into it. Only to be used as the output buffer for a task - will throw if attached as an input. From 6b9a1af8448b308d52b9730eb8893acf7075d9d2 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 11 Mar 2024 09:44:56 -0700 Subject: [PATCH 10/14] fixup! fixup! fixup! fixup! fixup! Add streaming API --- src/stream.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/stream.jl b/src/stream.jl index 67f3143d2..34e3abcb9 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -152,6 +152,10 @@ remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream::Stream, Int[waiter]) function migrate_stream!(stream::Stream, w::Integer=myid()) + if !isdefined(MemPool, :migrate!) + @warn "MemPool migration support not enabled!" maxlog=1 + return + end # Perform migration of the StreamStore # MemPool will block access to the new ref until the migration completes if stream.store_ref.handle.owner != w From 329e8be6a66c4fcd74a4dc17665ccf75b9dde73b Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Wed, 13 Mar 2024 22:43:44 +0100 Subject: [PATCH 11/14] Reference Dagger.EAGER_THUNK_STREAMS explicitly --- src/sch/eager.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index ee58ccd7d..40f63ad6c 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -117,7 +117,7 @@ function eager_cleanup(state, uid) delete!(state.thunk_dict, tid) end remotecall_wait(1, uid) do uid - lock(EAGER_THUNK_STREAMS) do global_streams + lock(Dagger.EAGER_THUNK_STREAMS) do global_streams if haskey(global_streams, uid) delete!(global_streams, uid) end From 73c62dc7650a24689ffac4c7d7aa5784fef23d53 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Thu, 14 Mar 2024 10:42:33 +0100 Subject: [PATCH 12/14] Use Base.promote_op() instead of Base._return_type() return_type() is kinda broken in v1.10, see: https://github.com/JuliaLang/julia/issues/52385 In any case Base.promote_op() is the official public API for this operation so we should use it anyway. --- src/submission.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/submission.jl b/src/submission.jl index 776315f65..9627c77a3 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -186,7 +186,7 @@ end function EagerThunkMetadata(spec::EagerTaskSpec) f = chunktype(spec.f).instance arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) - return_type = Base._return_type(f, Base.to_tuple_type(arg_types)) + return_type = Base.promote_op(f, arg_types...) return EagerThunkMetadata(return_type) end chunktype(t::EagerThunk) = t.metadata.return_type From 09aedeea0234226c967d442f5a2abb6caeecbc17 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sun, 17 Mar 2024 11:55:17 +0100 Subject: [PATCH 13/14] Special-case StreamingFunction in EagerThunkMetadata() constructor This always us to handle all the other kinds of task specs. --- src/submission.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/submission.jl b/src/submission.jl index 9627c77a3..8d5bc8473 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -184,7 +184,7 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) end end function EagerThunkMetadata(spec::EagerTaskSpec) - f = chunktype(spec.f).instance + f = spec.f isa StreamingFunction ? spec.f.f : spec.f arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) return_type = Base.promote_op(f, arg_types...) return EagerThunkMetadata(return_type) From 366cb377eac5e3f6e09c05d0f15f24bab63591da Mon Sep 17 00:00:00 2001 From: Davide Ferretti Date: Fri, 29 Mar 2024 10:21:35 -0700 Subject: [PATCH 14/14] Add streaming throughput monitor Co-authored-by: Julian Samaroo --- src/Dagger.jl | 1 + src/stream-utils.jl | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 src/stream-utils.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index cb95ce7a3..391e4af66 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -45,6 +45,7 @@ include("sch/Sch.jl"); using .Sch # Streaming include("stream-buffers.jl") include("stream-fetchers.jl") +include("stream-utils.jl") include("stream.jl") # Array computations diff --git a/src/stream-utils.jl b/src/stream-utils.jl new file mode 100644 index 000000000..a610c989d --- /dev/null +++ b/src/stream-utils.jl @@ -0,0 +1,24 @@ +function throughput_monitor(ctr, x) + time_start = time_ns() + + Dagger.spawn(ctr, time_start) do ctr, time_start + # Measure throughput + elapsed_time_ns = time_ns() - time_start + elapsed_time_s = elapsed_time_ns / 1e9 + elem_size = sizeof(x) + throughput = (ctr[] * elem_size) / elapsed_time_s + + # Print measured throughput + print("\e[1K\e[100D") + print("Throughput: $(round(throughput; digits=3)) bytes/second") + + # Sleep for a bit + sleep(0.1) + end + function measure_throughput(ctr, x) + ctr[] += 1 + return x + end + + return Dagger.@spawn measure_throughput(ctr, x) +end