diff --git a/lib/DaggerMPI/Project.toml b/lib/DaggerMPI/Project.toml new file mode 100644 index 000000000..33b4dc92b --- /dev/null +++ b/lib/DaggerMPI/Project.toml @@ -0,0 +1,8 @@ +name = "DaggerMPI" +uuid = "37bfb287-2338-4693-8557-581796463535" +authors = ["Julian P Samaroo "] +version = "0.1.0" + +[deps] +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" diff --git a/lib/DaggerMPI/src/DaggerMPI.jl b/lib/DaggerMPI/src/DaggerMPI.jl new file mode 100644 index 000000000..7f185ed2f --- /dev/null +++ b/lib/DaggerMPI/src/DaggerMPI.jl @@ -0,0 +1,171 @@ +module DaggerMPI + +using Dagger +using MPI + +struct MPIProcessor{P,C} <: Dagger.Processor + proc::P + comm::MPI.Comm + color_algo::C +end + +struct SimpleColoring end +function (sc::SimpleColoring)(comm, key) + return UInt64(rem(key, MPI.Comm_size(comm))) +end + +const MPI_PROCESSORS = Ref{Int}(-1) + +const PREVIOUS_PROCESSORS = Set() + + +function initialize(comm::MPI.Comm=MPI.COMM_WORLD; color_algo=SimpleColoring()) + @assert MPI_PROCESSORS[] == -1 "DaggerMPI already initialized" + + # Force eager_thunk to run + fetch(Dagger.@spawn 1+1) + + MPI.Init(; finalize_atexit=false) + procs = Dagger.get_processors(OSProc()) + i = 0 + empty!(Dagger.PROCESSOR_CALLBACKS) + empty!(Dagger.OSPROC_PROCESSOR_CACHE) + for proc in procs + Dagger.add_processor_callback!("mpiprocessor_$i") do + return MPIProcessor(proc, comm, color_algo) + end + i += 1 + end + MPI_PROCESSORS[] = i + + # FIXME: Hack to populate new processors + Dagger.get_processors(OSProc()) + + return nothing +end + +function finalize() + @assert MPI_PROCESSORS[] > -1 "DaggerMPI not yet initialized" + for i in 1:MPI_PROCESSORS[] + Dagger.delete_processor_callback!("mpiprocessor_$i") + end + empty!(Dagger.PROCESSOR_CALLBACKS) + empty!(Dagger.OSPROC_PROCESSOR_CACHE) + i = 1 + for proc in PREVIOUS_PROCESSORS + Dagger.add_processor_callback!("old_processor_$i") do + return proc + end + i += 1 + end + empty!(PREVIOUS_PROCESSORS) + MPI.Finalize() + MPI_PROCESSORS[] = -1 +end + +"References a value stored on some MPI rank." +struct MPIColoredValue{T} + color::UInt64 + value::T + comm::MPI.Comm +end + +Dagger.get_parent(proc::MPIProcessor) = Dagger.OSProc() +Dagger.default_enabled(proc::MPIProcessor) = true + + +"Busy-loop Irecv that yields to other tasks." +function recv_yield(src, tag, comm) + while true + (got, msg, stat) = MPI.Improbe(src, tag, comm, MPI.Status) + if got + count = MPI.Get_count(stat, UInt8) + buf = Array{UInt8}(undef, count) + req = MPI.Imrecv!(MPI.Buffer(buf), msg) + while true + finish = MPI.Test(req) + if finish + value = MPI.deserialize(buf) + return value + end + yield() + end + end + # TODO: Sigmoidal backoff + yield() + end +end + +function Dagger.execute!(proc::MPIProcessor, f, args...) + rank = MPI.Comm_rank(proc.comm) + tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash() >> 32)) + color = proc.color_algo(proc.comm, tag) + if rank == color + @debug "[$rank] Executing $f on $tag" + return MPIColoredValue(color, Dagger.execute!(proc.proc, f, args...), proc.comm) + end + # Return nothing, we won't use this value anyway + @debug "[$rank] Skipped $f on $tag" + return MPIColoredValue(color, nothing, proc.comm) +end + +function Dagger.move(from_proc::MPIProcessor, to_proc::MPIProcessor, x::Dagger.Chunk) + @assert from_proc.comm == to_proc.comm "Mixing different MPI communicators is not supported" + @assert Dagger.chunktype(x) <: MPIColoredValue + x_value = fetch(x) + rank = MPI.Comm_rank(from_proc.comm) + tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:input) >> 32)) + other_tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:self) >> 32)) + other = from_proc.color_algo(from_proc.comm, other_tag) + if x_value.color == rank == other + # We generated and will use this input + return Dagger.move(from_proc.proc, to_proc.proc, x_value.value) + elseif x_value.color == rank + # We generated this input + @debug "[$rank] Starting P2P send to [$other] from $tag to $other_tag" + MPI.isend(x_value.value, other, tag, from_proc.comm) + @debug "[$rank] Finished P2P send to [$other] from $tag to $other_tag" + return Dagger.move(from_proc.proc, to_proc.proc, x_value.value) + elseif other == rank + # We will use this input + @debug "[$rank] Starting P2P recv from $tag to $other_tag" + value = recv_yield(x_value.color, tag, from_proc.comm) + @debug "[$rank] Finished P2P recv from $tag to $other_tag" + return Dagger.move(from_proc.proc, to_proc.proc, value) + else + # We didn't generate and will not use this input + return nothing + end +end + +function Dagger.move(from_proc::MPIProcessor, to_proc::Dagger.Processor, x::Dagger.Chunk) + @assert Dagger.chunktype(x) <: MPIColoredValue + x_value = fetch(x) + rank = MPI.Comm_rank(from_proc.comm) + tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:input) >> 32)) + if rank == x_value.color + # FIXME: Broadcast send + @sync for other in 0:(MPI.Comm_size(from_proc.comm)-1) + other == rank && continue + @async begin + @debug "[$rank] Starting bcast send to [$other] on $tag" + MPI.isend(x_value.value, other, tag, from_proc.comm) + @debug "[$rank] Finished bcast send to [$other] on $tag" + end + end + return Dagger.move(from_proc.proc, to_proc, x_value.value) + else + @debug "[$rank] Starting bcast recv on $tag" + value = recv_yield(x_value.color, tag, from_proc.comm) + @debug "[$rank] Finished bcast recv on $tag" + return Dagger.move(from_proc.proc, to_proc, value) + end +end + +function Dagger.move(from_proc::Dagger.Processor, to_proc::MPIProcessor, x::Dagger.Chunk) + @assert !(Dagger.chunktype(x) <: MPIColoredValue) + rank = MPI.Comm_rank(to_proc.comm) + return MPIColoredValue(rank, Dagger.move(from_proc, to_proc.proc, x), from_proc.comm) +end + +end # module diff --git a/src/Dagger.jl b/src/Dagger.jl index 63b0834b1..a9b947921 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -36,6 +36,7 @@ include("chunks.jl") include("compute.jl") include("utils/clock.jl") include("utils/system_uuid.jl") +include("utils/uhash.jl") include("sch/Sch.jl"); using .Sch # Array computations diff --git a/src/chunks.jl b/src/chunks.jl index 9745cb598..5c3d1533f 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -53,6 +53,7 @@ mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope} processor::P scope::S persist::Bool + hash::UInt end domain(c::Chunk) = c.domain @@ -242,7 +243,7 @@ be used. All other kwargs are passed directly to `MemPool.poolset`. """ -function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S} +function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, hash=UInt(0), kwargs...) where {X,P,S} if device === nothing device = if Sch.walk_storage_safe(x) MemPool.GLOBAL_DEVICE[] @@ -250,8 +251,8 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cac MemPool.CPURAMDevice() end end - ref = poolset(x; device, kwargs...) - Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist) + ref = poolset(move(OSProc(), proc, x); device, kwargs...) + Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist, hash) end tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x diff --git a/src/processor.jl b/src/processor.jl index 91091a190..e99399dc0 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -288,19 +288,40 @@ end # In-Thunk Helpers """ - thunk_processor() + thunk_processor() -> Dagger.Processor Get the current processor executing the current thunk. """ thunk_processor() = task_local_storage(:_dagger_processor)::Processor """ - in_thunk() + in_thunk() -> Bool Returns `true` if currently in a [`Thunk`](@ref) process, else `false`. """ in_thunk() = haskey(task_local_storage(), :_dagger_sch_uid) +""" + get_task_hash(kind::Symbol=:self) -> UInt + +Returns the unified hash of the current task or of an input to the current +task. If `kind == :self`, then the hash is for the current task; if `kind == +:input`, then the hash is for the current input to the task that is being +processed. The `:self` hash is available during `Dagger.execute!` and +`Dagger.move`, whereas the `:input` hash is only available during +`Dagger.move`. This hash is consistent across Julia processes (if all +processes are running the same Julia version on the same architecture). +""" +function get_task_hash(kind::Symbol=:self)::UInt + if kind == :self + return task_local_storage(:_dagger_task_hash)::UInt + elseif kind == :input + return task_local_storage(:_dagger_input_hash)::UInt + else + throw(ArgumentError("Invalid task hash kind: $kind")) + end +end + """ get_tls() @@ -309,6 +330,8 @@ Gets all Dagger TLS variable as a `NamedTuple`. get_tls() = ( sch_uid=task_local_storage(:_dagger_sch_uid), sch_handle=task_local_storage(:_dagger_sch_handle), + task_hash=task_local_storage(:_dagger_task_hash), + input_hash=get(task_local_storage(), :_dagger_input_hash, nothing), processor=thunk_processor(), time_utilization=task_local_storage(:_dagger_time_utilization), alloc_utilization=task_local_storage(:_dagger_alloc_utilization), @@ -320,9 +343,13 @@ get_tls() = ( Sets all Dagger TLS variables from the `NamedTuple` `tls`. """ function set_tls!(tls) - task_local_storage(:_dagger_sch_uid, tls.sch_uid) - task_local_storage(:_dagger_sch_handle, tls.sch_handle) - task_local_storage(:_dagger_processor, tls.processor) - task_local_storage(:_dagger_time_utilization, tls.time_utilization) - task_local_storage(:_dagger_alloc_utilization, tls.alloc_utilization) + task_local_storage(:_dagger_sch_uid, get(tls, :sch_uid, nothing)) + task_local_storage(:_dagger_sch_handle, get(tls, :sch_handle, nothing)) + task_local_storage(:_dagger_task_hash, get(tls, :task_hash, nothing)) + if haskey(tls, :input_hash) && tls.input_hash !== nothing + task_local_storage(:_dagger_input_hash, tls.input_hash) + end + task_local_storage(:_dagger_processor, get(tls, :processor, nothing)) + task_local_storage(:_dagger_time_utilization, get(tls, :time_utilization, nothing)) + task_local_storage(:_dagger_alloc_utilization, get(tls, :alloc_utilization, nothing)) end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index b68a519c8..bee0717ef 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -9,7 +9,7 @@ import Random: randperm import ..Dagger import ..Dagger: Context, Processor, Thunk, WeakThunk, ThunkFuture, ThunkFailedException, Chunk, OSProc, AnyScope -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, default_enabled, get_processors, get_parent, execute!, rmprocs!, addprocs!, thunk_processor, constrain, cputhreadtime +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, default_enabled, get_processors, get_parent, execute!, rmprocs!, addprocs!, thunk_processor, constrain, cputhreadtime, uhash const OneToMany = Dict{Thunk, Set{Thunk}} @@ -613,6 +613,10 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) @assert !haskey(state.cache, task) opts = merge(ctx.options, task.options) sig = signature(task, state) + if task.hash == UInt(0) + # Compute the hash and cache it in the task + uhash(task, UInt(0); sig) + end # Calculate scope scope = if task.f isa Chunk @@ -672,7 +676,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util push!(get!(()->Vector{Tuple{Thunk,<:Any,<:Any}}(), to_fire, (gproc, proc)), (task, est_time_util, est_alloc_util)) - state.worker_time_pressure[gproc.pid][proc] += est_time_util + state.worker_time_pressure[gproc.pid][proc] = get(state.worker_time_pressure[gproc.pid], proc, UInt64(0)) + est_time_util @goto pop_task end end @@ -893,10 +897,12 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) ids = Int[0] data = Any[thunk.f] + hashes = Union{UInt,Nothing}[uhash(thunk.f, UInt(0))] for (idx, x) in enumerate(thunk.inputs) x = unwrap_weak_checked(x) push!(ids, istask(x) ? x.id : -idx) push!(data, istask(x) ? state.cache[x] : x) + push!(hashes, uhash(x, UInt(0))) end toptions = thunk.options !== nothing ? thunk.options : ThunkOptions() options = merge(ctx.options, toptions) @@ -906,9 +912,10 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) - push!(to_send, Any[thunk.id, time_util, alloc_util, chunktype(thunk.f), data, thunk.get_result, + push!(to_send, Any[thunk.id, thunk.hash, + time_util, alloc_util, chunktype(thunk.f), data, thunk.get_result, thunk.persist, thunk.cache, thunk.meta, options, - propagated, ids, + propagated, ids, hashes, (log_sink=ctx.log_sink, profile=ctx.profile), sch_handle, state.uid]) end @@ -964,7 +971,7 @@ function do_tasks(to_proc, chan, tasks) end "Executes a single task on `to_proc`." function do_task(to_proc, comm) - thunk_id, est_time_util, est_alloc_util, Tf, data, send_result, persist, cache, meta, options, propagated, ids, ctx_vars, sch_handle, uid = comm + thunk_id, task_hash, est_time_util, est_alloc_util, Tf, data, send_result, persist, cache, meta, options, propagated, ids, hashes, ctx_vars, sch_handle, sch_uid = comm ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) from_proc = OSProc() @@ -1002,7 +1009,7 @@ function do_task(to_proc, comm) lock(TASK_SYNC) do while true # Get current time utilization for the selected processor - time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, uid) + time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, sch_uid) real_time_util = get!(()->Ref{UInt64}(UInt64(0)), time_dict, to_proc) # Get current allocation utilization and capacity @@ -1043,14 +1050,19 @@ function do_task(to_proc, comm) # Initiate data transfers for function and arguments transfer_time = Threads.Atomic{UInt64}(0) transfer_size = Threads.Atomic{UInt64}(0) - _data, _ids = if meta - (Any[first(data)], Int[first(ids)]) # always fetch function + _data, _ids, _hashes = if meta + (Any[first(data)], Int[first(ids)], Union{UInt,Nothing}[first(hashes)]) # always fetch function else - (data, ids) + (data, ids, hashes) end - fetch_tasks = map(Iterators.zip(_data,_ids)) do (x, id) + fetch_tasks = map(Iterators.zip(_data, _ids, _hashes)) do (x, id, hash) @async begin timespan_start(ctx, :move, (;thunk_id, id), (;f, id, data=x)) + Dagger.set_tls!(( + sch_uid=sch_uid, + input_hash=hash, + task_hash, + )) x = if x isa Chunk value = lock(TASK_SYNC) do if haskey(CHUNK_CACHE, x) @@ -1123,8 +1135,9 @@ function do_task(to_proc, comm) result_meta = try # Set TLS variables Dagger.set_tls!(( - sch_uid=uid, + sch_uid, sch_handle=sch_handle, + task_hash, processor=to_proc, time_utilization=est_time_util, alloc_utilization=est_alloc_util, @@ -1149,7 +1162,7 @@ function do_task(to_proc, comm) # Construct result # TODO: We should cache this locally - send_result || meta ? res : tochunk(res, to_proc; device, persist, cache=persist ? true : cache) + send_result || meta ? res : tochunk(res, to_proc; device, persist, cache=persist ? true : cache, hash=task_hash) catch ex bt = catch_backtrace() RemoteException(myid(), CapturedException(ex, bt)) diff --git a/src/sch/util.jl b/src/sch/util.jl index f6fffa360..5ac2f0eb6 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -394,7 +394,7 @@ function estimate_task_costs(state, procs, task, inputs) transfer_costs = Dict(proc=>impute_sum([affinity(chunk)[2] for chunk in filter(c->get_parent(processor(c))!=get_parent(proc), chunks)]) for proc in procs) # Estimate total cost to move data and get task running after currently-scheduled tasks - costs = Dict(proc=>state.worker_time_pressure[get_parent(proc).pid][proc]+(tx_cost/tx_rate) for (proc, tx_cost) in transfer_costs) + costs = Dict(proc=>get(state.worker_time_pressure[get_parent(proc).pid], proc, UInt64(0))+(tx_cost/tx_rate) for (proc, tx_cost) in transfer_costs) # Shuffle procs around, so equally-costly procs are equally considered P = randperm(length(procs)) diff --git a/src/thunk.jl b/src/thunk.jl index 72fb2576f..603f53ac3 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -52,6 +52,7 @@ mutable struct Thunk f::Any # usually a Function, but could be any callable inputs::Vector{Any} # TODO: Use `ImmutableArray` in 1.8 id::Int + hash::UInt get_result::Bool # whether the worker should send the result or only the metadata meta::Bool persist::Bool # don't `free!` result after computing @@ -64,6 +65,7 @@ mutable struct Thunk propagates::Tuple # which options we'll propagate function Thunk(f, xs...; id::Int=next_id(), + hash=UInt(0), get_result::Bool=false, meta::Bool=false, persist::Bool=false, @@ -85,10 +87,10 @@ mutable struct Thunk xs = Any[xs...] if options !== nothing @assert isempty(kwargs) - new(f, xs, id, get_result, meta, persist, cache, cache_ref, + new(f, xs, id, hash, get_result, meta, persist, cache, cache_ref, affinity, eager_ref, options, propagates) else - new(f, xs, id, get_result, meta, persist, cache, cache_ref, + new(f, xs, id, hash, get_result, meta, persist, cache, cache_ref, affinity, eager_ref, Sch.ThunkOptions(;kwargs...), propagates) end end @@ -96,6 +98,8 @@ end Serialization.serialize(io::AbstractSerializer, t::Thunk) = throw(ArgumentError("Cannot serialize a Thunk")) +get_task_hash(t::Thunk) = t.hash + function affinity(t::Thunk) if t.affinity !== nothing return t.affinity @@ -183,6 +187,7 @@ end unwrap_weak_checked(t) = t Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) +get_task_hash(t::WeakThunk) = unwrap_weak_checked(t).hash struct ThunkFailedException{E<:Exception} <: Exception thunk::WeakThunk @@ -223,12 +228,23 @@ function Base.fetch(t::EagerThunk; raw=false) if raw fetch(t.future; raw=true) else - move(OSProc(), fetch(t.future)) + value = fetch(t.future) + if value isa Chunk + return fetch(@async begin + Dagger.set_tls!((input_hash=value.hash, + task_hash=value.hash)) + return move(OSProc(), value) + end) + else + return move(OSProc(), value) + end end end function Base.show(io::IO, t::EagerThunk) print(io, "EagerThunk ($(isready(t) ? "finished" : "running"))") end +get_task_hash(t::EagerThunk) = + remotecall_fetch(d->get_task_hash(poolget(d)), t.thunk_ref.owner, t.thunk_ref) "When finalized, cleans-up the associated `EagerThunk`." mutable struct EagerThunkFinalizer diff --git a/src/utils/uhash.jl b/src/utils/uhash.jl new file mode 100644 index 000000000..9a625906a --- /dev/null +++ b/src/utils/uhash.jl @@ -0,0 +1,61 @@ +# unified hash algorithm + +using Dagger + +uhash(x, h::UInt)::UInt = hash(x, h) +function uhash(x::Dagger.Thunk, h::UInt; sig=nothing)::UInt + value = hash(0xdead7453, h) + if x.hash != UInt(0) + return uhash(x.hash, value) + end + @assert sig !== nothing + tt = Any[] + for input in x.inputs + input = unwrap_weak_checked(input) + if input isa Dagger.Thunk && input.hash != UInt(0) + value = uhash(input.hash, value) + else + value = uhash(input, value) + push!(tt, typeof(input)) + end + end + sig = (typeof(x.f), tt) + value = uhash_sig(sig, value) + x.hash = value + return value +end +uhash(x::Dagger.WeakThunk, h::UInt)::UInt = + uhash(Dagger.unwrap_weak_checked(x), h) +function uhash_sig((f, tt), h::UInt)::UInt + value = hash(0xdead5160, h) + ci_list = Base.code_typed(f, tt) + if length(ci_list) == 0 + return hash(Union{}, hash(typeof(f), hash(tt, value))) + end + # tt must be concrete + ci = first(only(ci_list))::Core.CodeInfo + return uhash_code(ci, hash(typeof(f), hash(tt, value))) +end +function uhash_code(ci::Core.CodeInfo, h::UInt)::UInt + value = hash(0xdeadc0de, h) + for insn in ci.code + dump(insn) + value = uhash_insn(insn, h) + end + return value +end +function uhash_insn(insn::Expr, h::UInt)::UInt + value = hash(0xdeadeec54, h) + value = hash(insn.head, value) + for arg in insn.args + dump(insn) + @show uhash_insn(arg, value) + value = uhash_insn(arg, value) + end + return value +end +function uhash_insn(insn::GlobalRef, h::UInt)::UInt + value = hash(0xdead6147, h) + return hash(nameof(insn.mod), hash(insn.name, value)) +end +uhash_insn(insn, h::UInt)::UInt = hash(insn, h)