diff --git a/src/mpi.jl b/src/mpi.jl index b2d1d2b6e..61363c6db 100644 --- a/src/mpi.jl +++ b/src/mpi.jl @@ -404,6 +404,19 @@ const DEADLOCK_WARN_PERIOD = TaskLocalValue{Float64}(()->10.0) const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->60.0) const RECV_WAITING = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Base.Event}()) +struct InplaceInfo + type::DataType + shape::Tuple +end +struct InplaceSparseInfo + type::DataType + m::Int + n::Int + colptr::Int + rowval::Int + nzval::Int +end + function supports_inplace_mpi(value) if value isa DenseArray && isbitstype(eltype(value)) return true @@ -412,16 +425,17 @@ function supports_inplace_mpi(value) end end function recv_yield!(buffer, comm, src, tag) - println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))") + #println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))") if !supports_inplace_mpi(buffer) return recv_yield(comm, src, tag), false end + time_start = time_ns() detect = DEADLOCK_DETECT[] warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) rank = MPI.Comm_rank(comm) - Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]") + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]") # Ensure no other receiver is waiting our_event = Base.Event() @@ -460,7 +474,7 @@ function recv_yield!(buffer, comm, src, tag) notify(our_event) end #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Released lock") - return value, true + return buffer, true end warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, "recv", src) yield() @@ -470,17 +484,10 @@ function recv_yield!(buffer, comm, src, tag) yield() end end -struct InplaceInfo - type::DataType - shape::Tuple -end + function recv_yield(comm, src, tag) - time_start = time_ns() - detect = DEADLOCK_DETECT[] - warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) - timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) rank = MPI.Comm_rank(comm) - Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv from [$src]") + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv from [$src]") # Ensure no other receiver is waiting our_event = Base.Event() @@ -494,7 +501,7 @@ function recv_yield(comm, src, tag) end end if other_event !== nothing - #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waitingg for other receiver...") wait(other_event) @goto retry end @@ -503,7 +510,7 @@ function recv_yield(comm, src, tag) type = nothing @label receive value = recv_yield_serialized(comm, rank, src, tag) - if value isa InplaceInfo + if value isa InplaceInfo || value isa InplaceSparseInfo value = recv_yield_inplace(value, comm, rank, src, tag) end lock(RECV_WAITING) do waiting @@ -512,11 +519,13 @@ function recv_yield(comm, src, tag) end return value end -function recv_yield_serialized(comm, my_rank, their_rank, tag) + +function recv_yield_inplace!(array, comm, my_rank, their_rank, tag) time_start = time_ns() detect = DEADLOCK_DETECT[] warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + while true (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) if got @@ -524,25 +533,44 @@ function recv_yield_serialized(comm, my_rank, their_rank, tag) error("recv_yield failed with error $(MPI.Get_error(stat))") end count = MPI.Get_count(stat, UInt8) - buf = Array{UInt8}(undef, count) - req = MPI.Imrecv!(MPI.Buffer(buf), msg) + @assert count == sizeof(array) "recv_yield_inplace: expected $(sizeof(array)) bytes, got $count" + buf = MPI.Buffer(array) + req = MPI.Imrecv!(buf, msg) __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") - return MPI.deserialize(buf) + break end warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) yield() end + return array end + function recv_yield_inplace(_value::InplaceInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: Array && isbitstype(eltype(T)) "recv_yield_inplace only supports inplace MPI transfers of bitstype dense arrays" + array = Array{eltype(T)}(undef, _value.shape) + return recv_yield_inplace!(array, comm, my_rank, their_rank, tag) +end + +function recv_yield_inplace(_value::InplaceSparseInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: SparseMatrixCSC "recv_yield_inplace only supports inplace MPI transfers of SparseMatrixCSC" + + colptr = recv_yield_inplace!(Vector{Int64}(undef, _value.colptr), comm, my_rank, their_rank, tag) + rowval = recv_yield_inplace!(Vector{Int64}(undef, _value.rowval), comm, my_rank, their_rank, tag) + nzval = recv_yield_inplace!(Vector{eltype(T)}(undef, _value.nzval), comm, my_rank, their_rank, tag) + + SparseArray = SparseMatrixCSC{eltype(T), Int64}(_value.m, _value.n, colptr, rowval, nzval) + return SparseArray + +end + +function recv_yield_serialized(comm, my_rank, their_rank, tag) time_start = time_ns() detect = DEADLOCK_DETECT[] warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) - T = _value.type - @assert T <: Array && isbitstype(eltype(T)) "recv_yield_inplace only supports inplace MPI transfers of bitstype dense arrays" - array = Array{eltype(T)}(undef, _value.shape) - while true (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) if got @@ -550,17 +578,14 @@ function recv_yield_inplace(_value::InplaceInfo, comm, my_rank, their_rank, tag) error("recv_yield failed with error $(MPI.Get_error(stat))") end count = MPI.Get_count(stat, UInt8) - @assert count == sizeof(array) "recv_yield_inplace: expected $(sizeof(array)) bytes, got $count" - buf = MPI.Buffer(array) - req = MPI.Imrecv!(buf, msg) + buf = Array{UInt8}(undef, count) + req = MPI.Imrecv!(MPI.Buffer(buf), msg) __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") - break + return MPI.deserialize(buf) end warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) yield() end - - return array end const SEEN_TAGS = Dict{Int32, Type}() @@ -584,19 +609,27 @@ function _send_yield(value, comm, dest, tag; check_seen::Bool=true, inplace::Boo send_yield_serialized(value, comm, rank, dest, tag) end end + function send_yield_inplace(value, comm, my_rank, their_rank, tag) req = MPI.Isend(value, comm; dest=their_rank, tag) __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") end + function send_yield_serialized(value, comm, my_rank, their_rank, tag) if value isa Array && isbitstype(eltype(value)) send_yield_serialized(InplaceInfo(typeof(value), size(value)), comm, my_rank, their_rank, tag) send_yield_inplace(value, comm, my_rank, their_rank, tag) + elseif value isa SparseMatrixCSC && isbitstype(eltype(value)) + send_yield_serialized(InplaceSparseInfo(typeof(value), value.m, value.n, length(value.colptr), length(value.rowval), length(value.nzval)), comm, my_rank, their_rank, tag) + send_yield!(value.colptr, comm, their_rank, tag; check_seen=false) + send_yield!(value.rowval, comm, their_rank, tag; check_seen=false) + send_yield!(value.nzval, comm, their_rank, tag; check_seen=false) else req = MPI.isend(value, comm; dest=their_rank, tag) __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") end end + function __wait_for_request(req, comm, my_rank, their_rank, tag, fn::String, kind::String) time_start = time_ns() detect = DEADLOCK_DETECT[] @@ -623,6 +656,7 @@ function bcast_send_yield(value, comm, root, tag) send_yield(value, comm, other_rank, tag) end end + function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, kind, srcdest) time_elapsed = (time_ns() - time_start) if detect && time_elapsed > warn_period diff --git a/test/mpi.jl b/test/mpi.jl index a428e4256..5e15cad65 100644 --- a/test/mpi.jl +++ b/test/mpi.jl @@ -1,33 +1,63 @@ using Dagger using MPI +using SparseArrays Dagger.accelerate!(:mpi) -#= + + if MPI.Comm_rank(MPI.COMM_WORLD) == 0 B = rand(4, 4) - Dagger.send_yield(B, MPI.COMM_WORLD, 1, 0) - println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) B: $B") + Dagger.send_yield!(B, MPI.COMM_WORLD, 1, 0) + println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield! Array: B: $B") else B = zeros(4, 4) Dagger.recv_yield!(B, MPI.COMM_WORLD, 0, 0) - println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) B: $B") + println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield! Array: B: $B") end +MPI.Barrier(MPI.COMM_WORLD) + if MPI.Comm_rank(MPI.COMM_WORLD) == 0 B = "hello" - Dagger.send_yield(B, MPI.COMM_WORLD, 1, 1) - println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) B: $B") + Dagger.send_yield!(B, MPI.COMM_WORLD, 1, 2) + println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield String: B: $B") else B = "Goodbye" - B1, _ = Dagger.recv_yield!(B, MPI.COMM_WORLD, 0, 1) - println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) B1: $B1") + B1, _ = Dagger.recv_yield!(B, MPI.COMM_WORLD, 0, 2) + println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield! String: B1: $B1") end -=# -A = rand(Blocks(2,2), 4, 4) -Ac = collect(A) -println(Ac) +MPI.Barrier(MPI.COMM_WORLD) + +if MPI.Comm_rank(MPI.COMM_WORLD) == 0 + B = sprand(4,4,0.2) + Dagger.send_yield(B, MPI.COMM_WORLD, 1, 1) + println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield (half in-place) Sparse: B: $B") +else + B1 = Dagger.recv_yield(MPI.COMM_WORLD, 0, 1) + println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield (half in-place) Sparse: B1: $B1") +end -#move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2]) +MPI.Barrier(MPI.COMM_WORLD) +if MPI.Comm_rank(MPI.COMM_WORLD) == 0 + B = rand(4, 4) + Dagger.send_yield(B, MPI.COMM_WORLD, 1, 0) + println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield (half in-place) Dense: B: $B") +else + + B = Dagger.recv_yield( MPI.COMM_WORLD, 0, 0) + println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield (half in-place) Dense: B: $B") +end + +MPI.Barrier(MPI.COMM_WORLD) + + + +#= +A = rand(Blocks(2,2), 4, 4) +Ac = collect(A) +println(Ac) +=# +#move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2]) \ No newline at end of file