Skip to content

Commit e33216c

Browse files
committed
datadeps: Add logic to compare similar DAGs
1 parent c52b7d6 commit e33216c

File tree

2 files changed

+195
-7
lines changed

2 files changed

+195
-7
lines changed

src/datadeps.jl

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Graphs: SimpleDiGraph, nv, add_edge!, add_vertex!
1+
import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv, ne
22

33
export In, Out, InOut, Deps, spawn_datadeps
44

@@ -78,6 +78,107 @@ function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}
7878
append!(queue.seen_tasks, specs)
7979
end
8080

81+
struct DatadepsArgSpec
82+
pos::Union{Int, Symbol}
83+
value_type::Type
84+
dep_mod::Any
85+
ainfo::AbstractAliasing
86+
end
87+
struct DTaskDAGID{id} end
88+
struct DAGSpec
89+
g::SimpleDiGraph{Int}
90+
id_to_uid::Dict{Int, UInt}
91+
uid_to_id::Dict{UInt, Int}
92+
id_to_functype::Dict{Int, Type} # FIXME: DatadepsArgSpec
93+
id_to_argtypes::Dict{Int, Vector{DatadepsArgSpec}}
94+
DAGSpec() = new(SimpleDiGraph{Int}(),
95+
Dict{Int, UInt}(), Dict{UInt, Int}(),
96+
Dict{Int, Type}(),
97+
Dict{Int, Vector{DatadepsArgSpec}}())
98+
end
99+
function Base.push!(dspec::DAGSpec, tspec::DTaskSpec, task::DTask)
100+
add_vertex!(dspec.g)
101+
id = nv(dspec.g)
102+
103+
dspec.id_to_functype[id] = typeof(tspec.f)
104+
105+
dspec.id_to_argtypes[id] = DatadepsArgSpec[]
106+
for (idx, (kwpos, arg)) in enumerate(tspec.args)
107+
arg, deps = unwrap_inout(arg)
108+
pos = kwpos isa Symbol ? kwpos : idx
109+
for (dep_mod, readdep, writedep) in deps
110+
if arg isa DTask
111+
if arg.uid in keys(dspec.uid_to_id)
112+
# Within-DAG dependency
113+
arg_id = dspec.uid_to_id[arg.uid]
114+
push!(dspec.id_to_argtypes[arg_id], DatadepsArgSpec(pos, DTaskDAGID{arg_id}, dep_mod, UnknownAliasing()))
115+
add_edge!(dspec.g, arg_id, id)
116+
continue
117+
end
118+
119+
# External DTask, so fetch this and track it as a raw value
120+
arg = fetch(arg; raw=true)
121+
end
122+
ainfo = aliasing(arg, dep_mod)
123+
push!(dspec.id_to_argtypes[id], DatadepsArgSpec(pos, typeof(arg), dep_mod, ainfo))
124+
end
125+
end
126+
127+
# FIXME: Also record some portion of options
128+
# FIXME: Record syncdeps
129+
dspec.id_to_uid[id] = task.uid
130+
dspec.uid_to_id[task.uid] = id
131+
132+
return
133+
end
134+
function Base.:(==)(dspec1::DAGSpec, dspec2::DAGSpec)
135+
# Are the graphs the same size?
136+
nv(dspec1.g) == nv(dspec2.g) || return false
137+
ne(dspec1.g) == ne(dspec2.g) || return false
138+
139+
for id in 1:nv(dspec1.g)
140+
# Are all the vertices the same?
141+
id in keys(dspec2.id_to_uid) || return false
142+
id in keys(dspec2.id_to_functype) || return false
143+
id in keys(dspec2.id_to_argtypes) || return false
144+
145+
# Are all the edges the same?
146+
inneighbors(dspec1.g, id) == inneighbors(dspec2.g, id) || return false
147+
outneighbors(dspec1.g, id) == outneighbors(dspec2.g, id) || return false
148+
149+
# Are function types the same?
150+
dspec1.id_to_functype[id] === dspec2.id_to_functype[id] || return false
151+
152+
# Are argument types/relative dependencies the same?
153+
for argspec1 in dspec1.id_to_argtypes[id]
154+
# Is this argument position present in both?
155+
argspec2_idx = findfirst(argspec2->argspec1.pos == argspec2.pos, dspec2.id_to_argtypes[id])
156+
argspec2_idx === nothing && return false
157+
argspec2 = dspec2.id_to_argtypes[id][argspec2_idx]
158+
159+
# Are the arguments the same?
160+
argspec1.value_type === argspec2.value_type || return false
161+
argspec1.dep_mod === argspec2.dep_mod || return false
162+
if !equivalent_structure(argspec1.ainfo, argspec2.ainfo)
163+
@show argspec1.ainfo argspec2.ainfo
164+
return false
165+
end
166+
end
167+
end
168+
169+
return true
170+
end
171+
172+
struct DAGSpecSchedule
173+
id_to_proc::Dict{Int, Processor}
174+
DAGSpecSchedule() = new(Dict{Int, Processor}())
175+
end
176+
177+
#const DAG_SPECS = Vector{DAGSpec}()
178+
const DAG_SPECS = Vector{Pair{DAGSpec, DAGSpecSchedule}}()
179+
180+
#const DAG_SCHEDULE_CACHE = Dict{DAGSpec, DAGSpecSchedule}()
181+
81182
struct DataDepsAliasingState
82183
# Track original and current data locations
83184
# We track data => space
@@ -152,6 +253,9 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
152253
# The aliasing analysis state
153254
alias_state::State
154255

