Skip to content

SparseArray in-place send/recv #624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: yg/faster-mpi
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 62 additions & 28 deletions src/mpi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,19 @@ const DEADLOCK_WARN_PERIOD = TaskLocalValue{Float64}(()->10.0)
const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->60.0)
const RECV_WAITING = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Base.Event}())

struct InplaceInfo
type::DataType
shape::Tuple
end
struct InplaceSparseInfo
type::DataType
m::Int
n::Int
colptr::Int
rowval::Int
nzval::Int
end

function supports_inplace_mpi(value)
if value isa DenseArray && isbitstype(eltype(value))
return true
Expand All @@ -412,16 +425,17 @@ function supports_inplace_mpi(value)
end
end
function recv_yield!(buffer, comm, src, tag)
println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))")
#println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))")
if !supports_inplace_mpi(buffer)
return recv_yield(comm, src, tag), false
end

time_start = time_ns()
detect = DEADLOCK_DETECT[]
warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9)
timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9)
rank = MPI.Comm_rank(comm)
Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]")
#Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]")

# Ensure no other receiver is waiting
our_event = Base.Event()
Expand Down Expand Up @@ -460,7 +474,7 @@ function recv_yield!(buffer, comm, src, tag)
notify(our_event)
end
#Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Released lock")
return value, true
return buffer, true
end
warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, "recv", src)
yield()
Expand All @@ -470,17 +484,10 @@ function recv_yield!(buffer, comm, src, tag)
yield()
end
end
struct InplaceInfo
type::DataType
shape::Tuple
end

function recv_yield(comm, src, tag)
time_start = time_ns()
detect = DEADLOCK_DETECT[]
warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9)
timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9)
rank = MPI.Comm_rank(comm)
Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv from [$src]")
#Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv from [$src]")

# Ensure no other receiver is waiting
our_event = Base.Event()
Expand All @@ -494,7 +501,7 @@ function recv_yield(comm, src, tag)
end
end
if other_event !== nothing
#Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...")
Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waitingg for other receiver...")
wait(other_event)
@goto retry
end
Expand All @@ -503,7 +510,7 @@ function recv_yield(comm, src, tag)
type = nothing
@label receive
value = recv_yield_serialized(comm, rank, src, tag)
if value isa InplaceInfo
if value isa InplaceInfo || value isa InplaceSparseInfo
value = recv_yield_inplace(value, comm, rank, src, tag)
end
lock(RECV_WAITING) do waiting
Expand All @@ -512,55 +519,73 @@ function recv_yield(comm, src, tag)
end
return value
end
function recv_yield_serialized(comm, my_rank, their_rank, tag)

function recv_yield_inplace!(array, comm, my_rank, their_rank, tag)
time_start = time_ns()
detect = DEADLOCK_DETECT[]
warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9)
timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9)

while true
(got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status)
if got
if MPI.Get_error(stat) != MPI.SUCCESS
error("recv_yield failed with error $(MPI.Get_error(stat))")
end
count = MPI.Get_count(stat, UInt8)
buf = Array{UInt8}(undef, count)
req = MPI.Imrecv!(MPI.Buffer(buf), msg)
@assert count == sizeof(array) "recv_yield_inplace: expected $(sizeof(array)) bytes, got $count"
buf = MPI.Buffer(array)
req = MPI.Imrecv!(buf, msg)
__wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv")
return MPI.deserialize(buf)
break
end
warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank)
yield()
end
return array
end

function recv_yield_inplace(_value::InplaceInfo, comm, my_rank, their_rank, tag)
T = _value.type
@assert T <: Array && isbitstype(eltype(T)) "recv_yield_inplace only supports inplace MPI transfers of bitstype dense arrays"
array = Array{eltype(T)}(undef, _value.shape)
return recv_yield_inplace!(array, comm, my_rank, their_rank, tag)
end

function recv_yield_inplace(_value::InplaceSparseInfo, comm, my_rank, their_rank, tag)
T = _value.type
@assert T <: SparseMatrixCSC "recv_yield_inplace only supports inplace MPI transfers of SparseMatrixCSC"

colptr = recv_yield_inplace!(Vector{Int64}(undef, _value.colptr), comm, my_rank, their_rank, tag)
rowval = recv_yield_inplace!(Vector{Int64}(undef, _value.rowval), comm, my_rank, their_rank, tag)
nzval = recv_yield_inplace!(Vector{eltype(T)}(undef, _value.nzval), comm, my_rank, their_rank, tag)

SparseArray = SparseMatrixCSC{eltype(T), Int64}(_value.m, _value.n, colptr, rowval, nzval)
return SparseArray

end

function recv_yield_serialized(comm, my_rank, their_rank, tag)
time_start = time_ns()
detect = DEADLOCK_DETECT[]
warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9)
timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9)

T = _value.type
@assert T <: Array && isbitstype(eltype(T)) "recv_yield_inplace only supports inplace MPI transfers of bitstype dense arrays"
array = Array{eltype(T)}(undef, _value.shape)

