Skip to content

Commit 82df8f5

Browse files
committed
datadeps: Support within-DAG deps for AOT scheduler
1 parent 2888fa4 commit 82df8f5

File tree

1 file changed

+66
-24
lines changed

1 file changed

+66
-24
lines changed

src/datadeps.jl

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,40 +96,61 @@ struct DAGSpec
9696
Dict{Int, Type}(),
9797
Dict{Int, Vector{DatadepsArgSpec}}())
9898
end
99-
function Base.push!(dspec::DAGSpec, tspec::DTaskSpec, task::DTask)
99+
function dag_add_task!(dspec::DAGSpec, astate, tspec::DTaskSpec, task::DTask)
100+
# Check if this task depends on any other tasks within the DAG,
101+
# which we are not yet ready to handle
102+
for (idx, (kwpos, arg)) in enumerate(tspec.args)
103+
arg, deps = unwrap_inout(arg)
104+
pos = kwpos isa Symbol ? kwpos : idx
105+
for (dep_mod, readdep, writedep) in deps
106+
if arg isa DTask
107+
if arg.uid in keys(dspec.uid_to_id)
108+
# Within-DAG dependency, bail out
109+
return false
110+
end
111+
end
112+
end
113+
end
114+
100115
add_vertex!(dspec.g)
101116
id = nv(dspec.g)
102117

118+
# Record function signature
103119
dspec.id_to_functype[id] = typeof(tspec.f)
104-
105-
dspec.id_to_argtypes[id] = DatadepsArgSpec[]
120+
argtypes = DatadepsArgSpec[]
106121
for (idx, (kwpos, arg)) in enumerate(tspec.args)
107122
arg, deps = unwrap_inout(arg)
108123
pos = kwpos isa Symbol ? kwpos : idx
109124
for (dep_mod, readdep, writedep) in deps
110125
if arg isa DTask
126+
#= TODO: Re-enable this when we can handle within-DAG dependencies
111127
if arg.uid in keys(dspec.uid_to_id)
112128
# Within-DAG dependency
113129
arg_id = dspec.uid_to_id[arg.uid]
114130
push!(dspec.id_to_argtypes[arg_id], DatadepsArgSpec(pos, DTaskDAGID{arg_id}, dep_mod, UnknownAliasing()))
115131
add_edge!(dspec.g, arg_id, id)
116132
continue
117133
end
134+
=#
118135

119136
# External DTask, so fetch this and track it as a raw value
120137
arg = fetch(arg; raw=true)
121138
end
122-
ainfo = aliasing(arg, dep_mod)
123-
push!(dspec.id_to_argtypes[id], DatadepsArgSpec(pos, typeof(arg), dep_mod, ainfo))
139+
ainfo = aliasing(astate, arg, dep_mod)
140+
push!(argtypes, DatadepsArgSpec(pos, typeof(arg), dep_mod, ainfo))
124141
end
125142
end
143+
dspec.id_to_argtypes[id] = argtypes
126144

127145
# FIXME: Also record some portion of options
128146
# FIXME: Record syncdeps
129147
dspec.id_to_uid[id] = task.uid
130148
dspec.uid_to_id[task.uid] = id
131149

132-
return
150+
return true
151+
end
152+
function dag_has_task(dspec::DAGSpec, task::DTask)
153+
return task.uid in keys(dspec.uid_to_id)
133154
end
134155
function Base.:(==)(dspec1::DAGSpec, dspec2::DAGSpec)
135156
# Are the graphs the same size?
@@ -159,10 +180,7 @@ function Base.:(==)(dspec1::DAGSpec, dspec2::DAGSpec)
159180
# Are the arguments the same?
160181
argspec1.value_type === argspec2.value_type || return false
161182
argspec1.dep_mod === argspec2.dep_mod || return false
162-
if !equivalent_structure(argspec1.ainfo, argspec2.ainfo)
163-
@show argspec1.ainfo argspec2.ainfo
164-
return false
165-
end
183+
equivalent_structure(argspec1.ainfo, argspec2.ainfo) || return false
166184
end
167185
end
168186

@@ -454,7 +472,7 @@ function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, t
454472
astate.data_locality[task] = space
455473
astate.data_origin[task] = space
456474
end
457-
function clear_ainfo_owner_readers!(astate::DataDepsAliasingState)
475+
function reset_ainfo_owner_readers!(astate::DataDepsAliasingState)
458476
for ainfo in keys(astate.ainfos_owner)
459477
astate.ainfos_owner[ainfo] = nothing
460478
empty!(astate.ainfos_readers[ainfo])
@@ -621,24 +639,26 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
621639
@warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1
622640
end
623641

