Skip to content

Commit ebb7ca9

Browse files
committed
datadeps: Reduce dynamic dispatch with AliasingWrapper
1 parent 51d6429 commit ebb7ca9

File tree

2 files changed

+36
-24
lines changed

2 files changed

+36
-24
lines changed

src/datadeps.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,26 +90,26 @@ end
9090
struct DataDepsAliasingState
9191
# Track original and current data locations
9292
# We track data => space
93-
data_origin::Dict{AbstractAliasing,MemorySpace}
94-
data_locality::Dict{AbstractAliasing,MemorySpace}
93+
data_origin::Dict{AliasingWrapper,MemorySpace}
94+
data_locality::Dict{AliasingWrapper,MemorySpace}
9595

9696
# Track writers ("owners") and readers
97-
ainfos_owner::Dict{AbstractAliasing,Union{Pair{DTask,Int},Nothing}}
98-
ainfos_readers::Dict{AbstractAliasing,Vector{Pair{DTask,Int}}}
99-
ainfos_overlaps::Dict{AbstractAliasing,Set{AbstractAliasing}}
97+
ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}
98+
ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}
99+
ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}}
100100

101101
# Cache ainfo lookups
102-
ainfo_cache::Dict{Tuple{Any,Any},AbstractAliasing}
102+
ainfo_cache::Dict{Tuple{Any,Any},AliasingWrapper}
103103

104104
function DataDepsAliasingState()
105-
data_origin = Dict{AbstractAliasing,MemorySpace}()
106-
data_locality = Dict{AbstractAliasing,MemorySpace}()
105+
data_origin = Dict{AliasingWrapper,MemorySpace}()
106+
data_locality = Dict{AliasingWrapper,MemorySpace}()
107107

108-
ainfos_owner = Dict{AbstractAliasing,Union{Pair{DTask,Int},Nothing}}()
109-
ainfos_readers = Dict{AbstractAliasing,Vector{Pair{DTask,Int}}}()
110-
ainfos_overlaps = Dict{AbstractAliasing,Set{AbstractAliasing}}()
108+
ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}()
109+
ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}()
110+
ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}()
111111

112-
ainfo_cache = Dict{Tuple{Any,Any},AbstractAliasing}()
112+
ainfo_cache = Dict{Tuple{Any,Any},AliasingWrapper}()
113113

114114
return new(data_origin, data_locality,
115115
ainfos_owner, ainfos_readers, ainfos_overlaps,
@@ -142,7 +142,7 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
142142
aliasing::Bool
143143

144144
# The ordered list of tasks and their read/write dependencies
145-
dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}}
145+
dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}}
146146

147147
# The mapping of memory space to remote argument copies
148148
remote_args::Dict{MemorySpace,IdDict{Any,Any}}
@@ -154,7 +154,7 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
154154
alias_state::State
155155

156156
function DataDepsState(aliasing::Bool)
157-
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[]
157+
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}[]
158158
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
159159
supports_inplace_cache = IdDict{Any,Bool}()
160160
if aliasing
@@ -168,7 +168,7 @@ end
168168

169169
function aliasing(astate::DataDepsAliasingState, arg, dep_mod)
170170
return get!(astate.ainfo_cache, (arg, dep_mod)) do
171-
return aliasing(arg, dep_mod)
171+
return AliasingWrapper(aliasing(arg, dep_mod))
172172
end
173173
end
174174

@@ -245,7 +245,7 @@ end
245245
# Aliasing state setup
246246
function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
247247
# Populate task dependencies
248-
dependencies_to_add = Vector{Tuple{Bool,Bool,AbstractAliasing,<:Any,<:Any}}()
248+
dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}()
249249

250250
# Track the task's arguments and access patterns
251251
for (idx, _arg) in enumerate(spec.fargs)
@@ -263,7 +263,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
263263
if state.aliasing
264264
ainfo = aliasing(state.alias_state, arg, dep_mod)
265265
else
266-
ainfo = UnknownAliasing()
266+
ainfo = AliasingWrapper(UnknownAliasing())
267267
end
268268
push!(dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg))
269269
end
@@ -274,7 +274,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
274274

275275
# Track the task result too
276276
# N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this
277-
push!(dependencies_to_add, (false, false, UnknownAliasing(), identity, task))
277+
push!(dependencies_to_add, (false, false, AliasingWrapper(UnknownAliasing()), identity, task))
278278

279279
# Record argument/result dependencies
280280
push!(state.dependencies, task => dependencies_to_add)
@@ -286,7 +286,7 @@ function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, ar
286286

287287
# Initialize owner and readers
288288
if !haskey(astate.ainfos_owner, ainfo)
289-
overlaps = Set{AbstractAliasing}()
289+
overlaps = Set{AliasingWrapper}()
290290
push!(overlaps, ainfo)
291291
for other_ainfo in keys(astate.ainfos_owner)
292292
ainfo == other_ainfo && continue
@@ -368,7 +368,7 @@ end
368368

369369
function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps)
370370
astate = state.alias_state
371-
ainfo isa NoAliasing && return
371+
ainfo.inner isa NoAliasing && return
372372
for other_ainfo in astate.ainfos_overlaps[ainfo]
373373
other_task_write_num = astate.ainfos_owner[other_ainfo]
374374
@dagdebug nothing :spawn_datadeps "Considering sync with writer via $ainfo -> $other_ainfo"
@@ -381,7 +381,7 @@ function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::Ab
381381
end
382382
function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps)
383383
astate = state.alias_state
384-
ainfo isa NoAliasing && return
384+
ainfo.inner isa NoAliasing && return
385385
for other_ainfo in astate.ainfos_overlaps[ainfo]
386386
@dagdebug nothing :spawn_datadeps "Considering sync with reader via $ainfo -> $other_ainfo"
387387
other_tasks = astate.ainfos_readers[other_ainfo]
@@ -864,7 +864,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
864864
# in the correct order
865865

866866
# First, find the latest owners of each live ainfo
867-
arg_writes = IdDict{Any,Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}}()
867+
arg_writes = IdDict{Any,Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}}()
868868
for (task, taskdeps) in state.dependencies
869869
for (_, writedep, ainfo, dep_mod, arg) in taskdeps
870870
writedep || continue
@@ -873,7 +873,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
873873

874874
# Skip virtual writes from task result aliasing
875875
# FIXME: Make this less bad
876-
if arg isa DTask && dep_mod === identity && ainfo isa UnknownAliasing
876+
if arg isa DTask && dep_mod === identity && ainfo.inner isa UnknownAliasing
877877
continue
878878
end
879879

@@ -884,7 +884,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
884884
end
885885

886886
# Get the set of writers
887-
ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}, arg_writes, arg)
887+
ainfo_writes = get!(Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}, arg_writes, arg)
888888

889889
#= FIXME: If we fully overlap any writer, evict them
890890
idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes)

src/memory-spaces.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,18 @@ memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `
122122
memory_spans(x) = memory_spans(aliasing(x))
123123
memory_spans(x, T) = memory_spans(aliasing(x, T))
124124

125+
struct AliasingWrapper <: AbstractAliasing
126+
inner::AbstractAliasing
127+
hash::UInt64
128+
129+
AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner))
130+
end
131+
memory_spans(x::AliasingWrapper) = memory_spans(x.inner)
132+
Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h)
133+
Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash
134+
will_alias(x::AliasingWrapper, y::AliasingWrapper) =
135+
will_alias(x.inner, y.inner)
136+
125137
struct NoAliasing <: AbstractAliasing end
126138
memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[]
127139
struct UnknownAliasing <: AbstractAliasing end

0 commit comments

Comments
 (0)