diff --git a/Project.toml b/Project.toml index 8735298dc..8423dc9ee 100644 --- a/Project.toml +++ b/Project.toml @@ -4,11 +4,14 @@ version = "0.18.14" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DistributedNext = "fab6aee4-877b-4bac-a744-3eca44acbb6f" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" @@ -54,6 +57,7 @@ Distributions = "0.25" GraphViz = "0.2" Graphs = "1" JSON3 = "1" +MPI = "0.20.22" MacroTools = "0.5" MemPool = "0.4.11" OnlineStats = "1" diff --git a/src/Dagger.jl b/src/Dagger.jl index fd6395a4b..5719a158a 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -10,9 +10,9 @@ import MemPool: DRef, FileRef, poolget, poolset import Base: collect, reduce import LinearAlgebra -import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric import Random import Random: AbstractRNG +import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric, chkstride1 import UUIDs: UUID, uuid4 @@ -64,8 +64,8 @@ include("utils/scopes.jl") include("queue.jl") include("thunk.jl") include("submission.jl") -include("chunks.jl") include("memory-spaces.jl") +include("chunks.jl") # Task scheduling include("compute.jl") @@ -125,6 +125,8 @@ function set_distributed_package!(value) @set_preferences!("distributed-package" => value) @info "Dagger.jl preference has been set, restart your Julia session for this change to take effect!" end +# MPI +include("mpi.jl") # Precompilation import PrecompileTools: @compile_workload diff --git a/src/array/darray.jl b/src/array/darray.jl index 11feb53cb..1696e0504 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -492,9 +492,12 @@ function auto_blocks(dims::Dims{N}) where N end auto_blocks(A::AbstractArray{T,N}) where {T,N} = auto_blocks(size(A)) -distribute(A::AbstractArray) = distribute(A, AutoBlocks()) -distribute(A::AbstractArray{T,N}, dist::Blocks{N}) where {T,N} = +distribute(A::AbstractArray{T,N}, dist::Blocks{N}, ::DistributedAcceleration) where {T,N} = _to_darray(Distribute(dist, A)) + +distribute(A::AbstractArray{T,N}, dist::Blocks{N}) where {T,N} = + distribute(A::AbstractArray{T,N}, dist, current_acceleration()) +distribute(A::AbstractArray) = distribute(A, AutoBlocks()) distribute(A::AbstractArray, ::AutoBlocks) = distribute(A, auto_blocks(A)) function distribute(x::AbstractArray{T,N}, n::NTuple{N}) where {T,N} p = map((d, dn)->ceil(Int, d / dn), size(x), n) diff --git a/src/chunks.jl b/src/chunks.jl index 1eb56714e..a2a5fb6f6 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -26,33 +26,6 @@ domain(x::Any) = UnitDomain() ###### Chunk ###### -""" - Chunk - -A reference to a piece of data located on a remote worker. `Chunk`s are -typically created with `Dagger.tochunk(data)`, and the data can then be -accessed from any worker with `collect(::Chunk)`. `Chunk`s are -serialization-safe, and use distributed refcounting (provided by -`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, -as long as a reference exists on some worker. - -Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a -sense) the processor that "owns" or contains the data. Calling -`collect(::Chunk)` will perform data movement and conversions defined by that -processor to safely serialize the data to the calling worker. - -## Constructors -See [`tochunk`](@ref). -""" -mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope} - chunktype::Type{T} - domain - handle::H - processor::P - scope::S - persist::Bool -end - domain(c::Chunk) = c.domain chunktype(c::Chunk) = c.chunktype persist!(t::Chunk) = (t.persist=true; t) @@ -77,7 +50,7 @@ function collect(ctx::Context, chunk::Chunk; options=nothing) elseif chunk.handle isa FileRef return poolget(chunk.handle) else - return move(chunk.processor, OSProc(), chunk.handle) + return move(chunk.processor, default_processor(), chunk.handle) end end collect(ctx::Context, ref::DRef; options=nothing) = @@ -262,9 +235,18 @@ will be inspected to determine if it's safe to serialize; if so, the default MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will be used. +`type` can be specified manually to force the type to be `Chunk{type}`. + All other kwargs are passed directly to `MemPool.poolset`. """ -function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S} + +tochunk(x::X, proc::P, space::M; kwargs...) where {X,P<:Processor,M<:MemorySpace} = + tochunk(x, proc, space, AnyScope(); kwargs...) +function tochunk(x::X, proc::P, space::M, scope::S; persist=false, cache=false, device=nothing, type=X, kwargs...) where {X,P<:Processor,S,M<:MemorySpace} + if x isa Chunk + check_proc_space(x, proc, space) + return x + end if device === nothing device = if Sch.walk_storage_safe(x) MemPool.GLOBAL_DEVICE[] @@ -272,10 +254,56 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cac MemPool.CPURAMDevice() end end - ref = poolset(x; device, kwargs...) - Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist) + ref = tochunk_pset(x, space; device, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space, persist) end -tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x + +function tochunk(x::X, proc::P, scope::S; persist=false, cache=false, device=nothing, type=X, kwargs...) where {X,P<:Processor,S} + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + check_proc_space(x, proc, x.space) + return x + end + space = default_memory_space(current_acceleration(), x) + ref = tochunk_pset(x, space; device, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space, persist) +end +function tochunk(x::X, space::M, scope::S; persist=false, cache=false, device=nothing, type=X, kwargs...) where {X,M<:MemorySpace,S} + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + check_proc_space(x, x.processor, space) + return x + end + proc = default_processor(current_acceleration(), x) + ref = tochunk_pset(x, space; device, kwargs...) + return Chunk{type,typeof(ref),typeof(proc),S,M}(type, domain(x), ref, proc, scope, space, persist) +end +tochunk(x, procOrSpace; kwargs...) = tochunk(x, procOrSpace, AnyScope(); kwargs...) +tochunk(x; kwargs...) = tochunk(x, default_memory_space(current_acceleration(), x), AnyScope(); kwargs...) + +check_proc_space(x, proc, space) = nothing +function check_proc_space(x::Chunk, proc, space) + if x.space !== space + throw(ArgumentError("Memory space mismatch: Chunk=$(x.space) != Requested=$space")) + end +end +function check_proc_space(x::Thunk, proc, space) + # FIXME: Validate +end + +tochunk_pset(x, space::MemorySpace; device=nothing, kwargs...) = poolset(x; device, kwargs...) function savechunk(data, dir, f) sz = open(joinpath(dir, f), "w") do io @@ -292,10 +320,12 @@ struct WeakChunk wid::Int id::Int x::WeakRef - function WeakChunk(c::Chunk) - return new(c.handle.owner, c.handle.id, WeakRef(c)) - end end + +function WeakChunk(c::Chunk) + return WeakChunk(c.handle.owner, c.handle.id, WeakRef(c)) +end + unwrap_weak(c::WeakChunk) = c.x.value function unwrap_weak_checked(c::WeakChunk) cw = unwrap_weak(c) diff --git a/src/datadeps.jl b/src/datadeps.jl index 43c4c3848..33c85affa 100644 --- a/src/datadeps.jl +++ b/src/datadeps.jl @@ -1,4 +1,5 @@ import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv +using MPI export In, Out, InOut, Deps, spawn_datadeps @@ -21,6 +22,10 @@ struct Deps{T,DT<:Tuple} end Deps(x, deps...) = Deps(x, deps) +struct MPIAcceleration <: Acceleration + comm::MPI.Comm +end + struct DataDepsTaskQueue <: AbstractTaskQueue # The queue above us upper_queue::AbstractTaskQueue @@ -162,9 +167,9 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState end end -function aliasing(astate::DataDepsAliasingState, arg, dep_mod) +function aliasing(astate::DataDepsAliasingState, accel::Acceleration, arg, dep_mod) return get!(astate.ainfo_cache, (arg, dep_mod)) do - return aliasing(arg, dep_mod) + return aliasing(accel, arg, dep_mod) end end @@ -202,7 +207,7 @@ function has_writedep(state::DataDepsState, arg, deps, task::DTask) for (readdep, writedep, other_ainfo, _, _) in other_taskdeps writedep || continue for (dep_mod, _, _) in deps - ainfo = aliasing(state.alias_state, arg, dep_mod) + ainfo = aliasing(state.alias_state, current_acceleration(), arg, dep_mod) if will_alias(ainfo, other_ainfo) return true end @@ -251,7 +256,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) # Add all aliasing dependencies for (dep_mod, readdep, writedep) in deps if state.aliasing - ainfo = aliasing(state.alias_state, arg, dep_mod) + ainfo = aliasing(state.alias_state, current_acceleration(), arg, dep_mod) else ainfo = UnknownAliasing() end @@ -272,8 +277,7 @@ end function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, arg, deps) astate = state.alias_state for (dep_mod, readdep, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - + ainfo = aliasing(astate, current_acceleration(), arg, dep_mod) # Initialize owner and readers if !haskey(astate.ainfos_owner, ainfo) overlaps = Set{AbstractAliasing}() @@ -394,9 +398,48 @@ function add_reader!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, push!(state.alias_state.args_readers[arg], task=>write_num) end +function remotecall_endpoint(::Dagger.DistributedAcceleration, w, from_proc, to_proc, orig_space, dest_space, data, task) + return remotecall_fetch(w.pid, from_proc, to_proc, data) do from_proc, to_proc, data + data_converted = move(from_proc, to_proc, data) + data_chunk = tochunk(data_converted, to_proc, dest_space) + @assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" + return data_chunk + end +end + +const MPI_UID = ScopedValue{Int64}(0) + +function remotecall_endpoint(accel::Dagger.MPIAcceleration, w, from_proc, to_proc, orig_space, dest_space, data, task) + loc_rank = MPI.Comm_rank(accel.comm) + with(MPI_UID=>task.uid) do + if data isa Chunk + tag = abs(Base.unsafe_trunc(Int32, hash(data.handle.id))) + if loc_rank == from_proc.rank == to_proc.rank + data_converted = move(to_proc, data) + data_chunk = tochunk(data_converted, to_proc, dest_space) + elseif loc_rank == to_proc.rank + data_moved = Dagger.recv_yield(accel.comm, orig_space.rank, tag) + data_converted = move(to_proc, data_moved) + data_chunk = tochunk(data_converted, to_proc, dest_space) + elseif loc_rank == from_proc.rank + data_moved = move(from_proc, data) + Dagger.send_yield(data_moved, accel.comm, to_proc.rank, tag) + data_chunk = tochunk(data_moved, to_proc, dest_space) + else + T = move_type(from_proc, to_proc, chunktype(data)) + data_chunk = tochunk(nothing, to_proc, dest_space; type=T) + end + else + data_converted = move(from_proc, data) + data_chunk = tochunk(data_converted, to_proc, dest_space) + end + return data_chunk + end +end + # Make a copy of each piece of data on each worker # memory_space => {arg => copy_of_arg} -function generate_slot!(state::DataDepsState, dest_space, data) +function generate_slot!(state::DataDepsState, dest_space, data, task) if data isa DTask data = fetch(data; raw=true) end @@ -404,26 +447,22 @@ function generate_slot!(state::DataDepsState, dest_space, data) to_proc = first(processors(dest_space)) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) + w = only(unique(map(get_parent, collect(processors(dest_space))))) if orig_space == dest_space - data_chunk = tochunk(data, from_proc) + data_chunk = tochunk(data, from_proc, dest_space) dest_space_args[data] = data_chunk - @assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc @assert memory_space(data_chunk) == orig_space + @assert processor(data_chunk) in processors(dest_space) || data isa Chunk && (processor(data) isa Dagger.OSProc || processor(data) isa Dagger.MPIOSProc) else - w = only(unique(map(get_parent, collect(processors(dest_space))))).pid ctx = Sch.eager_context() id = rand(Int) timespan_start(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data)) - dest_space_args[data] = remotecall_fetch(w, from_proc, to_proc, data) do from_proc, to_proc, data - data_converted = move(from_proc, to_proc, data) - data_chunk = tochunk(data_converted, to_proc) - @assert processor(data_chunk) in processors(dest_space) - @assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - @assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - return data_chunk - end + dest_space_args[data] = remotecall_endpoint(current_acceleration(), w, from_proc, to_proc, orig_space, dest_space, data, task) timespan_finish(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data=dest_space_args[data])) end + check_uniform(memory_space(dest_space_args[data])) + check_uniform(processor(dest_space_args[data])) + check_uniform(dest_space_args[data].handle) return dest_space_args[data] end @@ -457,9 +496,13 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Get the set of all processors to be scheduled on all_procs = Processor[] scope = get_options(:scope, DefaultScope()) - for w in procs() - append!(all_procs, get_processors(OSProc(w))) + accel = current_acceleration() + accel_procs = filter(procs(Dagger.Sch.eager_context())) do proc + Dagger.accel_matches_proc(accel, proc) end + all_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in accel_procs]...)) + # FIXME: This is an unreliable way to ensure processor uniformity + sort!(all_procs, by=short_name) filter!(proc->!isa(constrain(ExactScope(proc), scope), InvalidScope), all_procs) @@ -467,8 +510,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue) throw(Sch.SchedulingException("No processors available, try widening scope")) end exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) + #=if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 + end=# + for proc in all_procs + check_uniform(proc) end # Round-robin assign tasks to processors @@ -665,8 +711,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue) our_space = only(memory_spaces(our_proc)) our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) our_scope = UnionScope(map(ExactScope, our_procs)...) + check_uniform(our_proc) + check_uniform(our_space) - spec.f = move(ThreadProc(myid(), 1), our_proc, spec.f) + # FIXME: May not be correct to move this under uniformity + spec.f = move(default_processor(), our_proc, spec.f) @dagdebug nothing :spawn_datadeps "($(repr(spec.f))) Scheduling: $our_proc ($our_space)" # Copy raw task arguments for analysis @@ -677,34 +726,34 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Is the data written previously or now? arg, deps = unwrap_inout(arg) arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)" + if !type_may_alias(typeof(arg)) + @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (immutable)" spec.args[idx] = pos => arg continue end # Is the source of truth elsewhere? arg_remote = get!(get!(IdDict{Any,Any}, state.remote_args, our_space), arg) do - generate_slot!(state, our_space, arg) + generate_slot!(state, our_space, arg, task) end if queue.aliasing for (dep_mod, _, _) in deps - ainfo = aliasing(astate, arg, dep_mod) + ainfo = aliasing(astate, current_acceleration(), arg, dep_mod) data_space = astate.data_locality[ainfo] nonlocal = our_space != data_space if nonlocal # Add copy-to operation (depends on latest owner of arg) @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) + generate_slot!(state, data_space, arg, task) end copy_to_scope = our_scope copy_to_syncdeps = Set{Any}() get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) + #@dagdebug nothing :mpi "[$(MPI.Comm_rank(current_acceleration().comm))] Scheduled move from $(arg_local.handle.id) into $(arg_remote.handle.id)\n" + copy_to = Dagger.@spawn scope=copy_to_scope occupancy=Dict(Any=>0) syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) add_writer!(state, ainfo, copy_to, write_num) - astate.data_locality[ainfo] = our_space else @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Skipped copy-to (local): $data_space" @@ -717,13 +766,13 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Add copy-to operation (depends on latest owner of arg) @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Enqueueing copy-to: $data_space => $our_space" arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) + generate_slot!(state, data_space, arg, task) end copy_to_scope = our_scope copy_to_syncdeps = Set{Any}() get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) + copy_to = Dagger.@spawn scope=copy_to_scope occupancy=Dict(Any=>0) syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) add_writer!(state, arg, copy_to, write_num) astate.data_locality[arg] = our_space @@ -752,7 +801,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) type_may_alias(typeof(arg)) || continue if queue.aliasing for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) + ainfo = aliasing(astate, current_acceleration(), arg, dep_mod) if writedep @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Syncing as writer" get_write_deps!(state, ainfo, task, write_num, syncdeps) @@ -775,7 +824,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Launch user's task task_scope = our_scope - spec.options = merge(spec.options, (;syncdeps, scope=task_scope)) + spec.options = merge(spec.options, (;syncdeps, scope=task_scope, occupancy=Dict(Any=>0))) enqueue!(upper_queue, spec=>task) # Update read/write tracking for arguments @@ -785,7 +834,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) type_may_alias(typeof(arg)) || continue if queue.aliasing for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) + ainfo = aliasing(astate, current_acceleration(), arg, dep_mod) if writedep @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Set as owner" add_writer!(state, ainfo, task, write_num) @@ -817,7 +866,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # in the correct order # First, find the latest owners of each live ainfo - arg_writes = IdDict{Any,Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}}() + arg_writes = IdDict{Any,Vector{Tuple{AbstractAliasing,<:Any,MemorySpace,DTask}}}() for (task, taskdeps) in state.dependencies for (_, writedep, ainfo, dep_mod, arg) in taskdeps writedep || continue @@ -831,7 +880,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end # Get the set of writers - ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}, arg_writes, arg) + ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace,DTask}}, arg_writes, arg) #= FIXME: If we fully overlap any writer, evict them idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes) @@ -839,7 +888,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) =# # Make ourselves the latest writer - push!(ainfo_writes, (ainfo, dep_mod, astate.data_locality[ainfo])) + push!(ainfo_writes, (ainfo, dep_mod, astate.data_locality[ainfo], task)) end end @@ -851,14 +900,14 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # FIXME: Remove me deleteat!(ainfo_writes, 1:length(ainfo_writes)-1) end - for (ainfo, dep_mod, data_remote_space) in ainfo_writes + for (ainfo, dep_mod, data_remote_space, task) in ainfo_writes # Is the source of truth elsewhere? data_local_space = astate.data_origin[ainfo] if data_local_space != data_remote_space # Add copy-from operation @dagdebug nothing :spawn_datadeps "[$dep_mod] Enqueueing copy-from: $data_remote_space => $data_local_space" arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_local_space), arg) do - generate_slot!(state, data_local_space, arg) + generate_slot!(state, data_local_space, arg, task) end arg_remote = state.remote_args[data_remote_space][arg] @assert arg_remote !== arg_local @@ -867,7 +916,8 @@ function distribute_tasks!(queue::DataDepsTaskQueue) copy_from_syncdeps = Set() get_write_deps!(state, ainfo, nothing, write_num, copy_from_syncdeps) @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) + #@dagdebug nothing :mpi "[$(MPI.Comm_rank(current_acceleration().comm))] Scheduled move from $(arg_remote.handle.id) into $(arg_local.handle.id)\n" + copy_from = Dagger.@spawn scope=copy_from_scope occupancy=Dict(Any=>0) syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) else @dagdebug nothing :spawn_datadeps "[$dep_mod] Skipped copy-from (local): $data_remote_space" end @@ -895,7 +945,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) copy_from_syncdeps = Set() get_write_deps!(state, arg, nothing, write_num, copy_from_syncdeps) @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) + copy_from = Dagger.@spawn scope=copy_from_scope occupancy=Dict(Any=>0) syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) else @dagdebug nothing :spawn_datadeps "Skipped copy-from (local): $data_remote_space" end diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index b0aa248ce..d027bd62e 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,25 +1,89 @@ +abstract type Acceleration end + +struct DistributedAcceleration <: Acceleration end + +const ACCELERATION = TaskLocalValue{Acceleration}(() -> DistributedAcceleration()) + +current_acceleration() = ACCELERATION[] + +default_processor(::DistributedAcceleration) = OSProc(myid()) +default_processor(accel::DistributedAcceleration, x) = default_processor(accel) +default_processor() = default_processor(current_acceleration()) + +accelerate!(accel::Symbol) = accelerate!(Val{accel}()) +accelerate!(::Val{:distributed}) = accelerate!(DistributedAcceleration()) + +initialize_acceleration!(a::DistributedAcceleration) = nothing +function accelerate!(accel::Acceleration) + initialize_acceleration!(accel) + ACCELERATION[] = accel +end + +accel_matches_proc(accel::DistributedAcceleration, proc::OSProc) = true +accel_matches_proc(accel::DistributedAcceleration, proc) = true + abstract type MemorySpace end +""" + Chunk + +A reference to a piece of data located on a remote worker. `Chunk`s are +typically created with `Dagger.tochunk(data)`, and the data can then be +accessed from any worker with `collect(::Chunk)`. `Chunk`s are +serialization-safe, and use distributed refcounting (provided by +`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, +as long as a reference exists on some worker. + +Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a +sense) the processor that "owns" or contains the data. Calling +`collect(::Chunk)` will perform data movement and conversions defined by that +processor to safely serialize the data to the calling worker. + +## Constructors +See [`tochunk`](@ref). +""" + +mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope, M<:MemorySpace} + chunktype::Type{T} + domain + handle::H + processor::P + scope::S + space::M + persist::Bool +end + struct CPURAMMemorySpace <: MemorySpace owner::Int end root_worker_id(space::CPURAMMemorySpace) = space.owner -memory_space(x) = CPURAMMemorySpace(myid()) -function memory_space(x::Chunk) - proc = processor(x) - if proc isa OSProc - # TODO: This should probably be programmable - return CPURAMMemorySpace(proc.pid) - else - return only(memory_spaces(proc)) - end -end -memory_space(x::DTask) = - memory_space(fetch(x; raw=true)) +CPURAMMemorySpace() = CPURAMMemorySpace(myid()) + +default_processor(space::CPURAMMemorySpace) = OSProc(space.owner) +default_memory_space(accel::DistributedAcceleration) = CPURAMMemorySpace(myid()) +default_memory_space(accel::DistributedAcceleration, x) = default_memory_space(accel) +default_memory_space(x) = default_memory_space(current_acceleration(), x) +default_memory_space() = default_memory_space(current_acceleration()) + +memory_space(x) = first(memory_spaces(default_processor())) +memory_space(x::Processor) = first(memory_spaces(x)) +memory_space(x::Chunk) = x.space +memory_space(x::DTask) = memory_space(fetch(x; raw=true)) memory_spaces(::P) where {P<:Processor} = throw(ArgumentError("Must define `memory_spaces` for `$P`")) + +function memory_spaces(proc::OSProc) + children = get_processors(proc) + spaces = Set{MemorySpace}() + for proc in children + for space in memory_spaces(proc) + push!(spaces, space) + end + end + return spaces +end memory_spaces(proc::ThreadProc) = Set([CPURAMMemorySpace(proc.owner)]) processors(::S) where {S<:MemorySpace} = @@ -70,6 +134,16 @@ function move!(::Type{<:Tridiagonal}, to_space::MemorySpace, from_space::MemoryS return end +# FIXME: Take MemorySpace instead +function move_type(from_proc::Processor, to_proc::Processor, ::Type{T}) where T + if from_proc == to_proc + return T + end + return Base._return_type(move, Tuple{typeof(from_proc), typeof(to_proc), T}) +end +move_type(from_proc::Processor, to_proc::Processor, ::Type{<:Chunk{T}}) where T = + move_type(from_proc, to_proc, T) + ### Aliasing and Memory Spans type_may_alias(::Type{String}) = false @@ -200,6 +274,7 @@ aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x end aliasing(x::DTask, T) = aliasing(fetch(x; raw=true), T) aliasing(x::DTask) = aliasing(fetch(x; raw=true)) +aliasing(accel::DistributedAcceleration, x::Chunk, T) = aliasing(x, T) struct ContiguousAliasing{S} <: AbstractAliasing span::MemorySpan{S} diff --git a/src/mpi.jl b/src/mpi.jl new file mode 100644 index 000000000..15b679dea --- /dev/null +++ b/src/mpi.jl @@ -0,0 +1,635 @@ +using MPI + +const CHECK_UNIFORMITY = TaskLocalValue{Bool}(()->false) +function check_uniformity!(check::Bool=true) + CHECK_UNIFORMITY[] = check +end +function check_uniform(value::Integer) + CHECK_UNIFORMITY[] || return + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + Core.print("[$rank] Starting check_uniform...\n") + all_min = MPI.Allreduce(value, MPI.Op(min, typeof(value)), comm) + all_max = MPI.Allreduce(value, MPI.Op(max, typeof(value)), comm) + Core.print("[$rank] Fetched min ($all_min)/max ($all_max) for check_uniform\n") + if all_min != all_max + if rank == 0 + Core.print("Found non-uniform value!\n") + end + Core.print("[$rank] value=$value\n") + exit(1) + end + flush(stdout) + MPI.Barrier(comm) +end + +MPIAcceleration() = MPIAcceleration(MPI.COMM_WORLD) + +#default_processor(accel::MPIAcceleration) = MPIOSProc(accel.comm) + +function aliasing(accel::MPIAcceleration, x::Chunk, T) + @assert x.handle isa MPIRef "MPIRef expected" + #print("[$(MPI.Comm_rank(x.handle.comm))] Hit probable hang on aliasing \n") + if x.handle.rank == MPI.Comm_rank(accel.comm) + ainfo = aliasing(x, T) + MPI.bcast(ainfo, x.handle.rank, x.handle.comm) + else + ainfo = MPI.bcast(nothing, x.handle.rank, x.handle.comm) + end + #print("[$(MPI.Comm_rank(x.handle.comm))] Left aliasing hang \n") + return ainfo +end +default_processor(accel::MPIAcceleration) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x::Chunk) = MPIOSProc(x.handle.comm, x.handle.rank) +default_processor(accel::MPIAcceleration, x::Function) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) +default_processor(accel::MPIAcceleration, T::Type) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) + +#TODO: Add a lock +const MPIClusterProcChildren = Dict{MPI.Comm, Set{Processor}}() + +struct MPIClusterProc <: Processor + comm::MPI.Comm + function MPIClusterProc(comm::MPI.Comm) + populate_children(comm) + return new(comm) + end +end + +Sch.init_proc(state, proc::MPIClusterProc, log_sink) = Sch.init_proc(state, MPIOSProc(proc.comm), log_sink) + +MPIClusterProc() = MPIClusterProc(MPI.COMM_WORLD) + +function populate_children(comm::MPI.Comm) + children = get_processors(OSProc()) + MPIClusterProcChildren[comm] = children +end + +struct MPIOSProc <: Processor + comm::MPI.Comm + rank::Int +end + +function MPIOSProc(comm::MPI.Comm) + rank = MPI.Comm_rank(comm) + return MPIOSProc(comm, rank) +end + +function MPIOSProc() + return MPIOSProc(MPI.COMM_WORLD) +end +#Sch.init_proc(state, proc::MPIOSProc, log_sink) = Sch.init_proc(state, OSProc(), log_sink) + +function check_uniform(proc::MPIOSProc) + check_uniform(hash(MPIOSProc)) + check_uniform(proc.rank) +end + +function memory_spaces(proc::MPIOSProc) + children = get_processors(proc) + spaces = Set{MemorySpace}() + for proc in children + for space in memory_spaces(proc) + push!(spaces, space) + end + end + return spaces +end + +struct MPIProcessScope <: AbstractScope + comm::MPI.Comm + rank::Int +end + +Base.isless(::MPIProcessScope, ::MPIProcessScope) = false +Base.isless(::MPIProcessScope, ::NodeScope) = true +Base.isless(::MPIProcessScope, ::UnionScope) = true +Base.isless(::MPIProcessScope, ::TaintScope) = true +Base.isless(::MPIProcessScope, ::AnyScope) = true +constrain(x::MPIProcessScope, y::MPIProcessScope) = + x == y ? y : InvalidScope(x, y) +constrain(x::NodeScope, y::MPIProcessScope) = + x == y.parent ? y : InvalidScope(x, y) + +Base.isless(::ExactScope, ::MPIProcessScope) = true +constrain(x::MPIProcessScope, y::ExactScope) = + x == y.parent ? y : InvalidScope(x, y) + +function enclosing_scope(proc::MPIOSProc) + return MPIProcessScope(proc.comm, proc.rank) +end + +struct MPIProcessor{P<:Processor} <: Processor + innerProc::P + comm::MPI.Comm + rank::Int +end + +function check_uniform(proc::MPIProcessor) + check_uniform(hash(MPIProcessor)) + check_uniform(proc.rank) + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(proc.innerProc)) +end + +Dagger.iscompatible_func(::MPIProcessor, opts, ::Any) = true +Dagger.iscompatible_arg(::MPIProcessor, opts, ::Any) = true + +default_enabled(proc::MPIProcessor) = default_enabled(proc.innerProc) + +root_worker_id(proc::MPIProcessor) = myid() +root_worker_id(proc::MPIOSProc) = myid() +root_worker_id(proc::MPIClusterProc) = myid() + +get_parent(proc::MPIClusterProc) = proc +get_parent(proc::MPIOSProc) = MPIClusterProc(proc.comm) +get_parent(proc::MPIProcessor) = MPIOSProc(proc.comm, proc.rank) + +short_name(proc::MPIProcessor) = "(MPI: $(proc.rank), $(short_name(proc.innerProc)))" + +function get_processors(mosProc::MPIOSProc) + populate_children(mosProc.comm) + children = MPIClusterProcChildren[mosProc.comm] + mpiProcs = Set{Processor}() + for proc in children + push!(mpiProcs, MPIProcessor(proc, mosProc.comm, mosProc.rank)) + end + return mpiProcs +end + +#TODO: non-uniform ranking through MPI groups +#TODO: use a lazy iterator +function get_processors(proc::MPIClusterProc) + children = Set{Processor}() + for i in 0:(MPI.Comm_size(proc.comm)-1) + for innerProc in MPIClusterProcChildren[proc.comm] + push!(children, MPIProcessor(innerProc, proc.comm, i)) + end + end + return children +end + +struct MPIMemorySpace{S<:MemorySpace} <: MemorySpace + innerSpace::S + comm::MPI.Comm + rank::Int +end + +function check_uniform(space::MPIMemorySpace) + check_uniform(space.rank) + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(space.innerSpace)) +end + +default_processor(space::MPIMemorySpace) = MPIOSProc(space.comm, space.rank) +default_memory_space(accel::MPIAcceleration) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) + +default_memory_space(accel::MPIAcceleration, x) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) +default_memory_space(accel::MPIAcceleration, x::Chunk) = MPIMemorySpace(CPURAMMemorySpace(myid()), x.handle.comm, x.handle.rank) +default_memory_space(accel::MPIAcceleration, x::Function) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) +default_memory_space(accel::MPIAcceleration, T::Type) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) + +function memory_spaces(proc::MPIClusterProc) + rawMemSpace = Set{MemorySpace}() + for rnk in 0:(MPI.Comm_size(proc.comm) - 1) + for innerSpace in memory_spaces(OSProc()) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, rnk)) + end + end + return rawMemSpace +end + +function memory_spaces(proc::MPIProcessor) + rawMemSpace = Set{MemorySpace}() + for innerSpace in memory_spaces(proc.innerProc) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, proc.rank)) + end + return rawMemSpace +end + +root_worker_id(mem_space::MPIMemorySpace) = myid() + +function processors(memSpace::MPIMemorySpace) + rawProc = Set{Processor}() + for innerProc in processors(memSpace.innerSpace) + push!(rawProc, MPIProcessor(innerProc, memSpace.comm, memSpace.rank)) + end + return rawProc +end + +struct MPIRefID + tid::Int + uid::UInt + id::Int +end + +function check_uniform(ref::MPIRefID) + check_uniform(ref.tid) + check_uniform(ref.uid) + check_uniform(ref.id) +end + +const MPIREF_TID = Dict{Int, Threads.Atomic{Int}}() +const MPIREF_UID = Dict{Int, Threads.Atomic{Int}}() + +mutable struct MPIRef + comm::MPI.Comm + rank::Int + size::Int + innerRef::Union{DRef, Nothing} + id::MPIRefID +end + +function check_uniform(ref::MPIRef) + check_uniform(ref.rank) + check_uniform(ref.id) +end + +move(from_proc::Processor, to_proc::Processor, x::MPIRef) = move(from_proc, to_proc, poolget(x.innerRef)) + +function affinity(x::MPIRef) + if x.innerRef === nothing + return MPIOSProc(x.comm, x.rank)=>0 + else + return MPIOSProc(x.comm, x.rank)=>x.innerRef.size + end +end + +peek_ref_id() = get_ref_id(false) +take_ref_id!() = get_ref_id(true) +function get_ref_id(take::Bool) + tid = 0 + uid = 0 + id = 0 + if Dagger.in_task() + tid = sch_handle().thunk_id.id + uid = 0 + counter = get!(MPIREF_TID, tid, Threads.Atomic{Int}(1)) + id = if take + Threads.atomic_add!(counter, 1) + else + counter[] + end + end + if MPI_UID[] != 0 + tid = 0 + uid = MPI_UID[] + counter = get!(MPIREF_UID, uid, Threads.Atomic{Int}(1)) + id = if take + Threads.atomic_add!(counter, 1) + else + counter[] + end + end + return MPIRefID(tid, uid, id) +end + +#TODO: partitioned scheduling with comm bifurcation +function tochunk_pset(x, space::MPIMemorySpace; device=nothing, kwargs...) + @assert space.comm == MPI.COMM_WORLD "$(space.comm) != $(MPI.COMM_WORLD)" + local_rank = MPI.Comm_rank(space.comm) + Mid = take_ref_id!() + if local_rank != space.rank + return MPIRef(space.comm, space.rank, 0, nothing, Mid) + else + return MPIRef(space.comm, space.rank, sizeof(x), poolset(x; device, kwargs...), Mid) + end +end + +function recv_yield(comm, src, tag) + while true + (got, stat) = MPI.Iprobe(comm, MPI.Status; source=src, tag=tag) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield (Iprobe) failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + buf = Array{UInt8}(undef, count) + req = MPI.Irecv!(MPI.Buffer(buf), comm; source=src, tag=tag) + while true + finish, stat = MPI.Test(req, MPI.Status) + if finish + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield (Test) failed with error $(MPI.Get_error(stat))") + end + value = MPI.deserialize(buf) + rnk = MPI.Comm_rank(comm) + return value + end + yield() + end + end + yield() + end +end +function send_yield(value, comm, dest, tag) + req = MPI.isend(value, comm; dest, tag) + while true + finish, status = MPI.Test(req, MPI.Status) + if finish + if MPI.Get_error(status) != MPI.SUCCESS + error("send_yield (Test) failed with error $(MPI.Get_error(status))") + end + return + end + yield() + end +end +function bcast_send_yield(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + for other_rank in 0:(sz-1) + rank == other_rank && continue + send_yield(value, comm, other_rank, tag) + end +end + +#discuss this with julian +WeakChunk(c::Chunk{T,H}) where {T,H<:MPIRef} = WeakChunk(c.handle.rank, c.handle.id.id, WeakRef(c)) + +function poolget(ref::MPIRef) + @assert ref.rank == MPI.Comm_rank(ref.comm) "MPIRef rank mismatch" + poolget(ref.innerRef) +end + +function move!(dep_mod, dst::MPIMemorySpace, src::MPIMemorySpace, dstarg::Chunk, srcarg::Chunk) + @assert dstarg.handle isa MPIRef && srcarg.handle isa MPIRef "MPIRef expected" + @assert dstarg.handle.comm == srcarg.handle.comm "MPIRef comm mismatch" + @assert dstarg.handle.rank == dst.rank && srcarg.handle.rank == src.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(srcarg.handle.comm) + h = abs(Base.unsafe_trunc(Int32, hash(dep_mod, hash(srcarg.handle.id, hash(dstarg.handle.id))))) + @dagdebug nothing :mpi "[$local_rank][$h] Moving from $(src.rank) to $(dst.rank)\n" + if src.rank == dst.rank == local_rank + move!(dep_mod, dst.innerSpace, src.innerSpace, dstarg, srcarg) + else + if local_rank == src.rank + send_yield(poolget(srcarg.handle), dst.comm, dst.rank, h) + end + if local_rank == dst.rank + val = recv_yield(src.comm, src.rank, h) + move!(dep_mod, dst.innerSpace, src.innerSpace, poolget(dstarg.handle), val) + end + end + @dagdebug nothing :mpi "[$local_rank][$h] Finished moving from $(src.rank) to $(dst.rank) successfuly\n" +end + +move(::MPIOSProc, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIOSProc, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +#TODO: out of place MPI move +function move(src::MPIOSProc, dst::MPIProcessor, x::Chunk) + @assert src.comm == dst.comm "Multi comm move not supported" + if Sch.SCHED_MOVE[] + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + @assert src.rank == MPI.Comm_rank(src.comm) "Unwraping not permited" + @assert src.rank == x.handle.rank == dst.rank + return poolget(x.handle) + end +end + +#TODO: Discuss this with julian + +move(src::Processor, dst::MPIProcessor, x::Chunk) = error("MPI move not supported") +move(to_proc::MPIProcessor, chunk::Chunk) = + move(chunk.processor, to_proc, chunk) +move(to_proc::Processor, d::MPIRef) = + move(MPIOSProc(d.rank), to_proc, d) +move(to_proc::MPIProcessor, x) = + move(MPIOSProc(), to_proc, x) + +move(::MPIProcessor, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIProcessor, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +function move(src::MPIProcessor, dst::MPIProcessor, x::Chunk) + @assert src.rank == dst.rank "Unwrapping not permitted" + if Sch.SCHED_MOVE[] + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permitted" + @assert src.rank == x.handle.rank == dst.rank + return poolget(x.handle) + end +end + +#FIXME:try to think of a better move! scheme +function execute!(proc::MPIProcessor, f, args...; kwargs...) + local_rank = MPI.Comm_rank(proc.comm) + tag = abs(Base.unsafe_trunc(Int32, hash(peek_ref_id()))) + tid = sch_handle().thunk_id.id + if local_rank == proc.rank || f === move! + result = execute!(proc.innerProc, f, args...; kwargs...) + bcast_send_yield(typeof(result), proc.comm, proc.rank, tag) + space = memory_space(result)::MPIMemorySpace + bcast_send_yield(space.innerSpace, proc.comm, proc.rank, tag) + return tochunk(result, proc, space) + else + T = recv_yield(proc.comm, proc.rank, tag) + innerSpace = recv_yield(proc.comm, proc.rank, tag) + space = MPIMemorySpace(innerSpace, proc.comm, proc.rank) + #= FIXME: If we get a bad result (something non-concrete, or Union{}), + # we should bcast the actual type + @warn "FIXME: Kwargs" maxlog=1 + T = Base._return_type(f, Tuple{typeof.(args)...}) + return tochunk(nothing, proc, memory_space(proc); type=T) + =# + return tochunk(nothing, proc, space; type=T) + end +end + +accelerate!(::Val{:mpi}) = accelerate!(MPIAcceleration()) + +function initialize_acceleration!(a::MPIAcceleration) + if !MPI.Initialized() + MPI.Init(;threadlevel=:multiple) + end + ctx = Dagger.Sch.eager_context() + sz = MPI.Comm_size(a.comm) + for i in 0:(sz-1) + push!(ctx.procs, MPIOSProc(a.comm, i)) + end + unique!(ctx.procs) +end + +accel_matches_proc(accel::MPIAcceleration, proc::MPIOSProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIClusterProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIProcessor) = true +accel_matches_proc(accel::MPIAcceleration, proc) = false + +distribute(A::AbstractArray{T,N}, dist::Blocks{N}, root::Int; comm::MPI.Comm=MPI.COMM_WORLD) where {T,N} = + distribute(A::AbstractArray{T,N}, dist; comm, root) +distribute(A::AbstractArray, root::Int; comm::MPI.Comm=MPI.COMM_WORLD) = distribute(A, AutoBlocks(), root; comm) +distribute(A::AbstractArray, ::AutoBlocks, root::Int; comm::MPI.Comm=MPI.COMM_WORLD) = distribute(A, auto_blocks(A), root; comm) +function distribute(x::AbstractArray{T,N}, n::NTuple{N}, root::Int; comm::MPI.Comm=MPI.COMM_WORLD) where {T,N} + p = map((d, dn)->ceil(Int, d / dn), size(x), n) + distribute(x, Blocks(p), root; comm) +end +distribute(x::AbstractVector, n::Int, root::Int; comm::MPI.Comm=MPI.COMM_WORLD) = distribute(x, (n,), root; comm) +distribute(x::AbstractVector, n::Vector{<:Integer}, root::Int; comm::MPI.Comm) = + distribute(x, DomainBlocks((1,), (cumsum(n),)); comm, root=0) + + +distribute(A::AbstractArray{T,N}, dist::Blocks{N}, comm::MPI.Comm; root::Int=0) where {T,N} = + distribute(A::AbstractArray{T,N}, dist; comm, root) +distribute(A::AbstractArray, comm::MPI.Comm; root::Int=0) = distribute(A, AutoBlocks(), comm; root) +distribute(A::AbstractArray, ::AutoBlocks, comm::MPI.Comm; root::Int=0) = distribute(A, auto_blocks(A), comm; root) +function distribute(x::AbstractArray{T,N}, n::NTuple{N}, comm::MPI.Comm; root::Int=0) where {T,N} + p = map((d, dn)->ceil(Int, d / dn), size(x), n) + distribute(x, Blocks(p), comm; root) +end +distribute(x::AbstractVector, n::Int, comm::MPI.Comm; root::Int=0) = distribute(x, (n,), comm; root) +distribute(x::AbstractVector, n::Vector{<:Integer}, comm::MPI.Comm; root::Int=0) = + distribute(x, DomainBlocks((1,), (cumsum(n),)), comm; root) + +function distribute(x::AbstractArray{T,N}, dist::Blocks{N}, ::MPIAcceleration) where {T,N} + return distribute(x, dist; comm=MPI.COMM_WORLD, root=0) +end + +distribute(A::Nothing, dist::Blocks{N}) where N = distribute(nothing, dist; comm=MPI.COMM_WORLD, root=0) + +function distribute(A::Union{AbstractArray{T,N}, Nothing}, dist::Blocks{N}; comm::MPI.Comm, root::Int) where {T,N} + rnk = MPI.Comm_rank(comm) + isroot = rnk == root + csz = MPI.Comm_size(comm) + d = MPI.bcast(domain(A), comm; root) + sd = partition(dist, d) + type = MPI.bcast(eltype(A), comm; root) + # TODO: Make better load balancing + cs = Array{Any}(undef, size(sd)) + if prod(size(sd)) < csz + @warn "Number of chunks is less than number of ranks, performance may be suboptimal" + end + AT = MPI.bcast(typeof(A), comm; root) + if isroot + dst = 0 + for (idx, part) in enumerate(sd) + if dst != root + h = abs(Base.unsafe_trunc(Int32, hash(part, UInt(0)))) + send_yield(A[part], comm, dst, h) + data = nothing + else + data = A[part] + end + with(MPI_UID=>Dagger.eager_next_id()) do + p = MPIOSProc(comm, dst) + s = first(memory_spaces(p)) + cs[idx] = tochunk(data, p, s; type=AT) + dst += 1 + if dst == csz + dst = 0 + end + end + end + Core.print("[$rnk] Sent all chunks\n") + else + dst = 0 + for (idx, part) in enumerate(sd) + data = nothing + if rnk == dst + h = abs(Base.unsafe_trunc(Int32, hash(part, UInt(0)))) + data = recv_yield(comm, root, h) + end + with(MPI_UID=>Dagger.eager_next_id()) do + p = MPIOSProc(comm, dst) + s = first(memory_spaces(p)) + cs[idx] = tochunk(data, p, s; type=AT) + dst += 1 + if dst == csz + dst = 0 + end + end + #MPI.Scatterv!(nothing, data, comm; root=root) + end + end + MPI.Barrier(comm) + return Dagger.DArray(type, d, sd, cs, dist) +end + +function Base.collect(x::Dagger.DMatrix{T}; + comm=MPI.COMM_WORLD, root=nothing, acrossranks::Bool=true) where {T} + csz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + sd = x.subdomains + if !acrossranks + if isempty(x.chunks) + return Array{eltype(d)}(undef, size(x)...) + end + localarr = [] + localparts = [] + curpart = rank + 1 + while curpart <= length(x.chunks) + print("[$rank] Collecting chunk $curpart\n") + push!(localarr, fetch(x.chunks[curpart])) + push!(localparts, sd[curpart]) + curpart += csz + end + return localarr, localparts + else + reqs = Vector{MPI.Request}() + dst = 0 + if root === nothing + data = Matrix{T}(undef, size(x)) + localarr, localparts = collect(x; acrossranks=false) + for (idx, part) in enumerate(localparts) + for i in 0:(csz - 1) + if i != rank + h = abs(Base.unsafe_trunc(Int32, hash(part, UInt(0)))) + print("[$rank] Sent chunk $idx to rank $i with tag $h \n") + push!(reqs, MPI.isend(localarr[idx], comm; dest = i, tag = h)) + else + data[part.indexes...] = localarr[idx] + end + end + end + for (idx, part) in enumerate(sd) + h = abs(Base.unsafe_trunc(Int32, hash(part, UInt(0)))) + if dst != rank + print("[$rank] Waiting for chunk $idx from rank $dst with tag $h\n") + data[part.indexes...] = recv_yield(comm, dst, h) + end + dst += 1 + if dst == MPI.Comm_size(comm) + dst = 0 + end + end + MPI.Waitall(reqs) + return data + else + if rank == root + data = Matrix{T}(undef, size(x)) + for (idx, part) in enumerate(sd) + h = abs(Base.unsafe_trunc(Int32, hash(part, UInt(0)))) + if dst == rank + localdata = fetch(x.chunks[idx]) + data[part.indexes...] = localdata + else + data[part.indexes...] = recv_yield(comm, dst, h) + end + dst += 1 + if dst == MPI.Comm_size(comm) + dst = 0 + end + end + return fetch.(data) + else + for (idx, part) in enumerate(sd) + h = abs(Base.unsafe_trunc(Int32, hash(part, UInt(0)))) + if rank == dst + localdata = fetch(x.chunks[idx]) + push!(reqs, MPI.isend(localdata, comm; dest = root, tag = h)) + end + dst += 1 + if dst == MPI.Comm_size(comm) + dst = 0 + end + end + MPI.Waitall(reqs) + return nothing + end + end + end +end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index b894f4526..c943996ed 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -15,11 +15,10 @@ import Base: @invokelatest import ..Dagger import ..Dagger: Context, Processor, Thunk, WeakThunk, ThunkFuture, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, LockedObject -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime, root_worker_id import ..Dagger: @dagdebug, @safe_lock_spin1 import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek - -import ..Dagger +import ScopedValues: ScopedValue, with const OneToMany = Dict{Thunk, Set{Thunk}} @@ -43,7 +42,7 @@ function Base.show(io::IO, entry::ProcessorCacheEntry) entries += 1 next = next.next end - print(io, "ProcessorCacheEntry(pid $(entry.gproc.pid), $(entry.proc), $entries entries)") + print(io, "ProcessorCacheEntry(pid $(root_worker_id(entry.gproc)), $(entry.proc), $entries entries)") end const Signature = Vector{Any} @@ -91,11 +90,11 @@ struct ComputeState running_on::Dict{Thunk,OSProc} thunk_dict::Dict{Int, WeakThunk} node_order::Any - worker_time_pressure::Dict{Int,Dict{Processor,UInt64}} - worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_loadavg::Dict{Int,NTuple{3,Float64}} - worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}} + worker_time_pressure::Dict{Processor,Dict{Processor,UInt64}} + worker_storage_pressure::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_storage_capacity::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_loadavg::Dict{Processor,NTuple{3,Float64}} + worker_chans::Dict{Int,Tuple{RemoteChannel,RemoteChannel}} procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}} signature_time_cost::Dict{Signature,UInt64} signature_alloc_cost::Dict{Signature,UInt64} @@ -251,6 +250,7 @@ Base.@kwdef struct ThunkOptions storage_root_tag = nothing storage_leaf_tag::Union{MemPool.Tag,Nothing} = nothing storage_retain::Bool = false + acceleration::Union{Nothing, Dagger.Acceleration} = nothing end """ @@ -275,7 +275,8 @@ function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions) topts.storage, topts.storage_root_tag, topts.storage_leaf_tag, - topts.storage_retain) + topts.storage_retain, + topts.acceleration) end Base.merge(sopts::SchedulerOptions, ::Nothing) = ThunkOptions(sopts.single, @@ -312,7 +313,10 @@ function populate_defaults(opts::ThunkOptions, Tf, Targs) maybe_default(:storage_root_tag), maybe_default(:storage_leaf_tag), maybe_default(:storage_retain), - ) + maybe_default(:acceleration)) +end + +function cleanup(ctx) end # Eager scheduling @@ -323,14 +327,14 @@ const WORKER_MONITOR_TASKS = Dict{Int,Task}() const WORKER_MONITOR_CHANS = Dict{Int,Dict{UInt64,RemoteChannel}}() function init_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - timespan_start(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + pid = Dagger.root_worker_id(p) + timespan_start(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) # Initialize pressure and capacity - gproc = OSProc(p.pid) lock(state.lock) do - state.worker_time_pressure[p.pid] = Dict{Processor,UInt64}() + state.worker_time_pressure[p] = Dict{Processor,UInt64}() - state.worker_storage_pressure[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() - state.worker_storage_capacity[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_storage_pressure[p] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_storage_capacity[p] = Dict{Union{StorageResource,Nothing},UInt64}() #= FIXME for storage in get_storage_resources(gproc) pressure, capacity = remotecall_fetch(gproc.pid, storage) do storage @@ -341,11 +345,11 @@ function init_proc(state, p, log_sink) end =# - state.worker_loadavg[p.pid] = (0.0, 0.0, 0.0) + state.worker_loadavg[p] = (0.0, 0.0, 0.0) end - if p.pid != 1 + if pid != 1 lock(WORKER_MONITOR_LOCK) do - wid = p.pid + wid = pid if !haskey(WORKER_MONITOR_TASKS, wid) t = Threads.@spawn begin try @@ -379,16 +383,16 @@ function init_proc(state, p, log_sink) end # Setup worker-to-scheduler channels - inp_chan = RemoteChannel(p.pid) - out_chan = RemoteChannel(p.pid) + inp_chan = RemoteChannel(pid) + out_chan = RemoteChannel(pid) lock(state.lock) do - state.worker_chans[p.pid] = (inp_chan, out_chan) + state.worker_chans[pid] = (inp_chan, out_chan) end # Setup dynamic listener - dynamic_listener!(ctx, state, p.pid) + dynamic_listener!(ctx, state, pid) - timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) end function _cleanup_proc(uid, log_sink) empty!(CHUNK_CACHE) # FIXME: Should be keyed on uid! @@ -403,7 +407,7 @@ function _cleanup_proc(uid, log_sink) end function cleanup_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - wid = p.pid + wid = root_worker_id(p) timespan_start(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) lock(WORKER_MONITOR_LOCK) do if haskey(WORKER_MONITOR_CHANS, wid) @@ -470,7 +474,7 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) node_order = x -> -get(ord, x, 0) state = start_state(deps, node_order, chan) - master = OSProc(myid()) + master = Dagger.default_processor() timespan_start(ctx, :scheduler_init, (;uid=state.uid), master) try @@ -559,8 +563,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) end pid, proc, thunk_id, (res, metadata) = chan_value @dagdebug thunk_id :take "Got finished task" - gproc = OSProc(pid) safepoint(state) + gproc = proc != nothing ? get_parent(proc) : OSProc(pid) lock(state.lock) do thunk_failed = false if res isa Exception @@ -587,11 +591,11 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) end node = unwrap_weak_checked(state.thunk_dict[thunk_id]) if metadata !== nothing - state.worker_time_pressure[pid][proc] = metadata.time_pressure + state.worker_time_pressure[gproc][proc] = metadata.time_pressure #to_storage = fetch(node.options.storage) #state.worker_storage_pressure[pid][to_storage] = metadata.storage_pressure #state.worker_storage_capacity[pid][to_storage] = metadata.storage_capacity - state.worker_loadavg[pid] = metadata.loadavg + state.worker_loadavg[gproc] = metadata.loadavg sig = signature(state, node) state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2 state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2 @@ -692,7 +696,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) populate_processor_cache_list!(state, procs) # Schedule tasks - to_fire = Dict{Tuple{OSProc,<:Processor},Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}}() + to_fire = Dict{Tuple{<:Processor,<:Processor},Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}}() failed_scheduling = Thunk[] # Select a new task and get its options @@ -760,7 +764,11 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) if length(procs) > fallback_threshold @goto fallback end - local_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in procs]...)) + accel = something(opts.acceleration, Dagger.DistributedAcceleration()) + accel_procs = filter(procs) do proc + Dagger.accel_matches_proc(accel, proc) + end + local_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in accel_procs]...)) if length(local_procs) > fallback_threshold @goto fallback end @@ -785,7 +793,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) can_use, scope = can_use_proc(state, task, gproc, proc, opts, scope) if can_use has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, proc, gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig) + has_capacity(state, proc, gproc, opts.time_util, opts.alloc_util, opts.occupancy, sig) if has_cap # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util @@ -793,8 +801,8 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}() end push!(proc_tasks, (task, scope, est_time_util, est_alloc_util, est_occupancy)) - state.worker_time_pressure[gproc.pid][proc] = - get(state.worker_time_pressure[gproc.pid], proc, 0) + + state.worker_time_pressure[gproc][proc] = + get(state.worker_time_pressure[gproc], proc, 0) + est_time_util @dagdebug task :schedule "Scheduling to $gproc -> $proc" @goto pop_task @@ -817,7 +825,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) can_use, scope = can_use_proc(state, task, entry.gproc, entry.proc, opts, scope) if can_use has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig) + has_capacity(state, entry.proc, entry.gproc, opts.time_util, opts.alloc_util, opts.occupancy, sig) if has_cap selected_entry = entry else @@ -843,7 +851,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) can_use, scope = can_use_proc(state, task, entry.gproc, entry.proc, opts, scope) if can_use has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig) + has_capacity(state, entry.proc, entry.gproc, opts.time_util, opts.alloc_util, opts.occupancy, sig) if has_cap # Select this processor selected_entry = entry @@ -929,13 +937,13 @@ function monitor_procs_changed!(ctx, state) end function remove_dead_proc!(ctx, state, proc, options=ctx.options) - @assert options.single !== proc.pid "Single worker failed, cannot continue." + @assert options.single !== root_worker_id(proc) "Single worker failed, cannot continue." rmprocs!(ctx, [proc]) - delete!(state.worker_time_pressure, proc.pid) - delete!(state.worker_storage_pressure, proc.pid) - delete!(state.worker_storage_capacity, proc.pid) - delete!(state.worker_loadavg, proc.pid) - delete!(state.worker_chans, proc.pid) + delete!(state.worker_time_pressure, proc) + delete!(state.worker_storage_pressure, proc) + delete!(state.worker_storage_capacity, proc) + delete!(state.worker_loadavg, proc) + delete!(state.worker_chans, root_worker_id(proc)) state.procs_cache_list[] = nothing end @@ -997,7 +1005,7 @@ end function evict_all_chunks!(ctx, to_evict) if !isempty(to_evict) - @sync for w in map(p->p.pid, procs_to_use(ctx)) + @sync for w in map(p->root_worker_id(p), procs_to_use(ctx)) Threads.@spawn remote_do(evict_chunks!, w, ctx.log_sink, to_evict) end end @@ -1021,10 +1029,11 @@ fire_task!(ctx, thunk::Thunk, p, state; scope=AnyScope(), time_util=10^9, alloc_ fire_task!(ctx, (thunk, scope, time_util, alloc_util, occupancy)::Tuple{Thunk,<:Any}, p, state) = fire_tasks!(ctx, [(thunk, scope, time_util, alloc_util, occupancy)], p, state) function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) + pid = root_worker_id(gproc) to_send = [] for (thunk, scope, time_util, alloc_util, occupancy) in thunks push!(state.running, thunk) - state.running_on[thunk] = gproc + state.running_on[thunk] = OSProc(pid) if thunk.cache && thunk.cache_ref !== nothing # the result might be already cached data = thunk.cache_ref @@ -1076,9 +1085,9 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) toptions = thunk.options !== nothing ? thunk.options : ThunkOptions() options = merge(ctx.options, toptions) propagated = get_propagated_options(thunk) - @assert (options.single === nothing) || (gproc.pid == options.single) + @assert (options.single === nothing) || (pid == options.single) # TODO: Set `sch_handle.tid.ref` to the right `DRef` - sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) + sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) push!(to_send, Any[thunk.id, time_util, alloc_util, occupancy, @@ -1095,15 +1104,15 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) for ts in to_send # TODO: errormonitor task = Threads.@spawn begin - timespan_start(ctx, :fire, (;uid=state.uid, worker=gproc.pid), nothing) + timespan_start(ctx, :fire, (;uid=state.uid, worker=pid), nothing) try - remotecall_wait(do_tasks, gproc.pid, proc, state.chan, [ts]); + remotecall_wait(do_tasks, pid, proc, state.chan, [ts]); catch err bt = catch_backtrace() thunk_id = ts[1] - put!(state.chan, (gproc.pid, proc, thunk_id, (CapturedException(err, bt), nothing))) + put!(state.chan, (pid, proc, thunk_id, (CapturedException(err, bt), nothing))) finally - timespan_finish(ctx, :fire, (;uid=state.uid, worker=gproc.pid), nothing) + timespan_finish(ctx, :fire, (;uid=state.uid, worker=pid), nothing) end end end @@ -1228,7 +1237,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re proc_occupancy = istate.proc_occupancy time_pressure = istate.time_pressure - wid = get_parent(to_proc).pid + wid = root_worker_id(to_proc) work_to_do = false while isopen(return_queue) # Wait for new tasks @@ -1477,6 +1486,8 @@ function do_tasks(to_proc, return_queue, tasks) end @dagdebug nothing :processor "Kicked processors" end + +const SCHED_MOVE = ScopedValue{Bool}(false) """ do_task(to_proc, task_desc) -> Any @@ -1490,8 +1501,9 @@ function do_task(to_proc, task_desc) options, propagated, ids, positions, ctx_vars, sch_handle, sch_uid = task_desc ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) + Dagger.accelerate!(options.acceleration) - from_proc = OSProc() + from_proc = Dagger.default_processor() Tdata = Any[] for x in data push!(Tdata, chunktype(x)) @@ -1618,12 +1630,13 @@ function do_task(to_proc, task_desc) end else =# - new_x = @invokelatest move(to_proc, x) - #end + new_x = with(SCHED_MOVE=>true) do + @invokelatest move(to_proc, x) + end if new_x !== x @dagdebug thunk_id :move "Moved argument $position to $to_proc: $(typeof(x)) -> $(typeof(new_x))" end - timespan_finish(ctx, :move, (;thunk_id, id, position, processor=to_proc), (;f, data=new_x); tasks=[Base.current_task()]) + timespan_finish(ctx, :move, (;thunk_id, id, processor=to_proc), (;f, data=new_x); tasks=[Base.current_task()]) return new_x end end diff --git a/src/sch/eager.jl b/src/sch/eager.jl index aea0abbf6..17d47581a 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -26,7 +26,8 @@ function init_eager() errormonitor_tracked("eager compute()", Threads.@spawn try sopts = SchedulerOptions(;allow_errors=true) opts = Dagger.Options((;scope=Dagger.ExactScope(Dagger.ThreadProc(1, 1)), - occupancy=Dict(Dagger.ThreadProc=>0))) + occupancy=Dict(Dagger.ThreadProc=>0), + acceleration=Dagger.DistributedAcceleration())) Dagger.compute(ctx, Dagger._delayed(eager_thunk, opts)(); options=sopts) catch err diff --git a/src/sch/util.jl b/src/sch/util.jl index 2e090b26c..cf2655a02 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -328,6 +328,7 @@ end function can_use_proc(state, task, gproc, proc, opts, scope) # Check against proclist + pid = Dagger.root_worker_id(gproc) if opts.proclist !== nothing @warn "The `proclist` option is deprecated, please use scopes instead\nSee https://juliaparallel.org/Dagger.jl/stable/scopes/ for details" maxlog=1 if opts.proclist isa Function @@ -355,8 +356,8 @@ function can_use_proc(state, task, gproc, proc, opts, scope) # Check against single if opts.single !== nothing @warn "The `single` option is deprecated, please use scopes instead\nSee https://juliaparallel.org/Dagger.jl/stable/scopes/ for details" maxlog=1 - if gproc.pid != opts.single - @dagdebug task :scope "Rejected $proc: gproc.pid ($(gproc.pid)) != single ($(opts.single))" + if pid != opts.single + @dagdebug task :scope "Rejected $proc: pid ($(pid)) != single ($(opts.single))" return false, scope end scope = constrain(scope, Dagger.ProcessScope(opts.single)) @@ -438,7 +439,7 @@ function populate_processor_cache_list!(state, procs) # Populate the cache if empty if state.procs_cache_list[] === nothing current = nothing - for p in map(x->x.pid, procs) + for p in map(x->Dagger.root_worker_id(x), procs) for proc in get_processors(OSProc(p)) next = ProcessorCacheEntry(OSProc(p), proc) if current === nothing @@ -514,7 +515,7 @@ function estimate_task_costs(state, procs, task, inputs) tx_cost = impute_sum(affinity(chunk)[2] for chunk in chunks_filt) # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(state.worker_time_pressure[get_parent(proc).pid], proc, 0) + est_time_util = get(state.worker_time_pressure[get_parent(proc)], proc, 0) costs[proc] = est_time_util + (tx_cost/tx_rate) end diff --git a/src/scopes.jl b/src/scopes.jl index 834993c9f..1e601371a 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -87,11 +87,13 @@ ProcessorTypeScope(T) = Set{AbstractScopeTaint}([ProcessorTypeTaint{T}()])) "Scoped to a specific processor." -struct ExactScope <: AbstractScope - parent::ProcessScope +struct ExactScope{P<:AbstractScope} <: AbstractScope + parent::P processor::Processor end -ExactScope(proc) = ExactScope(ProcessScope(get_parent(proc).pid), proc) +ExactScope(proc) = ExactScope(enclosing_scope(get_parent(proc)), proc) + +enclosing_scope(proc::OSProc) = ProcessScope(proc.pid) "Indicates that the applied scopes `x` and `y` are incompatible." struct InvalidScope <: AbstractScope diff --git a/src/thunk.jl b/src/thunk.jl index dc961f303..61f6475d6 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -84,16 +84,25 @@ mutable struct Thunk affinity=nothing, eager_ref=nothing, processor=nothing, + memory_space=nothing, scope=nothing, options=nothing, propagates=(), kwargs... ) - if !isa(f, Chunk) && (!isnothing(processor) || !isnothing(scope)) - f = tochunk(f, - something(processor, OSProc()), - something(scope, DefaultScope())) - end + #FIXME: dont force unwrap with fetch + f = fetch(f) + if (!isnothing(processor) || !isnothing(scope) || !isnothing(memory_space)) + if !isnothing(processor) + f = tochunk(f, + processor, + something(scope, DefaultScope())) + else + f = tochunk(f, + something(memory_space, default_memory_space(f)), + something(scope, DefaultScope())) + end + end xs = Base.mapany(identity, xs) syncdeps_set = Set{Any}(filterany(is_task_or_chunk, Base.mapany(last, xs))) if syncdeps !== nothing @@ -467,12 +476,21 @@ function spawn(f, args...; kwargs...) # Wrap f in a Chunk if necessary processor = haskey(options, :processor) ? options.processor : nothing scope = haskey(options, :scope) ? options.scope : nothing - if !isnothing(processor) || !isnothing(scope) - f = tochunk(f, - something(processor, get_options(:processor, OSProc())), - something(scope, get_options(:scope, DefaultScope()))) + memory_space = haskey(options, :memory_space) ? options.memory_space : nothing + #FIXME: don't for unwrap with fetch + f = fetch(f) + if (!isnothing(processor) || !isnothing(scope) || !isnothing(memory_space)) + if !isnothing(processor) + f = tochunk(f, + processor, + something(scope, DefaultScope())) + else + f = tochunk(f, + something(memory_space, default_memory_space(f)), + something(scope, DefaultScope())) + end end - + # Process the args and kwargs into Pair form args_kwargs = args_kwargs_to_pairs(args, kwargs) @@ -481,6 +499,9 @@ function spawn(f, args...; kwargs...) options = NamedTuple(filter(opt->opt[1] != :task_queue, Base.pairs(options))) propagates = filter(prop->prop != :task_queue, propagates) options = merge(options, (;propagates)) + if !haskey(options, :acceleration) + options = merge(options, (;acceleration=current_acceleration())) + end # Construct task spec and handle spec = DTaskSpec(f, args_kwargs, options)