624-
# Round-robin assign tasks to processors
625642
upper_queue = get_options(:task_queue)
626643

627644
state = DataDepsState(queue.aliasing, all_procs)
628645
astate = state.alias_state
629646

630647
schedule = Dict{DTask, Processor}()
631648

632-
if DATADEPS_SCHEDULE_REUSABLE[]
633-
# Compute DAG spec
634-
for (spec, task) in queue.seen_tasks
635-
push!(state.dag_spec, spec, task)
649+
# Compute DAG spec
650+
for (spec, task) in queue.seen_tasks
651+
if !dag_add_task!(state.dag_spec, astate, spec, task)
652+
# This task needs to be deferred
653+
break
636654
end
655+
end
637656

657+
if DATADEPS_SCHEDULE_REUSABLE[]
638658
# Find any matching DAG specs and reuse their schedule
639659
for (other_spec, spec_schedule) in DAG_SPECS
640660
if other_spec == state.dag_spec
641-
@info "Found matching DAG spec!"
661+
@dagdebug nothing :spawn_datadeps "Found matching DAG spec!"
642662
#spec_schedule = DAG_SCHEDULE_CACHE[other_spec]
643663
schedule = Dict{DTask, Processor}()
644664
for (id, proc) in spec_schedule.id_to_proc
@@ -654,13 +674,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
654674

655675
# Populate all task dependencies
656676
write_num = 1
677+
task_num = 0
657678
for (spec, task) in queue.seen_tasks
679+
if !dag_has_task(state.dag_spec, task)
680+
# This task needs to be deferred
681+
break
682+
end
658683
write_num = populate_task_info!(state, spec, task, write_num)
684+
task_num += 1
659685
end
686+
@assert task_num > 0
660687

661688
if isempty(schedule)
662689
# Run AOT scheduling
663-
schedule = datadeps_create_schedule(queue.scheduler, state, queue.seen_tasks)::Dict{DTask, Processor}
690+
schedule = datadeps_create_schedule(queue.scheduler, state, queue.seen_tasks[1:task_num])::Dict{DTask, Processor}
664691

665692
if DATADEPS_SCHEDULE_REUSABLE[]
666693
# Cache the schedule
@@ -674,12 +701,17 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
674701
end
675702
end
676703

677-
# Clear out ainfo database (will be repopulated during task execution)
678-
clear_ainfo_owner_readers!(astate)
704+
# Reset ainfo database (will be repopulated during task execution)
705+
reset_ainfo_owner_readers!(astate)
679706

680707
# Launch tasks and necessary copies
681708
write_num = 1
682709
for (spec, task) in queue.seen_tasks
710+
if !dag_has_task(state.dag_spec, task)
711+
# This task needs to be deferred
712+
break
713+
end
714+
683715
our_proc = schedule[task]
684716
@assert our_proc in all_procs
685717
our_space = only(memory_spaces(our_proc))
@@ -829,6 +861,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
829861
write_num += 1
830862
end
831863

864+
# Remove processed tasks
865+
deleteat!(queue.seen_tasks, 1:task_num)
866+
832867
# Copy args from remote to local
833868
if queue.aliasing
834869
# We need to replay the writes from all tasks in-order (skipping any
@@ -961,18 +996,25 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
961996
wait_all(; check_errors=true) do
962997
scheduler = something(scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler())
963998
launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool
999+
local result
9641000
if launch_wait
965-
result = spawn_bulk() do
1001+
spawn_bulk() do
9661002
queue = DataDepsTaskQueue(get_options(:task_queue);
9671003
scheduler, aliasing)
968-
with_options(f; task_queue=queue)
969-
distribute_tasks!(queue)
1004+
result = with_options(f; task_queue=queue)
1005+
while !isempty(queue.seen_tasks)
1006+
@dagdebug nothing :spawn_datadeps "Entering Datadeps region"
1007+
distribute_tasks!(queue)
1008+
end
9701009
end
9711010
else
9721011
queue = DataDepsTaskQueue(get_options(:task_queue);
9731012
scheduler, aliasing)
9741013
result = with_options(f; task_queue=queue)
975-
distribute_tasks!(queue)
1014+
while !isempty(queue.seen_tasks)
1015+
@dagdebug nothing :spawn_datadeps "Entering Datadeps region"
1016+
distribute_tasks!(queue)
1017+
end
9761018
end
9771019
return result
9781020
end

0 commit comments

Comments
 (0)