diff --git a/Manifest.toml b/Manifest.toml index 0cf9adc65..3b0885f3c 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "5333a6c200b6e6add81c46547527f66ddc0dc16c" +project_hash = "1e12d6aa088ae431916872c11d09544380c7a130" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -12,9 +12,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "b66b8f8e3db5d7835fb8cbe2589ffd1cd456e491" +git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.17.0" +version = "1.23.0" [[deps.ChangesOfVariables]] deps = ["InverseFunctions", "LinearAlgebra", "Test"] @@ -23,10 +23,10 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.8" [[deps.Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c" +deps = ["Dates", "LinearAlgebra", "TOML", "UUIDs"] +git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.10.0" +version = "4.14.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] @@ -34,15 +34,15 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "1.0.1+0" [[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" +version = "1.16.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" +git-tree-sha1 = "0f4b5d62a88d8f59003e43c25a8a90de9eb76317" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" +version = "0.18.18" [[deps.Dates]] deps = ["Printf"] @@ -91,28 +91,28 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.26" +version = "0.3.27" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.11" +version = "0.5.13" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MemPool]] -deps = ["DataStructures", "Distributed", "Mmap", "Random", "Serialization", "Sockets"] -git-tree-sha1 = "b9c1a032c3c1310a857c061ce487c632eaa1faa4" +deps = ["DataStructures", "Distributed", "Mmap", "Random", "ScopedValues", "Serialization", "Sockets"] +git-tree-sha1 = "60dd4ac427d39e0b3f15b193845324523ee71c03" uuid = "f9f48841-c794-520a-933b-121f7ba6ed94" -version = "0.4.4" +version = "0.4.6" [[deps.Missings]] deps = ["DataAPI"] @@ -133,9 +133,9 @@ uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" version = "0.3.20+0" [[deps.OrderedCollections]] -git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.2" +version = "1.6.3" [[deps.PrecompileTools]] deps = ["Preferences"] @@ -145,9 +145,9 @@ version = "1.2.0" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.1" +version = "1.4.3" [[deps.Printf]] deps = ["Unicode"] @@ -173,9 +173,9 @@ version = "0.7.0" [[deps.ScopedValues]] deps = ["HashArrayMappedTries", "Logging"] -git-tree-sha1 = "e3b5e4ccb1702db2ae9ac2a660d4b6b2a8595742" +git-tree-sha1 = "c27d546a4749c81f70d1fabd604da6aa5054e3d2" uuid = "7e506255-f358-4e82-b7e4-beb19740aa63" -version = "1.1.0" +version = "1.2.0" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -189,9 +189,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" +version = "1.2.1" [[deps.SparseArrays]] deps = ["LinearAlgebra", "Random"] diff --git a/docs/make.jl b/docs/make.jl index 679c3b9e9..d6a86bcaa 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,6 +19,7 @@ makedocs(; "Task Spawning" => "task-spawning.md", "Data Management" => "data-management.md", "Distributed Arrays" => "darray.md", + "Streaming Tasks" => "streaming.md", "Scopes" => "scopes.md", "Processors" => "processors.md", "Task Queues" => "task-queues.md", diff --git a/docs/src/streaming.md b/docs/src/streaming.md new file mode 100644 index 000000000..0a13a1472 --- /dev/null +++ b/docs/src/streaming.md @@ -0,0 +1,105 @@ +# Streaming Tasks + +Dagger tasks have a limited lifetime - they are created, execute, finish, and +are eventually destroyed when they're no longer needed. Thus, if one wants +to run the same kind of computations over and over, one might re-create a +similar set of tasks for each unit of data that needs processing. + +This might be fine for computations which take a long time to run (thus +dwarfing the cost of task creation, which is quite small), or when working with +a limited set of data, but this approach is not great for doing lots of small +computations on a large (or endless) amount of data. For example, processing +image frames from a webcam, reacting to messages from a message bus, reading +samples from a software radio, etc. All of these tasks are better suited to a +"streaming" model of data processing, where data is simply piped into a +continuously-running task (or DAG of tasks) forever, or until the data runs +out. + +Thankfully, if you have a problem which is best modeled as a streaming system +of tasks, Dagger has you covered! Building on its support for +["Task Queues"](@ref), Dagger provides a means to convert an entire DAG of +tasks into a streaming DAG, where data flows into and out of each task +asynchronously, using the `spawn_streaming` function: + +```julia +Dagger.spawn_streaming() do # enters a streaming region + vals = Dagger.@spawn rand() + print_vals = Dagger.@spawn println(vals) +end # exits the streaming region, and starts the DAG running +``` + +In the above example, `vals` is a Dagger task which has been transformed to run +in a streaming manner - instead of just calling `rand()` once and returning its +result, it will re-run `rand()` endlessly, continuously producing new random +values. In typical Dagger style, `print_vals` is a Dagger task which depends on +`vals`, but in streaming form - it will continuously `println` the random +values produced from `vals`. Both tasks will run forever, and will run +efficiently, only doing the work necessary to generate, transfer, and consume +values. + +As the comments point out, `spawn_streaming` creates a streaming region, during +which `vals` and `print_vals` are created and configured. Both tasks are halted +until `spawn_streaming` returns, allowing large DAGs to be built all at once, +without any task losing a single value. If desired, streaming regions can be +connected, although some values might be lost while tasks are being connected: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.@spawn rand() +end + +# Some values might be generated by `vals` but thrown away +# before `print_vals` is fully setup and connected to it + +print_vals = Dagger.spawn_streaming() do + Dagger.@spawn println(vals) +end +``` + +More complicated streaming DAGs can be easily constructed, without doing +anything different. For example, we can generate multiple streams of random +numbers, write them all to their own files, and print the combined results: + +```julia +Dagger.spawn_streaming() do + all_vals = [Dagger.spawn(rand) for i in 1:4] + all_vals_written = map(1:4) do i + Dagger.spawn(all_vals[i]) do val + open("results_$i.txt"; write=true, create=true, append=true) do io + println(io, repr(val)) + end + return val + end + end + Dagger.spawn(all_vals_written...) do all_vals_written... + vals_sum = sum(all_vals_written) + println(vals_sum) + end +end +``` + +If you want to stop the streaming DAG and tear it all down, you can call +`Dagger.kill!(all_vals[1])` (or `Dagger.kill!(all_vals_written[2])`, etc., the +kill propagates throughout the DAG). + +Alternatively, tasks can stop themselves from the inside with +`finish_streaming`, optionally returning a value that can be `fetch`'d. Let's +do this when our randomly-drawn number falls within some arbitrary range: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.spawn() do + x = rand() + if x < 0.001 + # That's good enough, let's be done + return Dagger.finish_streaming("Finished!") + end + return x + end +end +fetch(vals) +``` + +In this example, the call to `fetch` will hang (while random numbers continue +to be drawn), until a drawn number is less than 0.001; at that point, `fetch` +will return with "Finished!", and the task `vals` will have terminated. diff --git a/src/Dagger.jl b/src/Dagger.jl index be6ee3075..391e4af66 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -42,6 +42,12 @@ include("utils/system_uuid.jl") include("utils/caching.jl") include("sch/Sch.jl"); using .Sch +# Streaming +include("stream-buffers.jl") +include("stream-fetchers.jl") +include("stream-utils.jl") +include("stream.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") diff --git a/src/eager_thunk.jl b/src/eager_thunk.jl index 8a021737e..bb0766be1 100644 --- a/src/eager_thunk.jl +++ b/src/eager_thunk.jl @@ -29,6 +29,16 @@ end Options(;options...) = Options((;options...)) Options(options...) = Options((;options...)) +""" + EagerThunkMetadata + +Represents some useful metadata pertaining to an `EagerThunk`: +- `return_type::Type` - The inferred return type of the task +""" +mutable struct EagerThunkMetadata + return_type::Type +end + """ EagerThunk @@ -39,9 +49,11 @@ be `fetch`'d or `wait`'d on at any time. mutable struct EagerThunk uid::UInt future::ThunkFuture + metadata::EagerThunkMetadata finalizer_ref::DRef thunk_ref::DRef - EagerThunk(uid, future, finalizer_ref) = new(uid, future, finalizer_ref) + EagerThunk(uid, future, metadata, finalizer_ref) = + new(uid, future, metadata, finalizer_ref) end Base.isready(t::EagerThunk) = isready(t.future) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 3d62ed8bf..40f63ad6c 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -116,6 +116,13 @@ function eager_cleanup(state, uid) # N.B. cache and errored expire automatically delete!(state.thunk_dict, tid) end + remotecall_wait(1, uid) do uid + lock(Dagger.EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + delete!(global_streams, uid) + end + end + end end function _find_thunk(e::Dagger.EagerThunk) diff --git a/src/sch/util.jl b/src/sch/util.jl index c9b72f2c9..71e62a531 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -362,12 +362,19 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) else get(state.signature_alloc_cost, sig, UInt64(0)) end::UInt64 - est_occupancy = if occupancy !== nothing && haskey(occupancy, T) - # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` - Base.unsafe_trunc(UInt32, clamp(occupancy[T], 0, 1) * typemax(UInt32)) - else - typemax(UInt32) - end::UInt32 + est_occupancy::UInt32 = typemax(UInt32) + if occupancy !== nothing + occ = nothing + if haskey(occupancy, T) + occ = occupancy[T] + elseif haskey(occupancy, Any) + occ = occupancy[Any] + end + if occ !== nothing + # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` + est_occupancy = Base.unsafe_trunc(UInt32, clamp(occ, 0, 1) * typemax(UInt32)) + end + end #= FIXME: Estimate if cached data can be swapped to storage storage = storage_resource(p) real_alloc_util = state.worker_storage_pressure[gp][storage] diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl new file mode 100644 index 000000000..753f8c11c --- /dev/null +++ b/src/stream-buffers.jl @@ -0,0 +1,202 @@ +""" +A buffer that drops all elements put into it. Only to be used as the output +buffer for a task - will throw if attached as an input. +""" +struct DropBuffer{T} end +DropBuffer{T}(_) where T = DropBuffer{T}() +Base.isempty(::DropBuffer) = true +isfull(::DropBuffer) = false +Base.put!(::DropBuffer, _) = nothing +Base.take!(::DropBuffer) = error("Cannot `take!` from a DropBuffer") + +"A process-local buffer backed by a `Channel{T}`." +struct ChannelBuffer{T} + channel::Channel{T} + len::Int + count::Threads.Atomic{Int} + ChannelBuffer{T}(len::Int=1024) where T = + new{T}(Channel{T}(len), len, Threads.Atomic{Int}(0)) +end +Base.isempty(cb::ChannelBuffer) = isempty(cb.channel) +isfull(cb::ChannelBuffer) = cb.count[] == cb.len +function Base.put!(cb::ChannelBuffer{T}, x) where T + put!(cb.channel, convert(T, x)) + Threads.atomic_add!(cb.count, 1) +end +function Base.take!(cb::ChannelBuffer) + take!(cb.channel) + Threads.atomic_sub!(cb.count, 1) +end + +"A cross-worker buffer backed by a `RemoteChannel{T}`." +struct RemoteChannelBuffer{T} + channel::RemoteChannel{Channel{T}} + len::Int + count::Threads.Atomic{Int} + RemoteChannelBuffer{T}(len::Int=1024) where T = + new{T}(RemoteChannel(()->Channel{T}(len)), len, Threads.Atomic{Int}(0)) +end +Base.isempty(cb::RemoteChannelBuffer) = isempty(cb.channel) +isfull(cb::RemoteChannelBuffer) = cb.count[] == cb.len +function Base.put!(cb::RemoteChannelBuffer{T}, x) where T + put!(cb.channel, convert(T, x)) + Threads.atomic_add!(cb.count, 1) +end +function Base.take!(cb::RemoteChannelBuffer) + take!(cb.channel) + Threads.atomic_sub!(cb.count, 1) +end + +"A process-local ring buffer." +mutable struct ProcessRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function ProcessRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 +isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +function Base.put!(rb::ProcessRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ProcessRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end + +#= TODO +"A server-local ring buffer backed by shared-memory." +mutable struct ServerRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function ServerRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::ServerRingBuffer) = (@atomic rb.count) == 0 +function Base.put!(rb::ServerRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ServerRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end +=# + +#= +"A TCP-based ring buffer." +mutable struct TCPRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function TCPRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::TCPRingBuffer) = (@atomic rb.count) == 0 +function Base.put!(rb::TCPRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::TCPRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end +=# + +#= +""" +A flexible puller which switches to the most efficient buffer type based +on the sender and receiver locations. +""" +mutable struct UniBuffer{T} + buffer::Union{ProcessRingBuffer{T}, Nothing} +end +function initialize_stream_buffer!(::Type{UniBuffer{T}}, T, send_proc, recv_proc, buffer_amount) where T + if buffer_amount == 0 + error("Return NullBuffer") + end + send_osproc = get_parent(send_proc) + recv_osproc = get_parent(recv_proc) + if send_osproc.pid == recv_osproc.pid + inner = RingBuffer{T}(buffer_amount) + elseif system_uuid(send_osproc.pid) == system_uuid(recv_osproc.pid) + inner = ProcessBuffer{T}(buffer_amount) + else + inner = RemoteBuffer{T}(buffer_amount) + end + return UniBuffer{T}(buffer_amount) +end + +struct LocalPuller{T,B} + buffer::B{T} + id::UInt + function LocalPuller{T,B}(id::UInt, buffer_amount::Integer) where {T,B} + buffer = initialize_stream_buffer!(B, T, buffer_amount) + return new{T,B}(buffer, id) + end +end +function Base.take!(pull::LocalPuller{T,B}) where {T,B} + if pull.buffer === nothing + pull.buffer = + error("Return NullBuffer") + end + value = take!(pull.buffer) +end +function initialize_input_stream!(stream::Stream{T,B}, id::UInt, send_proc::Processor, recv_proc::Processor, buffer_amount::Integer) where {T,B} + local_buffer = remotecall_fetch(stream.ref.handle.owner, stream.ref.handle, id) do ref, id + local_buffer, remote_buffer = initialize_stream_buffer!(B, T, send_proc, recv_proc, buffer_amount) + ref.buffers[id] = remote_buffer + return local_buffer + end + stream.buffer = local_buffer + return stream +end +=# diff --git a/src/stream-fetchers.jl b/src/stream-fetchers.jl new file mode 100644 index 000000000..f8660cdf1 --- /dev/null +++ b/src/stream-fetchers.jl @@ -0,0 +1,24 @@ +struct RemoteFetcher end +function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} + if store_ref.handle.owner == myid() + store = fetch(store_ref)::Store_remote + while !isfull(buffer) + value = take!(store, id)::T + put!(buffer, value) + end + else + tls = Dagger.get_tls() + values = remotecall_fetch(store_ref.handle.owner, store_ref.handle, id, T, Store_remote) do store_ref, id, T, Store_remote + store = MemPool.poolget(store_ref)::Store_remote + values = T[] + while !isempty(store, id) + value = take!(store, id)::T + push!(values, value) + end + return values + end::Vector{T} + for value in values + put!(buffer, value) + end + end +end diff --git a/src/stream-utils.jl b/src/stream-utils.jl new file mode 100644 index 000000000..a610c989d --- /dev/null +++ b/src/stream-utils.jl @@ -0,0 +1,24 @@ +function throughput_monitor(ctr, x) + time_start = time_ns() + + Dagger.spawn(ctr, time_start) do ctr, time_start + # Measure throughput + elapsed_time_ns = time_ns() - time_start + elapsed_time_s = elapsed_time_ns / 1e9 + elem_size = sizeof(x) + throughput = (ctr[] * elem_size) / elapsed_time_s + + # Print measured throughput + print("\e[1K\e[100D") + print("Throughput: $(round(throughput; digits=3)) bytes/second") + + # Sleep for a bit + sleep(0.1) + end + function measure_throughput(ctr, x) + ctr[] += 1 + return x + end + + return Dagger.@spawn measure_throughput(ctr, x) +end diff --git a/src/stream.jl b/src/stream.jl new file mode 100644 index 000000000..34e3abcb9 --- /dev/null +++ b/src/stream.jl @@ -0,0 +1,416 @@ +mutable struct StreamStore{T,B} + waiters::Vector{Int} + buffers::Dict{Int,B} + buffer_amount::Int + open::Bool + lock::Threads.Condition + StreamStore{T,B}(buffer_amount::Integer) where {T,B} = + new{T,B}(zeros(Int, 0), Dict{Int,B}(), buffer_amount, + true, Threads.Condition()) +end +tid() = Dagger.Sch.sch_handle().thunk_id.id +function uid() + thunk_id = tid() + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end +end +function Base.put!(store::StreamStore{T,B}, value) where {T,B} + @lock store.lock begin + if !isopen(store) + @dagdebug nothing :stream_put "[$(uid())] closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug nothing :stream_put "[$(uid())] adding $value" + for buffer in values(store.buffers) + while isfull(buffer) + @dagdebug nothing :stream_put "[$(uid())] buffer full, waiting" + wait(store.lock) + end + put!(buffer, value) + end + notify(store.lock) + end +end +function Base.take!(store::StreamStore, id::UInt) + @lock store.lock begin + buffer = store.buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug nothing :stream_take "[$(uid())] no elements, not taking" + wait(store.lock) + end + @dagdebug nothing :stream_take "[$(uid())] wait finished" + if !isopen(store, id) + @dagdebug nothing :stream_take "[$(uid())] closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + unlock(store.lock) + value = try + take!(buffer) + finally + lock(store.lock) + end + @dagdebug nothing :stream_take "[$(uid())] value accepted" + notify(store.lock) + return value + end +end +Base.isempty(store::StreamStore, id::UInt) = isempty(store.buffers[id]) +isfull(store::StreamStore, id::UInt) = isfull(store.buffers[id]) +"Returns whether the store is actively open. Only check this when deciding if new values can be pushed." +Base.isopen(store::StreamStore) = store.open +"Returns whether the store is actively open, or if closing, still has remaining messages for `id`. Only check this when deciding if existing values can be taken." +function Base.isopen(store::StreamStore, id::UInt) + @lock store.lock begin + if !isempty(store.buffers[id]) + return true + end + return store.open + end +end +function Base.close(store::StreamStore) + if store.open + store.open = false + @lock store.lock notify(store.lock) + end +end +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Int}) where {T,B} + @lock store.lock begin + for w in waiters + buffer = initialize_stream_buffer(B, T, store.buffer_amount) + store.buffers[w] = buffer + end + append!(store.waiters, waiters) + notify(store.lock) + end +end +function remove_waiters!(store::StreamStore, waiters::Vector{Int}) + @lock store.lock begin + for w in waiters + delete!(store.buffers, w) + idx = findfirst(wo->wo==w, store.waiters) + deleteat!(store.waiters, idx) + end + notify(store.lock) + end +end + +mutable struct Stream{T,B} + store::Union{StreamStore{T,B},Nothing} + store_ref::Chunk + input_buffer::Union{B,Nothing} + buffer_amount::Int + function Stream{T,B}(buffer_amount::Integer=0) where {T,B} + # Creates a new output stream + store = StreamStore{T,B}(buffer_amount) + store_ref = tochunk(store) + return new{T,B}(store, store_ref, nothing, buffer_amount) + end + function Stream{B}(stream::Stream{T}, buffer_amount::Integer=0) where {T,B} + # References an existing output stream + return new{T,B}(nothing, stream.store_ref, nothing, buffer_amount) + end +end +function initialize_input_stream!(stream::Stream{T,B}) where {T,B} + stream.input_buffer = initialize_stream_buffer(B, T, stream.buffer_amount) +end + +Base.put!(stream::Stream, @nospecialize(value)) = + put!(stream.store, value) +function Base.take!(stream::Stream{T,B}, id::UInt) where {T,B} + @warn "Make remote fetcher configurable" maxlog=1 + stream_fetch_values!(RemoteFetcher, T, stream.store_ref, stream.input_buffer, id) + return take!(stream.input_buffer) +end +function Base.isopen(stream::Stream, id::UInt)::Bool + return remotecall_fetch(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + return isopen(MemPool.poolget(ref)::StreamStore, id) + end +end +function Base.close(stream::Stream) + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + close(MemPool.poolget(ref)::StreamStore) + end +end +function add_waiters!(stream::Stream, waiters::Vector{Int}) + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + add_waiters!(MemPool.poolget(ref)::StreamStore, waiters) + end +end +add_waiters!(stream::Stream, waiter::Integer) = + add_waiters!(stream::Stream, Int[waiter]) +function remove_waiters!(stream::Stream, waiters::Vector{Int}) + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + remove_waiters!(MemPool.poolget(ref)::StreamStore, waiters) + end +end +remove_waiters!(stream::Stream, waiter::Integer) = + remove_waiters!(stream::Stream, Int[waiter]) + +function migrate_stream!(stream::Stream, w::Integer=myid()) + if !isdefined(MemPool, :migrate!) + @warn "MemPool migration support not enabled!" maxlog=1 + return + end + # Perform migration of the StreamStore + # MemPool will block access to the new ref until the migration completes + if stream.store_ref.handle.owner != w + # Take lock to prevent any further modifications + # N.B. Serialization automatically unlocks + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + lock((MemPool.poolget(ref)::StreamStore).lock) + end + + MemPool.migrate!(stream.store_ref.handle, w) + end +end + +struct StreamingTaskQueue <: AbstractTaskQueue + tasks::Vector{Pair{EagerTaskSpec,EagerThunk}} + self_streams::Dict{UInt,Any} + StreamingTaskQueue() = new(Pair{EagerTaskSpec,EagerThunk}[], + Dict{UInt,Any}()) +end + +function enqueue!(queue::StreamingTaskQueue, spec::Pair{EagerTaskSpec,EagerThunk}) + push!(queue.tasks, spec) + initialize_streaming!(queue.self_streams, spec...) +end +function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{EagerTaskSpec,EagerThunk}}) + append!(queue.tasks, specs) + for (spec, task) in specs + initialize_streaming!(queue.self_streams, spec, task) + end +end +function initialize_streaming!(self_streams, spec, task) + if !isa(spec.f, StreamingFunction) + # Adapt called function for streaming and generate output Streams + T_old = Base.uniontypes(task.metadata.return_type) + T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old) + # We treat non-dominating error paths as unreachable + T_old = filter(t->t !== Union{}, T_old) + T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any + output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) + if output_buffer_amount <= 0 + throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) + end + output_buffer = get(spec.options, :stream_output_buffer, ProcessRingBuffer) + stream = Stream{T,output_buffer}(output_buffer_amount) + spec.options = NamedTuple(filter(opt -> opt[1] != :stream_output_buffer && + opt[1] != :stream_output_buffer_amount, + Base.pairs(spec.options))) + self_streams[task.uid] = stream + + spec.f = StreamingFunction(spec.f, stream) + spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) + + # Register Stream globally + remotecall_wait(1, task.uid, stream) do uid, stream + lock(EAGER_THUNK_STREAMS) do global_streams + global_streams[uid] = stream + end + end + end +end + +function spawn_streaming(f::Base.Callable) + queue = StreamingTaskQueue() + result = with_options(f; task_queue=queue) + if length(queue.tasks) > 0 + finalize_streaming!(queue.tasks, queue.self_streams) + enqueue!(queue.tasks) + end + return result +end + +struct FinishStream{T,R} + value::Union{Some{T},Nothing} + result::R +end +finish_stream(value::T; result::R=nothing) where {T,R} = + FinishStream{T,R}(Some{T}(value), result) +finish_stream(; result::R=nothing) where R = + FinishStream{Union{},R}(nothing, result) + +function cancel_stream!(t::EagerThunk) + stream = task_to_stream(t.uid) + if stream !== nothing + close(stream) + end +end + +struct StreamingFunction{F, S} + f::F + stream::S +end +chunktype(sf::StreamingFunction{F}) where F = F +function (sf::StreamingFunction)(args...; kwargs...) + @nospecialize sf args kwargs + result = nothing + thunk_id = tid() + uid = remotecall_fetch(1, thunk_id) do thunk_id + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end + end + + # Migrate our output stream to this worker + if sf.stream isa Stream + migrate_stream!(sf.stream) + end + + try + # TODO: This kwarg song-and-dance is required to ensure that we don't + # allocate boxes within `stream!`, when possible + kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) + kwarg_values = map(last, (kwargs...,)) + for arg in args + if arg isa Stream + initialize_input_stream!(arg) + end + end + return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) + finally + # Remove ourself as a waiter for upstream Streams + streams = Set{Stream}() + for (idx, arg) in enumerate(args) + if arg isa Stream + push!(streams, arg) + end + end + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + push!(streams, arg) + end + end + for stream in streams + @dagdebug nothing :stream_close "[$uid] dropping waiter" + remove_waiters!(stream, uid) + @dagdebug nothing :stream_close "[$uid] dropped waiter" + end + + # Ensure downstream tasks also terminate + @dagdebug nothing :stream_close "[$uid] closed stream" + close(sf.stream) + end +end +# N.B We specialize to minimize/eliminate allocations +function stream!(sf::StreamingFunction, uid, + args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) + f = move(thunk_processor(), sf.f) + while true + # Get values from Stream args/kwargs + stream_args = _stream_take_values!(args, uid) + stream_kwarg_values = _stream_take_values!(kwarg_values, uid) + stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + + # Run a single cycle of f + stream_result = f(stream_args...; stream_kwargs...) + + # Exit streaming on graceful request + if stream_result isa FinishStream + if stream_result.value !== nothing + value = something(stream_result.value) + put!(sf.stream, value) + end + return stream_result.result + end + + # Put the result into the output stream + put!(sf.stream, stream_result) + end +end +function _stream_take_values!(args, uid) + return ntuple(length(args)) do idx + arg = args[idx] + if arg isa Stream + take!(arg, uid) + else + arg + end + end +end +@inline @generated function _stream_namedtuple(kwarg_names::Tuple, + stream_kwarg_values::Tuple) + name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...) + NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) + return :($NT(stream_kwarg_values)) +end +initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount) + +const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) +function task_to_stream(uid::UInt) + if myid() != 1 + return remotecall_fetch(task_to_stream, 1, uid) + end + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + return global_streams[uid] + end + return + end +end + +function finalize_streaming!(tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}, self_streams) + stream_waiter_changes = Dict{UInt,Vector{Int}}() + + for (spec, task) in tasks + @assert haskey(self_streams, task.uid) + + # Adapt args to accept Stream output of other streaming tasks + for (idx, (pos, arg)) in enumerate(spec.args) + if arg isa EagerThunk + # Check if this is a streaming task + if haskey(self_streams, arg.uid) + other_stream = self_streams[arg.uid] + else + other_stream = task_to_stream(arg.uid) + end + + if other_stream !== nothing + # Get input stream configs and configure input stream + @warn "Support no input buffering" maxlog=1 + input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + input_buffer = get(spec.options, :stream_input_buffer, ProcessRingBuffer) + # FIXME: input_fetcher = get(spec.options, :stream_input_fetcher, RemoteFetcher) + @warn "Accept custom input fetcher" maxlog=1 + input_stream = Stream{input_buffer}(other_stream, input_buffer_amount) + + # Replace the EagerThunk with the input Stream + spec.args[idx] = pos => other_stream + + # Add this task as a waiter for the associated output Stream + changes = get!(stream_waiter_changes, arg.uid) do + Int[] + end + push!(changes, task.uid) + end + end + end + + # Filter out all streaming options + to_filter = (:stream_input_buffer, :stream_input_buffer_amount, + :stream_output_buffer, :stream_output_buffer_amount) + spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), + Base.pairs(spec.options))) + if haskey(spec.options, :propagates) + propagates = filter(opt -> !(opt in to_filter), + spec.options.propagates) + spec.options = merge(spec.options, (;propagates)) + end + end + + # Adjust waiter count of Streams with dependencies + for (uid, waiters) in stream_waiter_changes + stream = task_to_stream(uid) + add_waiters!(stream, waiters) + end +end diff --git a/src/submission.jl b/src/submission.jl index 3898c922a..8d5bc8473 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -183,14 +183,22 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) return options end end +function EagerThunkMetadata(spec::EagerTaskSpec) + f = spec.f isa StreamingFunction ? spec.f.f : spec.f + arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) + return_type = Base.promote_op(f, arg_types...) + return EagerThunkMetadata(return_type) +end +chunktype(t::EagerThunk) = t.metadata.return_type function eager_spawn(spec::EagerTaskSpec) # Generate new EagerThunk uid = eager_next_id() future = ThunkFuture() + metadata = EagerThunkMetadata(spec) finalizer_ref = poolset(EagerThunkFinalizer(uid); device=MemPool.CPURAMDevice()) # Return unlaunched EagerThunk - return EagerThunk(uid, future, finalizer_ref) + return EagerThunk(uid, future, metadata, finalizer_ref) end function eager_launch!((spec, task)::Pair{EagerTaskSpec,EagerThunk}) # Lookup EagerThunk -> ThunkID