Skip to content

Commit c52b7d6

Browse files
committed
datadeps: AOT scheduling, add JuMP scheduler
1 parent db5eecd commit c52b7d6

File tree

4 files changed

+271
-234
lines changed

4 files changed

+271
-234
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
3434
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3535
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
3636
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
37+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
3738
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
3839
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
3940

@@ -42,6 +43,7 @@ DistributionsExt = "Distributions"
4243
GraphVizExt = "GraphViz"
4344
GraphVizSimpleExt = "Colors"
4445
JSON3Ext = "JSON3"
46+
JuMPExt = "JuMP"
4547
PlotsExt = ["DataFrames", "Plots"]
4648
PythonExt = "PythonCall"
4749

@@ -55,6 +57,7 @@ Distributions = "0.25"
5557
GraphViz = "0.2"
5658
Graphs = "1"
5759
JSON3 = "1"
60+
JuMP = "1"
5861
MacroTools = "0.5"
5962
MemPool = "0.4.11"
6063
OnlineStats = "1"

ext/JuMPExt.jl

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
module JuMPExt
2+
3+
if isdefined(Base, :get_extension)
4+
using JuMP
5+
else
6+
using ..JuMP
7+
end
8+
9+
using Dagger
10+
using Dagger.Distributed
11+
import MetricsTracker as MT
12+
import Graphs: edges, nv, outdegree
13+
14+
struct JuMPScheduler
15+
optimizer
16+
Z::Float64
17+
JuMPScheduler(optimizer) = new(optimizer, 10)
18+
end
19+
function Dagger.datadeps_create_schedule(sched::JuMPScheduler, state, specs_tasks)
20+
astate = state.alias_state
21+
g, task_to_id = astate.g, astate.task_to_id
22+
id_to_task = Dict(id => task for (task, id) in task_to_id)
23+
ntasks = length(specs_tasks)
24+
nprocs = length(state.all_procs)
25+
id_to_proc = Dict(i => p for (i, p) in enumerate(state.all_procs))
26+
27+
# Estimate the time each task will take to execute on each processor,
28+
# and the time it will take to transfer data between processors
29+
task_times = zeros(UInt64, ntasks, nprocs)
30+
xfer_times = zeros(Int, nprocs, nprocs)
31+
lock(MT.GLOBAL_METRICS_CACHE) do cache
32+
for (spec, task) in specs_tasks
33+
id = task_to_id[task]
34+
for p in 1:nprocs
35+
# When searching for a task runtime estimate, we use whatever
36+
# estimate is available and closest if not populated for this processor
37+
# Exact match > same proc type, same node > same proc type, any node > any proc type
38+
39+
sig = Dagger.Sch.signature(spec.f, map(pos_arg->pos_arg[1] => Dagger.unwrap_inout_value(pos_arg[2]), spec.args))
40+
proc = state.all_procs[p]
41+
@warn "Use node, not worker id!" maxlog=1
42+
pid = Dagger.root_worker_id(proc)
43+
44+
# Try exact match
45+
match_on = (MT.LookupExact(Dagger.SignatureMetric(), sig),
46+
MT.LookupExact(Dagger.ProcessorMetric(), proc))
47+
result = MT.cache_lookup(cache, Dagger, :execute!, MT.TimeMetric(), match_on)::Union{UInt64, Nothing}
48+
if result !== nothing
49+
task_times[id, p] = result
50+
continue
51+
end
52+
53+
# Try same proc type, same node
54+
match_on = (MT.LookupExact(Dagger.SignatureMetric(), sig),
55+
MT.LookupSubtype(Dagger.ProcessorMetric(), typeof(proc)),
56+
MT.LookupCustom(Dagger.ProcessorMetric(), other_proc->Dagger.root_worker_id(other_proc)==pid))
57+
result = MT.cache_lookup(cache, Dagger, :execute!, MT.TimeMetric(), match_on)::Union{UInt64, Nothing}
58+
if result !== nothing
59+
task_times[id, p] = result
60+
continue
61+
end
62+
63+
# Try same proc type, any node
64+
match_on = (MT.LookupExact(Dagger.SignatureMetric(), sig),
65+
MT.LookupSubtype(Dagger.ProcessorMetric(), typeof(proc)))
66+
result = MT.cache_lookup(cache, Dagger, :execute!, MT.TimeMetric(), match_on)::Union{UInt64, Nothing}
67+
if result !== nothing
68+
task_times[id, p] = result
69+
continue
70+
end
71+
72+
# Try any signature match
73+
match_on = MT.LookupExact(Dagger.SignatureMetric(), sig)
74+
result = MT.cache_lookup(cache, Dagger, :execute!, MT.TimeMetric(), match_on)::Union{UInt64, Nothing}
75+
if result !== nothing
76+
task_times[id, p] = result
77+
continue
78+
end
79+
80+
# If no information is available, use a random guess
81+
task_times[id, p] = UInt64(rand(1:1_000_000))
82+
end
83+
end
84+
85+
# FIXME: Actually fill this with estimated xfer times
86+
@warn "Assuming all xfer times are 1" maxlog=1
87+
for dst in 1:nprocs
88+
for src in 1:nprocs
89+
if src == dst # FIXME: Or if space is shared
90+
xfer_times[src, dst] = 0
91+
else
92+
# FIXME: sum(currently non-local task arg size) / xfer_speed
93+
xfer_times[src, dst] = 1
94+
end
95+
end
96+
end
97+
end
98+
99+
@warn "If no edges exist, this will fail" maxlog=1
100+
γ = Dict{Tuple{Int, Int}, Matrix{Int}}()
101+
for (i, j) in Tuple.(edges(g))
102+
γ[(i, j)] = copy(xfer_times)
103+
end
104+
105+
a_kls = Tuple.(edges(g))
106+
m = Model(sched.optimizer)
107+
JuMP.set_silent(m)
108+
109+
# Start time of each task
110+
@variable(m, t[1:ntasks] >= 0)
111+
# End time of last task
112+
@variable(m, t_last_end >= 0)
113+
114+
# 1 if task k is assigned to proc p
115+
@variable(m, s[1:ntasks, 1:nprocs], Bin)
116+
117+
# Each task is assigned to exactly one processor
118+
@constraint(m, [k in 1:ntasks], sum(s[k, :]) == 1)
119+
120+
# Penalties for moving between procs
121+
if length(a_kls) > 0
122+
@variable(m, p[a_kls] >= 0)
123+
124+
for (k, l) in a_kls
125+
for p1 in 1:nprocs
126+
for p2 in 1:nprocs
127+
p1 == p2 && continue
128+
# Task l occurs after task k if the procs are different,
129+
# thus there is a penalty
130+
@constraint(m, p[(k, l)] >= (s[k, p1] + s[l, p2] - 1) * γ[(k, l)][p1, p2])
131+
end
132+
end
133+
134+
# Task l occurs after task k
135+
@constraint(m, t[k] + task_times[k, :]' * s[k, :] + p[(k, l)] <= t[l])
136+
end
137+
else
138+
@variable(m, p >= 0)
139+
end
140+
141+
for l in filter(n -> outdegree(g, n) == 0, 1:nv(g))
142+
# DAG ends after the last task
143+
@constraint(m, t[l] + task_times[l, :]' * s[l, :] <= t_last_end)
144+
end
145+
146+
# Minimize the total runtime of the DAG
147+
# TODO: Do we need to bias towards earlier start times?
148+
@objective(m, Min, sched.Z*t_last_end + sum(t) .+ sum(p))
149+
150+
# Solve the model
151+
optimize!(m)
152+
153+
# Extract the schedule from the model
154+
task_to_proc = Dict{DTask, Dagger.Processor}()
155+
for k in 1:ntasks
156+
proc_id = findfirst(identity, value.(s[k, :]) .== 1)
157+
task_to_proc[id_to_task[k]] = id_to_proc[proc_id]
158+
end
159+
160+
return task_to_proc
161+
end
162+
163+
end # module JuMPExt

0 commit comments

Comments
 (0)