From d9e4b3b44bc9bbf4b602deb1f87afbfca5e6a536 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 1 Sep 2022 15:29:37 -0500 Subject: [PATCH 1/3] Add unified hashing system Implements a "semantic" hashing algorithm which hashes Thunks based on the functional behavior of the code being executed. The intention is to have a hash which has an identical value across different Julia sessions for tasks which compute the same value. This is important for implementing a "headless" worker-worker cluster, where there is no coordinating head worker, and all workers can see the entire computational program. Hashes are computed automatically and can be queried with `get_task_hash()` while running in a task context, or directly as `get_task_hash(task)` for any Dagger task type. Hashes are also provided within `Dagger.move` calls, where the input task's hash is also available. --- src/Dagger.jl | 1 + src/chunks.jl | 7 +++--- src/processor.jl | 41 +++++++++++++++++++++++++------ src/sch/Sch.jl | 37 +++++++++++++++++++--------- src/sch/util.jl | 2 +- src/thunk.jl | 22 ++++++++++++++--- src/utils/uhash.jl | 61 ++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 145 insertions(+), 26 deletions(-) create mode 100644 src/utils/uhash.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 63b0834b1..a9b947921 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -36,6 +36,7 @@ include("chunks.jl") include("compute.jl") include("utils/clock.jl") include("utils/system_uuid.jl") +include("utils/uhash.jl") include("sch/Sch.jl"); using .Sch # Array computations diff --git a/src/chunks.jl b/src/chunks.jl index 9745cb598..5c3d1533f 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -53,6 +53,7 @@ mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope} processor::P scope::S persist::Bool + hash::UInt end domain(c::Chunk) = c.domain @@ -242,7 +243,7 @@ be used. All other kwargs are passed directly to `MemPool.poolset`. """ -function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S} +function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, hash=UInt(0), kwargs...) where {X,P,S} if device === nothing device = if Sch.walk_storage_safe(x) MemPool.GLOBAL_DEVICE[] @@ -250,8 +251,8 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cac MemPool.CPURAMDevice() end end - ref = poolset(x; device, kwargs...) - Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist) + ref = poolset(move(OSProc(), proc, x); device, kwargs...) + Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist, hash) end tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x diff --git a/src/processor.jl b/src/processor.jl index 91091a190..e99399dc0 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -288,19 +288,40 @@ end # In-Thunk Helpers """ - thunk_processor() + thunk_processor() -> Dagger.Processor Get the current processor executing the current thunk. """ thunk_processor() = task_local_storage(:_dagger_processor)::Processor """ - in_thunk() + in_thunk() -> Bool Returns `true` if currently in a [`Thunk`](@ref) process, else `false`. """ in_thunk() = haskey(task_local_storage(), :_dagger_sch_uid) +""" + get_task_hash(kind::Symbol=:self) -> UInt + +Returns the unified hash of the current task or of an input to the current +task. If `kind == :self`, then the hash is for the current task; if `kind == +:input`, then the hash is for the current input to the task that is being +processed. The `:self` hash is available during `Dagger.execute!` and +`Dagger.move`, whereas the `:input` hash is only available during +`Dagger.move`. This hash is consistent across Julia processes (if all +processes are running the same Julia version on the same architecture). +""" +function get_task_hash(kind::Symbol=:self)::UInt + if kind == :self + return task_local_storage(:_dagger_task_hash)::UInt + elseif kind == :input + return task_local_storage(:_dagger_input_hash)::UInt + else + throw(ArgumentError("Invalid task hash kind: $kind")) + end +end + """ get_tls() @@ -309,6 +330,8 @@ Gets all Dagger TLS variable as a `NamedTuple`. get_tls() = ( sch_uid=task_local_storage(:_dagger_sch_uid), sch_handle=task_local_storage(:_dagger_sch_handle), + task_hash=task_local_storage(:_dagger_task_hash), + input_hash=get(task_local_storage(), :_dagger_input_hash, nothing), processor=thunk_processor(), time_utilization=task_local_storage(:_dagger_time_utilization), alloc_utilization=task_local_storage(:_dagger_alloc_utilization), @@ -320,9 +343,13 @@ get_tls() = ( Sets all Dagger TLS variables from the `NamedTuple` `tls`. """ function set_tls!(tls) - task_local_storage(:_dagger_sch_uid, tls.sch_uid) - task_local_storage(:_dagger_sch_handle, tls.sch_handle) - task_local_storage(:_dagger_processor, tls.processor) - task_local_storage(:_dagger_time_utilization, tls.time_utilization) - task_local_storage(:_dagger_alloc_utilization, tls.alloc_utilization) + task_local_storage(:_dagger_sch_uid, get(tls, :sch_uid, nothing)) + task_local_storage(:_dagger_sch_handle, get(tls, :sch_handle, nothing)) + task_local_storage(:_dagger_task_hash, get(tls, :task_hash, nothing)) + if haskey(tls, :input_hash) && tls.input_hash !== nothing + task_local_storage(:_dagger_input_hash, tls.input_hash) + end + task_local_storage(:_dagger_processor, get(tls, :processor, nothing)) + task_local_storage(:_dagger_time_utilization, get(tls, :time_utilization, nothing)) + task_local_storage(:_dagger_alloc_utilization, get(tls, :alloc_utilization, nothing)) end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index b68a519c8..bee0717ef 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -9,7 +9,7 @@ import Random: randperm import ..Dagger import ..Dagger: Context, Processor, Thunk, WeakThunk, ThunkFuture, ThunkFailedException, Chunk, OSProc, AnyScope -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, default_enabled, get_processors, get_parent, execute!, rmprocs!, addprocs!, thunk_processor, constrain, cputhreadtime +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, default_enabled, get_processors, get_parent, execute!, rmprocs!, addprocs!, thunk_processor, constrain, cputhreadtime, uhash const OneToMany = Dict{Thunk, Set{Thunk}} @@ -613,6 +613,10 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) @assert !haskey(state.cache, task) opts = merge(ctx.options, task.options) sig = signature(task, state) + if task.hash == UInt(0) + # Compute the hash and cache it in the task + uhash(task, UInt(0); sig) + end # Calculate scope scope = if task.f isa Chunk @@ -672,7 +676,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util push!(get!(()->Vector{Tuple{Thunk,<:Any,<:Any}}(), to_fire, (gproc, proc)), (task, est_time_util, est_alloc_util)) - state.worker_time_pressure[gproc.pid][proc] += est_time_util + state.worker_time_pressure[gproc.pid][proc] = get(state.worker_time_pressure[gproc.pid], proc, UInt64(0)) + est_time_util @goto pop_task end end @@ -893,10 +897,12 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) ids = Int[0] data = Any[thunk.f] + hashes = Union{UInt,Nothing}[uhash(thunk.f, UInt(0))] for (idx, x) in enumerate(thunk.inputs) x = unwrap_weak_checked(x) push!(ids, istask(x) ? x.id : -idx) push!(data, istask(x) ? state.cache[x] : x) + push!(hashes, uhash(x, UInt(0))) end toptions = thunk.options !== nothing ? thunk.options : ThunkOptions() options = merge(ctx.options, toptions) @@ -906,9 +912,10 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) - push!(to_send, Any[thunk.id, time_util, alloc_util, chunktype(thunk.f), data, thunk.get_result, + push!(to_send, Any[thunk.id, thunk.hash, + time_util, alloc_util, chunktype(thunk.f), data, thunk.get_result, thunk.persist, thunk.cache, thunk.meta, options, - propagated, ids, + propagated, ids, hashes, (log_sink=ctx.log_sink, profile=ctx.profile), sch_handle, state.uid]) end @@ -964,7 +971,7 @@ function do_tasks(to_proc, chan, tasks) end "Executes a single task on `to_proc`." function do_task(to_proc, comm) - thunk_id, est_time_util, est_alloc_util, Tf, data, send_result, persist, cache, meta, options, propagated, ids, ctx_vars, sch_handle, uid = comm + thunk_id, task_hash, est_time_util, est_alloc_util, Tf, data, send_result, persist, cache, meta, options, propagated, ids, hashes, ctx_vars, sch_handle, sch_uid = comm ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) from_proc = OSProc() @@ -1002,7 +1009,7 @@ function do_task(to_proc, comm) lock(TASK_SYNC) do while true # Get current time utilization for the selected processor - time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, uid) + time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, sch_uid) real_time_util = get!(()->Ref{UInt64}(UInt64(0)), time_dict, to_proc) # Get current allocation utilization and capacity @@ -1043,14 +1050,19 @@ function do_task(to_proc, comm) # Initiate data transfers for function and arguments transfer_time = Threads.Atomic{UInt64}(0) transfer_size = Threads.Atomic{UInt64}(0) - _data, _ids = if meta - (Any[first(data)], Int[first(ids)]) # always fetch function + _data, _ids, _hashes = if meta + (Any[first(data)], Int[first(ids)], Union{UInt,Nothing}[first(hashes)]) # always fetch function else - (data, ids) + (data, ids, hashes) end - fetch_tasks = map(Iterators.zip(_data,_ids)) do (x, id) + fetch_tasks = map(Iterators.zip(_data, _ids, _hashes)) do (x, id, hash) @async begin timespan_start(ctx, :move, (;thunk_id, id), (;f, id, data=x)) + Dagger.set_tls!(( + sch_uid=sch_uid, + input_hash=hash, + task_hash, + )) x = if x isa Chunk value = lock(TASK_SYNC) do if haskey(CHUNK_CACHE, x) @@ -1123,8 +1135,9 @@ function do_task(to_proc, comm) result_meta = try # Set TLS variables Dagger.set_tls!(( - sch_uid=uid, + sch_uid, sch_handle=sch_handle, + task_hash, processor=to_proc, time_utilization=est_time_util, alloc_utilization=est_alloc_util, @@ -1149,7 +1162,7 @@ function do_task(to_proc, comm) # Construct result # TODO: We should cache this locally - send_result || meta ? res : tochunk(res, to_proc; device, persist, cache=persist ? true : cache) + send_result || meta ? res : tochunk(res, to_proc; device, persist, cache=persist ? true : cache, hash=task_hash) catch ex bt = catch_backtrace() RemoteException(myid(), CapturedException(ex, bt)) diff --git a/src/sch/util.jl b/src/sch/util.jl index f6fffa360..5ac2f0eb6 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -394,7 +394,7 @@ function estimate_task_costs(state, procs, task, inputs) transfer_costs = Dict(proc=>impute_sum([affinity(chunk)[2] for chunk in filter(c->get_parent(processor(c))!=get_parent(proc), chunks)]) for proc in procs) # Estimate total cost to move data and get task running after currently-scheduled tasks - costs = Dict(proc=>state.worker_time_pressure[get_parent(proc).pid][proc]+(tx_cost/tx_rate) for (proc, tx_cost) in transfer_costs) + costs = Dict(proc=>get(state.worker_time_pressure[get_parent(proc).pid], proc, UInt64(0))+(tx_cost/tx_rate) for (proc, tx_cost) in transfer_costs) # Shuffle procs around, so equally-costly procs are equally considered P = randperm(length(procs)) diff --git a/src/thunk.jl b/src/thunk.jl index 72fb2576f..603f53ac3 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -52,6 +52,7 @@ mutable struct Thunk f::Any # usually a Function, but could be any callable inputs::Vector{Any} # TODO: Use `ImmutableArray` in 1.8 id::Int + hash::UInt get_result::Bool # whether the worker should send the result or only the metadata meta::Bool persist::Bool # don't `free!` result after computing @@ -64,6 +65,7 @@ mutable struct Thunk propagates::Tuple # which options we'll propagate function Thunk(f, xs...; id::Int=next_id(), + hash=UInt(0), get_result::Bool=false, meta::Bool=false, persist::Bool=false, @@ -85,10 +87,10 @@ mutable struct Thunk xs = Any[xs...] if options !== nothing @assert isempty(kwargs) - new(f, xs, id, get_result, meta, persist, cache, cache_ref, + new(f, xs, id, hash, get_result, meta, persist, cache, cache_ref, affinity, eager_ref, options, propagates) else - new(f, xs, id, get_result, meta, persist, cache, cache_ref, + new(f, xs, id, hash, get_result, meta, persist, cache, cache_ref, affinity, eager_ref, Sch.ThunkOptions(;kwargs...), propagates) end end @@ -96,6 +98,8 @@ end Serialization.serialize(io::AbstractSerializer, t::Thunk) = throw(ArgumentError("Cannot serialize a Thunk")) +get_task_hash(t::Thunk) = t.hash + function affinity(t::Thunk) if t.affinity !== nothing return t.affinity @@ -183,6 +187,7 @@ end unwrap_weak_checked(t) = t Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) +get_task_hash(t::WeakThunk) = unwrap_weak_checked(t).hash struct ThunkFailedException{E<:Exception} <: Exception thunk::WeakThunk @@ -223,12 +228,23 @@ function Base.fetch(t::EagerThunk; raw=false) if raw fetch(t.future; raw=true) else - move(OSProc(), fetch(t.future)) + value = fetch(t.future) + if value isa Chunk + return fetch(@async begin + Dagger.set_tls!((input_hash=value.hash, + task_hash=value.hash)) + return move(OSProc(), value) + end) + else + return move(OSProc(), value) + end end end function Base.show(io::IO, t::EagerThunk) print(io, "EagerThunk ($(isready(t) ? "finished" : "running"))") end +get_task_hash(t::EagerThunk) = + remotecall_fetch(d->get_task_hash(poolget(d)), t.thunk_ref.owner, t.thunk_ref) "When finalized, cleans-up the associated `EagerThunk`." mutable struct EagerThunkFinalizer diff --git a/src/utils/uhash.jl b/src/utils/uhash.jl new file mode 100644 index 000000000..9a625906a --- /dev/null +++ b/src/utils/uhash.jl @@ -0,0 +1,61 @@ +# unified hash algorithm + +using Dagger + +uhash(x, h::UInt)::UInt = hash(x, h) +function uhash(x::Dagger.Thunk, h::UInt; sig=nothing)::UInt + value = hash(0xdead7453, h) + if x.hash != UInt(0) + return uhash(x.hash, value) + end + @assert sig !== nothing + tt = Any[] + for input in x.inputs + input = unwrap_weak_checked(input) + if input isa Dagger.Thunk && input.hash != UInt(0) + value = uhash(input.hash, value) + else + value = uhash(input, value) + push!(tt, typeof(input)) + end + end + sig = (typeof(x.f), tt) + value = uhash_sig(sig, value) + x.hash = value + return value +end +uhash(x::Dagger.WeakThunk, h::UInt)::UInt = + uhash(Dagger.unwrap_weak_checked(x), h) +function uhash_sig((f, tt), h::UInt)::UInt + value = hash(0xdead5160, h) + ci_list = Base.code_typed(f, tt) + if length(ci_list) == 0 + return hash(Union{}, hash(typeof(f), hash(tt, value))) + end + # tt must be concrete + ci = first(only(ci_list))::Core.CodeInfo + return uhash_code(ci, hash(typeof(f), hash(tt, value))) +end +function uhash_code(ci::Core.CodeInfo, h::UInt)::UInt + value = hash(0xdeadc0de, h) + for insn in ci.code + dump(insn) + value = uhash_insn(insn, h) + end + return value +end +function uhash_insn(insn::Expr, h::UInt)::UInt + value = hash(0xdeadeec54, h) + value = hash(insn.head, value) + for arg in insn.args + dump(insn) + @show uhash_insn(arg, value) + value = uhash_insn(arg, value) + end + return value +end +function uhash_insn(insn::GlobalRef, h::UInt)::UInt + value = hash(0xdead6147, h) + return hash(nameof(insn.mod), hash(insn.name, value)) +end +uhash_insn(insn, h::UInt)::UInt = hash(insn, h) From 6e9eb36298b6ee7bf0752563fc6630eff787e29f Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 1 Sep 2022 15:36:29 -0500 Subject: [PATCH 2/3] Add DaggerMPI subpackage for MPI integrations Building on Dagger's unified hashing framework, DaggerMPI.jl allows DAGs to execute efficiently under an MPI cluster. Per-task hashes are used to "color" the DAG, disabling execution of each task on all but one MPI worker. Data movement is typically peer-to-peer using MPI Send and Recv, and is coordinated by using tags computed from the same coloring scheme. This scheme allows Dagger's scheduler to remain unmodified and unaware of the existence of an MPI cluster, while still providing "exactly once" execution semantics for each task in the DAG. --- lib/DaggerMPI/Project.toml | 8 ++ lib/DaggerMPI/src/DaggerMPI.jl | 159 +++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 lib/DaggerMPI/Project.toml create mode 100644 lib/DaggerMPI/src/DaggerMPI.jl diff --git a/lib/DaggerMPI/Project.toml b/lib/DaggerMPI/Project.toml new file mode 100644 index 000000000..33b4dc92b --- /dev/null +++ b/lib/DaggerMPI/Project.toml @@ -0,0 +1,8 @@ +name = "DaggerMPI" +uuid = "37bfb287-2338-4693-8557-581796463535" +authors = ["Julian P Samaroo "] +version = "0.1.0" + +[deps] +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" diff --git a/lib/DaggerMPI/src/DaggerMPI.jl b/lib/DaggerMPI/src/DaggerMPI.jl new file mode 100644 index 000000000..727b375c8 --- /dev/null +++ b/lib/DaggerMPI/src/DaggerMPI.jl @@ -0,0 +1,159 @@ +module DaggerMPI + +using Dagger +using MPI + +struct MPIProcessor{P,C} <: Dagger.Processor + proc::P + comm::MPI.Comm + color_algo::C +end + +struct SimpleColoring end +function (sc::SimpleColoring)(comm, key) + return UInt64(rem(key, MPI.Comm_size(comm))) +end + +const MPI_PROCESSORS = Ref{Int}(-1) + +const PREVIOUS_PROCESSORS = Set() + +function initialize(comm::MPI.Comm=MPI.COMM_WORLD; color_algo=SimpleColoring()) + @assert MPI_PROCESSORS[] == -1 "DaggerMPI already initialized" + + # Force eager_thunk to run + fetch(Dagger.@spawn 1+1) + + MPI.Init(; finalize_atexit=false) + procs = Dagger.get_processors(OSProc()) + i = 0 + empty!(Dagger.PROCESSOR_CALLBACKS) + empty!(Dagger.OSPROC_PROCESSOR_CACHE) + for proc in procs + Dagger.add_processor_callback!("mpiprocessor_$i") do + return MPIProcessor(proc, comm, color_algo) + end + i += 1 + end + MPI_PROCESSORS[] = i + + # FIXME: Hack to populate new processors + Dagger.get_processors(OSProc()) + + return nothing +end + +function finalize() + @assert MPI_PROCESSORS[] > -1 "DaggerMPI not yet initialized" + for i in 1:MPI_PROCESSORS[] + Dagger.delete_processor_callback!("mpiprocessor_$i") + end + empty!(Dagger.PROCESSOR_CALLBACKS) + empty!(Dagger.OSPROC_PROCESSOR_CACHE) + i = 1 + for proc in PREVIOUS_PROCESSORS + Dagger.add_processor_callback!("old_processor_$i") do + return proc + end + i += 1 + end + empty!(PREVIOUS_PROCESSORS) + MPI.Finalize() + MPI_PROCESSORS[] = -1 +end + +"References a value stored on some MPI rank." +struct MPIColoredValue{T} + color::UInt64 + value::T + comm::MPI.Comm +end + +Dagger.get_parent(proc::MPIProcessor) = Dagger.OSProc() +Dagger.default_enabled(proc::MPIProcessor) = true + +"Busy-loop Irecv that yields to other tasks." +function recv_yield(src, tag, comm) + while true + got, value, _ = MPI.irecv(src, tag, comm) + if got + return value + end + # TODO: Sigmoidal backoff + yield() + end +end + +function Dagger.execute!(proc::MPIProcessor, f, args...) + rank = MPI.Comm_rank(proc.comm) + tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash() >> 32)) + color = proc.color_algo(proc.comm, tag) + if rank == color + @debug "[$rank] Executing $f on $tag" + return MPIColoredValue(color, Dagger.execute!(proc.proc, f, args...), proc.comm) + end + # Return nothing, we won't use this value anyway + @debug "[$rank] Skipped $f on $tag" + return MPIColoredValue(color, nothing, proc.comm) +end + +function Dagger.move(from_proc::MPIProcessor, to_proc::MPIProcessor, x::Dagger.Chunk) + @assert from_proc.comm == to_proc.comm "Mixing different MPI communicators is not supported" + @assert Dagger.chunktype(x) <: MPIColoredValue + x_value = fetch(x) + rank = MPI.Comm_rank(from_proc.comm) + tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:input) >> 32)) + other_tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:self) >> 32)) + other = from_proc.color_algo(from_proc.comm, other_tag) + if x_value.color == rank == other + # We generated and will use this input + return Dagger.move(from_proc.proc, to_proc.proc, x_value.value) + elseif x_value.color == rank + # We generated this input + @debug "[$rank] Starting P2P send to [$other] from $tag to $other_tag" + MPI.isend(x_value.value, other, tag, from_proc.comm) + @debug "[$rank] Finished P2P send to [$other] from $tag to $other_tag" + return Dagger.move(from_proc.proc, to_proc.proc, x_value.value) + elseif other == rank + # We will use this input + @debug "[$rank] Starting P2P recv from $tag to $other_tag" + value = recv_yield(x_value.color, tag, from_proc.comm) + @debug "[$rank] Finished P2P recv from $tag to $other_tag" + return Dagger.move(from_proc.proc, to_proc.proc, value) + else + # We didn't generate and will not use this input + return nothing + end +end + +function Dagger.move(from_proc::MPIProcessor, to_proc::Dagger.Processor, x::Dagger.Chunk) + @assert Dagger.chunktype(x) <: MPIColoredValue + x_value = fetch(x) + rank = MPI.Comm_rank(from_proc.comm) + tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:input) >> 32)) + if rank == x_value.color + # FIXME: Broadcast send + @sync for other in 0:(MPI.Comm_size(from_proc.comm)-1) + other == rank && continue + @async begin + @debug "[$rank] Starting bcast send to [$other] on $tag" + MPI.isend(x_value.value, other, tag, from_proc.comm) + @debug "[$rank] Finished bcast send to [$other] on $tag" + end + end + return Dagger.move(from_proc.proc, to_proc, x_value.value) + else + @debug "[$rank] Starting bcast recv on $tag" + value = recv_yield(x_value.color, tag, from_proc.comm) + @debug "[$rank] Finished bcast recv on $tag" + return Dagger.move(from_proc.proc, to_proc, value) + end +end + +function Dagger.move(from_proc::Dagger.Processor, to_proc::MPIProcessor, x::Dagger.Chunk) + @assert !(Dagger.chunktype(x) <: MPIColoredValue) + rank = MPI.Comm_rank(to_proc.comm) + return MPIColoredValue(rank, Dagger.move(from_proc, to_proc.proc, x), from_proc.comm) +end + +end # module From 3ee8a42584283d03ddd2c78b9c0c231eea503c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20de=20Alc=C3=A2ntara=20Tom=C3=A9?= Date: Tue, 18 Apr 2023 14:36:11 -0300 Subject: [PATCH 3/3] Changing the receive and yield function to accomodate new MPI implementations --- lib/DaggerMPI/src/DaggerMPI.jl | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/DaggerMPI/src/DaggerMPI.jl b/lib/DaggerMPI/src/DaggerMPI.jl index 727b375c8..7f185ed2f 100644 --- a/lib/DaggerMPI/src/DaggerMPI.jl +++ b/lib/DaggerMPI/src/DaggerMPI.jl @@ -18,6 +18,7 @@ const MPI_PROCESSORS = Ref{Int}(-1) const PREVIOUS_PROCESSORS = Set() + function initialize(comm::MPI.Comm=MPI.COMM_WORLD; color_algo=SimpleColoring()) @assert MPI_PROCESSORS[] == -1 "DaggerMPI already initialized" @@ -72,12 +73,23 @@ end Dagger.get_parent(proc::MPIProcessor) = Dagger.OSProc() Dagger.default_enabled(proc::MPIProcessor) = true + "Busy-loop Irecv that yields to other tasks." function recv_yield(src, tag, comm) - while true - got, value, _ = MPI.irecv(src, tag, comm) + while true + (got, msg, stat) = MPI.Improbe(src, tag, comm, MPI.Status) if got - return value + count = MPI.Get_count(stat, UInt8) + buf = Array{UInt8}(undef, count) + req = MPI.Imrecv!(MPI.Buffer(buf), msg) + while true + finish = MPI.Test(req) + if finish + value = MPI.deserialize(buf) + return value + end + yield() + end end # TODO: Sigmoidal backoff yield()