@@ -96,40 +96,61 @@ struct DAGSpec
96
96
Dict {Int, Type} (),
97
97
Dict {Int, Vector{DatadepsArgSpec}} ())
98
98
end
99
- function Base. push! (dspec:: DAGSpec , tspec:: DTaskSpec , task:: DTask )
99
+ function dag_add_task! (dspec:: DAGSpec , astate, tspec:: DTaskSpec , task:: DTask )
100
+ # Check if this task depends on any other tasks within the DAG,
101
+ # which we are not yet ready to handle
102
+ for (idx, (kwpos, arg)) in enumerate (tspec. args)
103
+ arg, deps = unwrap_inout (arg)
104
+ pos = kwpos isa Symbol ? kwpos : idx
105
+ for (dep_mod, readdep, writedep) in deps
106
+ if arg isa DTask
107
+ if arg. uid in keys (dspec. uid_to_id)
108
+ # Within-DAG dependency, bail out
109
+ return false
110
+ end
111
+ end
112
+ end
113
+ end
114
+
100
115
add_vertex! (dspec. g)
101
116
id = nv (dspec. g)
102
117
118
+ # Record function signature
103
119
dspec. id_to_functype[id] = typeof (tspec. f)
104
-
105
- dspec. id_to_argtypes[id] = DatadepsArgSpec[]
120
+ argtypes = DatadepsArgSpec[]
106
121
for (idx, (kwpos, arg)) in enumerate (tspec. args)
107
122
arg, deps = unwrap_inout (arg)
108
123
pos = kwpos isa Symbol ? kwpos : idx
109
124
for (dep_mod, readdep, writedep) in deps
110
125
if arg isa DTask
126
+ #= TODO : Re-enable this when we can handle within-DAG dependencies
111
127
if arg.uid in keys(dspec.uid_to_id)
112
128
# Within-DAG dependency
113
129
arg_id = dspec.uid_to_id[arg.uid]
114
130
push!(dspec.id_to_argtypes[arg_id], DatadepsArgSpec(pos, DTaskDAGID{arg_id}, dep_mod, UnknownAliasing()))
115
131
add_edge!(dspec.g, arg_id, id)
116
132
continue
117
133
end
134
+ =#
118
135
119
136
# External DTask, so fetch this and track it as a raw value
120
137
arg = fetch (arg; raw= true )
121
138
end
122
- ainfo = aliasing (arg, dep_mod)
123
- push! (dspec . id_to_argtypes[id] , DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
139
+ ainfo = aliasing (astate, arg, dep_mod)
140
+ push! (argtypes , DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
124
141
end
125
142
end
143
+ dspec. id_to_argtypes[id] = argtypes
126
144
127
145
# FIXME : Also record some portion of options
128
146
# FIXME : Record syncdeps
129
147
dspec. id_to_uid[id] = task. uid
130
148
dspec. uid_to_id[task. uid] = id
131
149
132
- return
150
+ return true
151
+ end
152
+ function dag_has_task (dspec:: DAGSpec , task:: DTask )
153
+ return task. uid in keys (dspec. uid_to_id)
133
154
end
134
155
function Base.:(== )(dspec1:: DAGSpec , dspec2:: DAGSpec )
135
156
# Are the graphs the same size?
@@ -159,10 +180,7 @@ function Base.:(==)(dspec1::DAGSpec, dspec2::DAGSpec)
159
180
# Are the arguments the same?
160
181
argspec1. value_type === argspec2. value_type || return false
161
182
argspec1. dep_mod === argspec2. dep_mod || return false
162
- if ! equivalent_structure (argspec1. ainfo, argspec2. ainfo)
163
- @show argspec1. ainfo argspec2. ainfo
164
- return false
165
- end
183
+ equivalent_structure (argspec1. ainfo, argspec2. ainfo) || return false
166
184
end
167
185
end
168
186
@@ -454,7 +472,7 @@ function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, t
454
472
astate. data_locality[task] = space
455
473
astate. data_origin[task] = space
456
474
end
457
- function clear_ainfo_owner_readers ! (astate:: DataDepsAliasingState )
475
+ function reset_ainfo_owner_readers ! (astate:: DataDepsAliasingState )
458
476
for ainfo in keys (astate. ainfos_owner)
459
477
astate. ainfos_owner[ainfo] = nothing
460
478
empty! (astate. ainfos_readers[ainfo])
@@ -621,24 +639,26 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
621
639
@warn " Datadeps support for multi-GPU, multi-worker is currently broken\n Please be prepared for incorrect results or errors" maxlog= 1
622
640
end
623
641
624
- # Round-robin assign tasks to processors
625
642
upper_queue = get_options (:task_queue )
626
643
627
644
state = DataDepsState (queue. aliasing, all_procs)
628
645
astate = state. alias_state
629
646
630
647
schedule = Dict {DTask, Processor} ()
631
648
632
- if DATADEPS_SCHEDULE_REUSABLE[]
633
- # Compute DAG spec
634
- for (spec, task) in queue. seen_tasks
635
- push! (state. dag_spec, spec, task)
649
+ # Compute DAG spec
650
+ for (spec, task) in queue. seen_tasks
651
+ if ! dag_add_task! (state. dag_spec, astate, spec, task)
652
+ # This task needs to be deferred
653
+ break
636
654
end
655
+ end
637
656
657
+ if DATADEPS_SCHEDULE_REUSABLE[]
638
658
# Find any matching DAG specs and reuse their schedule
639
659
for (other_spec, spec_schedule) in DAG_SPECS
640
660
if other_spec == state. dag_spec
641
- @info " Found matching DAG spec!"
661
+ @dagdebug nothing :spawn_datadeps " Found matching DAG spec!"
642
662
# spec_schedule = DAG_SCHEDULE_CACHE[other_spec]
643
663
schedule = Dict {DTask, Processor} ()
644
664
for (id, proc) in spec_schedule. id_to_proc
@@ -654,13 +674,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
654
674
655
675
# Populate all task dependencies
656
676
write_num = 1
677
+ task_num = 0
657
678
for (spec, task) in queue. seen_tasks
679
+ if ! dag_has_task (state. dag_spec, task)
680
+ # This task needs to be deferred
681
+ break
682
+ end
658
683
write_num = populate_task_info! (state, spec, task, write_num)
684
+ task_num += 1
659
685
end
686
+ @assert task_num > 0
660
687
661
688
if isempty (schedule)
662
689
# Run AOT scheduling
663
- schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks):: Dict{DTask, Processor}
690
+ schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks[ 1 : task_num] ):: Dict{DTask, Processor}
664
691
665
692
if DATADEPS_SCHEDULE_REUSABLE[]
666
693
# Cache the schedule
@@ -674,12 +701,17 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
674
701
end
675
702
end
676
703
677
- # Clear out ainfo database (will be repopulated during task execution)
678
- clear_ainfo_owner_readers ! (astate)
704
+ # Reset ainfo database (will be repopulated during task execution)
705
+ reset_ainfo_owner_readers ! (astate)
679
706
680
707
# Launch tasks and necessary copies
681
708
write_num = 1
682
709
for (spec, task) in queue. seen_tasks
710
+ if ! dag_has_task (state. dag_spec, task)
711
+ # This task needs to be deferred
712
+ break
713
+ end
714
+
683
715
our_proc = schedule[task]
684
716
@assert our_proc in all_procs
685
717
our_space = only (memory_spaces (our_proc))
@@ -829,6 +861,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
829
861
write_num += 1
830
862
end
831
863
864
+ # Remove processed tasks
865
+ deleteat! (queue. seen_tasks, 1 : task_num)
866
+
832
867
# Copy args from remote to local
833
868
if queue. aliasing
834
869
# We need to replay the writes from all tasks in-order (skipping any
@@ -961,18 +996,25 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
961
996
wait_all (; check_errors= true ) do
962
997
scheduler = something (scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler ())
963
998
launch_wait = something (launch_wait, DATADEPS_LAUNCH_WAIT[], false ):: Bool
999
+ local result
964
1000
if launch_wait
965
- result = spawn_bulk () do
1001
+ spawn_bulk () do
966
1002
queue = DataDepsTaskQueue (get_options (:task_queue );
967
1003
scheduler, aliasing)
968
- with_options (f; task_queue= queue)
969
- distribute_tasks! (queue)
1004
+ result = with_options (f; task_queue= queue)
1005
+ while ! isempty (queue. seen_tasks)
1006
+ @dagdebug nothing :spawn_datadeps " Entering Datadeps region"
1007
+ distribute_tasks! (queue)
1008
+ end
970
1009
end
971
1010
else
972
1011
queue = DataDepsTaskQueue (get_options (:task_queue );
973
1012
scheduler, aliasing)
974
1013
result = with_options (f; task_queue= queue)
975
- distribute_tasks! (queue)
1014
+ while ! isempty (queue. seen_tasks)
1015
+ @dagdebug nothing :spawn_datadeps " Entering Datadeps region"
1016
+ distribute_tasks! (queue)
1017
+ end
976
1018
end
977
1019
return result
978
1020
end
0 commit comments