256+
# The DAG specification
257+
dag_spec::DAGSpec
258+
155259
function DataDepsState(aliasing::Bool, all_procs::Vector{Processor})
156260
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[]
157261
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
@@ -160,7 +264,8 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
160264
else
161265
state = DataDepsNonAliasingState()
162266
end
163-
return new{typeof(state)}(aliasing, all_procs, dependencies, remote_args, state)
267+
spec = DAGSpec()
268+
return new{typeof(state)}(aliasing, all_procs, dependencies, remote_args, state, spec)
164269
end
165270
end
166271

@@ -522,18 +627,54 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
522627
state = DataDepsState(queue.aliasing, all_procs)
523628
astate = state.alias_state
524629

630+
schedule = Dict{DTask, Processor}()
631+
632+
if DATADEPS_SCHEDULE_REUSABLE[]
633+
# Compute DAG spec
634+
for (spec, task) in queue.seen_tasks
635+
push!(state.dag_spec, spec, task)
636+
end
637+
638+
# Find any matching DAG specs and reuse their schedule
639+
for (other_spec, spec_schedule) in DAG_SPECS
640+
if other_spec == state.dag_spec
641+
@info "Found matching DAG spec!"
642+
#spec_schedule = DAG_SCHEDULE_CACHE[other_spec]
643+
schedule = Dict{DTask, Processor}()
644+
for (id, proc) in spec_schedule.id_to_proc
645+
uid = state.dag_spec.id_to_uid[id]
646+
task_idx = findfirst(spec_task -> spec_task[2].uid == uid, queue.seen_tasks)
647+
task = queue.seen_tasks[task_idx][2]
648+
schedule[task] = proc
649+
end
650+
break
651+
end
652+
end
653+
end
654+
525655
# Populate all task dependencies
526656
write_num = 1
527657
for (spec, task) in queue.seen_tasks
528658
write_num = populate_task_info!(state, spec, task, write_num)
529659
end
530660

531-
# AOT scheduling
532-
schedule = datadeps_create_schedule(queue.scheduler, state, queue.seen_tasks)::Dict{DTask, Processor}
533-
for (spec, task) in queue.seen_tasks
534-
println("Task $(spec.f) scheduled on $(schedule[task])")
661+
if isempty(schedule)
662+
# Run AOT scheduling
663+
schedule = datadeps_create_schedule(queue.scheduler, state, queue.seen_tasks)::Dict{DTask, Processor}
664+
665+
if DATADEPS_SCHEDULE_REUSABLE[]
666+
# Cache the schedule
667+
spec_schedule = DAGSpecSchedule()
668+
for (task, proc) in schedule
669+
id = state.dag_spec.uid_to_id[task.uid]
670+
spec_schedule.id_to_proc[id] = proc
671+
end
672+
#DAG_SCHEDULE_CACHE[state.dag_spec] = spec_schedule
673+
push!(DAG_SPECS, state.dag_spec => spec_schedule)
674+
end
535675
end
536676

677+
# Clear out ainfo database (will be repopulated during task execution)
537678
clear_ainfo_owner_readers!(astate)
538679

539680
# Launch tasks and necessary copies
@@ -556,7 +697,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
556697
# Is the data written previously or now?
557698
arg, deps = unwrap_inout(arg)
558699
arg = arg isa DTask ? fetch(arg; raw=true) : arg
559-
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task)
700+
if !type_may_alias(typeof(arg))
560701
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)"
561702
spec.args[idx] = pos => arg
562703
continue
@@ -837,4 +978,5 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
837978
end
838979
end
839980
const DATADEPS_SCHEDULER = ScopedValue{Any}(nothing)
981+
const DATADEPS_SCHEDULE_REUSABLE = ScopedValue{Bool}(true)
840982
const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing)

src/memory-spaces.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ memory_spans(x, T) = memory_spans(aliasing(x, T))
124124

125125
struct NoAliasing <: AbstractAliasing end
126126
memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[]
127+
equivalent_structure(::NoAliasing, ::NoAliasing) = true
127128
struct UnknownAliasing <: AbstractAliasing end
128129
memory_spans(::UnknownAliasing) = [MemorySpan{CPURAMMemorySpace}(C_NULL, typemax(UInt))]
130+
equivalent_structure(::UnknownAliasing, ::UnknownAliasing) = true
129131