while true
(got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status)
if got
if MPI.Get_error(stat) != MPI.SUCCESS
error("recv_yield failed with error $(MPI.Get_error(stat))")
end
count = MPI.Get_count(stat, UInt8)
@assert count == sizeof(array) "recv_yield_inplace: expected $(sizeof(array)) bytes, got $count"
buf = MPI.Buffer(array)
req = MPI.Imrecv!(buf, msg)
buf = Array{UInt8}(undef, count)
req = MPI.Imrecv!(MPI.Buffer(buf), msg)
__wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv")
break
return MPI.deserialize(buf)
end
warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank)
yield()
end

return array
end

const SEEN_TAGS = Dict{Int32, Type}()
Expand All @@ -584,19 +609,27 @@ function _send_yield(value, comm, dest, tag; check_seen::Bool=true, inplace::Boo
send_yield_serialized(value, comm, rank, dest, tag)
end
end

function send_yield_inplace(value, comm, my_rank, their_rank, tag)
req = MPI.Isend(value, comm; dest=their_rank, tag)
__wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send")
end

function send_yield_serialized(value, comm, my_rank, their_rank, tag)
if value isa Array && isbitstype(eltype(value))
send_yield_serialized(InplaceInfo(typeof(value), size(value)), comm, my_rank, their_rank, tag)
send_yield_inplace(value, comm, my_rank, their_rank, tag)
elseif value isa SparseMatrixCSC && isbitstype(eltype(value))
send_yield_serialized(InplaceSparseInfo(typeof(value), value.m, value.n, length(value.colptr), length(value.rowval), length(value.nzval)), comm, my_rank, their_rank, tag)
send_yield!(value.colptr, comm, their_rank, tag; check_seen=false)
send_yield!(value.rowval, comm, their_rank, tag; check_seen=false)
send_yield!(value.nzval, comm, their_rank, tag; check_seen=false)
else
req = MPI.isend(value, comm; dest=their_rank, tag)
__wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send")
end
end

function __wait_for_request(req, comm, my_rank, their_rank, tag, fn::String, kind::String)
time_start = time_ns()
detect = DEADLOCK_DETECT[]
Expand All @@ -623,6 +656,7 @@ function bcast_send_yield(value, comm, root, tag)
send_yield(value, comm, other_rank, tag)
end
end

function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, kind, srcdest)
time_elapsed = (time_ns() - time_start)
if detect && time_elapsed > warn_period
Expand Down
56 changes: 43 additions & 13 deletions test/mpi.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,63 @@
using Dagger
using MPI
using SparseArrays

Dagger.accelerate!(:mpi)
#=


if MPI.Comm_rank(MPI.COMM_WORLD) == 0
B = rand(4, 4)
Dagger.send_yield(B, MPI.COMM_WORLD, 1, 0)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) B: $B")
Dagger.send_yield!(B, MPI.COMM_WORLD, 1, 0)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield! Array: B: $B")
else
B = zeros(4, 4)
Dagger.recv_yield!(B, MPI.COMM_WORLD, 0, 0)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) B: $B")
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield! Array: B: $B")
end

MPI.Barrier(MPI.COMM_WORLD)

if MPI.Comm_rank(MPI.COMM_WORLD) == 0
B = "hello"
Dagger.send_yield(B, MPI.COMM_WORLD, 1, 1)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) B: $B")
Dagger.send_yield!(B, MPI.COMM_WORLD, 1, 2)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield String: B: $B")
else
B = "Goodbye"
B1, _ = Dagger.recv_yield!(B, MPI.COMM_WORLD, 0, 1)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) B1: $B1")
B1, _ = Dagger.recv_yield!(B, MPI.COMM_WORLD, 0, 2)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield! String: B1: $B1")
end
=#
A = rand(Blocks(2,2), 4, 4)
Ac = collect(A)
println(Ac)

MPI.Barrier(MPI.COMM_WORLD)

if MPI.Comm_rank(MPI.COMM_WORLD) == 0
B = sprand(4,4,0.2)
Dagger.send_yield(B, MPI.COMM_WORLD, 1, 1)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield (half in-place) Sparse: B: $B")
else
B1 = Dagger.recv_yield(MPI.COMM_WORLD, 0, 1)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield (half in-place) Sparse: B1: $B1")
end

#move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2])
MPI.Barrier(MPI.COMM_WORLD)

if MPI.Comm_rank(MPI.COMM_WORLD) == 0
B = rand(4, 4)
Dagger.send_yield(B, MPI.COMM_WORLD, 1, 0)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield (half in-place) Dense: B: $B")
else

B = Dagger.recv_yield( MPI.COMM_WORLD, 0, 0)
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield (half in-place) Dense: B: $B")
end

MPI.Barrier(MPI.COMM_WORLD)



#=
A = rand(Blocks(2,2), 4, 4)
Ac = collect(A)
println(Ac)
=#

#move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2])
Loading