130132
warn_unknown_aliasing(T) =
131133
@warn "Cannot resolve aliasing for object of type $T\nExecution may become sequential"
@@ -141,6 +143,18 @@ function memory_spans(ca::CombinedAliasing)
141143
end
142144
return all_spans
143145
end
146+
function equivalent_structure(ainfo1::CombinedAliasing,
147+
ainfo2::CombinedAliasing)
148+
for sub_ainfo1 in ainfo1.sub_ainfos
149+
for sub_ainfo2 in ainfo2.sub_ainfos
150+
if equivalent_structure(sub_ainfo1, sub_ainfo2)
151+
break
152+
end
153+
end
154+
return false
155+
end
156+
return true
157+
end
144158
Base.:(==)(ca1::CombinedAliasing, ca2::CombinedAliasing) =
145159
ca1.sub_ainfos == ca2.sub_ainfos
146160
Base.hash(ca1::CombinedAliasing, h::UInt) =
@@ -161,6 +175,10 @@ function memory_spans(oa::ObjectAliasing)
161175
span = MemorySpan{CPURAMMemorySpace}(rptr, oa.sz)
162176
return [span]
163177
end
178+
function equivalent_structure(ainfo1::ObjectAliasing,
179+
ainfo2::ObjectAliasing)
180+
return ainfo1.sz == ainfo2.sz
181+
end
164182

165183
aliasing(x, T) = aliasing(T(x))
166184
function aliasing(x::T) where T
@@ -221,6 +239,10 @@ function aliasing(x::Array{T}) where T
221239
end
222240
aliasing(x::Transpose) = aliasing(parent(x))
223241
aliasing(x::Adjoint) = aliasing(parent(x))
242+
function equivalent_structure(ainfo1::ContiguousAliasing{S},
243+
ainfo2::ContiguousAliasing{S}) where {S}
244+
return ainfo1.span.len == ainfo2.span.len
245+
end
224246

225247
struct StridedAliasing{T,N,S} <: AbstractAliasing
226248
base_ptr::RemotePtr{Cvoid,S}
@@ -279,6 +301,12 @@ function will_alias(x::StridedAliasing{T,N,S}, y::StridedAliasing{T,N,S}) where
279301
return true
280302
end
281303
# FIXME: Upgrade Contiguous/StridedAlising to same number of dims
304+
function equivalent_structure(ainfo1::StridedAliasing{T,N,S},
305+
ainfo2::StridedAliasing{T,N,S}) where {T,N,S}
306+
return ainfo1.base_inds == ainfo2.base_inds &&
307+
ainfo1.lengths == ainfo2.lengths &&
308+
ainfo1.strides == ainfo2.strides
309+
end
282310

283311
struct TriangularAliasing{T,S} <: AbstractAliasing
284312
ptr::RemotePtr{Cvoid,S}
@@ -311,6 +339,12 @@ aliasing(x::UnitUpperTriangular{T}) where T =
311339
TriangularAliasing{T,CPURAMMemorySpace}(pointer(parent(x)), size(parent(x), 1), true, false)
312340
aliasing(x::UnitLowerTriangular{T}) where T =
313341
TriangularAliasing{T,CPURAMMemorySpace}(pointer(parent(x)), size(parent(x), 1), false, false)
342+
function equivalent_structure(ainfo1::TriangularAliasing{T,S},
343+
ainfo2::TriangularAliasing{T,S}) where {T,S}
344+
return ainfo1.stride == ainfo2.stride &&
345+
ainfo1.isupper == ainfo2.isupper &&
346+
ainfo1.diagonal == ainfo2.diagonal
347+
end
314348

315349
struct DiagonalAliasing{T,S} <: AbstractAliasing
316350
ptr::RemotePtr{Cvoid,S}
@@ -331,6 +365,10 @@ function aliasing(x::AbstractMatrix{T}, ::Type{Diagonal}) where T
331365
rptr = RemotePtr{Cvoid}(ptr, S)
332366
return DiagonalAliasing{T,typeof(S)}(rptr, size(parent(x), 1))
333367
end
368+
function equivalent_structure(ainfo1::DiagonalAliasing{T,S},
369+
ainfo2::DiagonalAliasing{T,S}) where {T,S}
370+
return ainfo1.stride == ainfo2.stride
371+
end
334372
# FIXME: Bidiagonal
335373
# FIXME: Tridiagonal
336374

@@ -368,3 +406,11 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan)
368406
y_end = y_span.ptr + y_span.len - 1
369407
return x_span.ptr <= y_end && y_span.ptr <= x_end
370408
end
409+
410+
"""
411+
equivalent_structure(ainfo1::AbstractAliasing, ainfo2::AbstractAliasing) -> Bool
412+
413+
Returns `true` when `ainfo1` and `ainfo2` represent objects with the same
414+
memory structure, ignoring the specific memory addresses; otherwise, `false`.
415+
"""
416+
equivalent_structure(ainfo1::AbstractAliasing, ainfo2::AbstractAliasing) = false

0 commit comments

Comments
 (0)