Skip to content

Commit bfbd494

Browse files
committed
batch-spanning chunking in threaded execution
1 parent f0bd1e2 commit bfbd494

File tree

4 files changed

+117
-9
lines changed

4 files changed

+117
-9
lines changed

src/construction.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,14 @@ function Network(g::AbstractGraph,
183183
# create map for extenral inputs
184184
extmap = has_external_input(im) ? ExtMap(im) : nothing
185185

186-
nw = Network{typeof(execution),typeof(g),typeof(nl), typeof(vertexbatches),
187-
typeof(mass_matrix),eltype(caches),typeof(gbufprovider),typeof(extmap)}(
186+
nw = Network(
188187
vertexbatches,
189188
nl, im,
190189
caches,
191190
mass_matrix,
192191
gbufprovider,
193192
extmap,
193+
execution,
194194
)
195195

196196
end

src/coreloop.jl

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,116 @@ end
8888
end
8989
end
9090

91-
@inline function process_batches!(::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F}
92-
unrolled_foreach(filt, batches) do batch
91+
@inline function process_batches!(ex::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F}
92+
# unrolled_foreach(filt, batches) do batch
93+
# (du, u, o, p, t) = duopt
94+
# Threads.@threads for i in 1:length(batch)
95+
# _type = dispatchT(batch)
96+
# apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t)
97+
# end
98+
# end
99+
# return
100+
101+
Nchunks = Threads.nthreads()
102+
# Nchunks = 4
103+
# chunking is kinda expensive, so we cache it
104+
key = hash((Base.objectid(batches), filt, fg, Nchunks))
105+
chunks = get!(ex.chunk_cache, key) do
106+
_chunk_batches(batches, filt, fg, Nchunks)
107+
end
108+
109+
_progress_in_batch = function(batch, ci, processed, N)
93110
(du, u, o, p, t) = duopt
94-
Threads.@threads for i in 1:length(batch)
95-
_type = dispatchT(batch)
96-
apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t)
111+
_type = dispatchT(batch)
112+
while ci length(batch) && processed < N
113+
apply_comp!(_type, fg, batch, ci, du, u, o, inbufs, p, t)
114+
ci += 1
115+
processed += 1
116+
end
117+
processed, ci
118+
end
119+
120+
Threads.@sync for chunk in chunks
121+
chunk.N == 0 && continue
122+
Threads.@spawn begin
123+
local N = chunk.N
124+
local bi = chunk.batch_i
125+
local ci = chunk.comp_i
126+
local processed = 0
127+
while processed < N
128+
batch = batches[bi]
129+
filt(batch) || continue
130+
processed, ci = @noinline _progress_in_batch(batch, ci, processed, N)
131+
bi += 1
132+
ci = 1
133+
end
134+
end
135+
end
136+
end
137+
function _chunk_batches(batches, filt, fg, workers)
138+
Ncomp = 0
139+
total_eqs = 0
140+
unrolled_foreach(filt, batches) do batch
141+
Ncomp += length(batch)::Int
142+
total_eqs += length(batch)::Int * _N_eqs(fg, batch)::Int
143+
end
144+
chunks = Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}(undef, workers)
145+
146+
eqs_per_worker = total_eqs / workers
147+
bi = 1
148+
ci = 1
149+
assigned = 0
150+
eqs_assigned = 0
151+
for w in 1:workers
152+
ci_start = ci
153+
bi_start = bi
154+
eqs_in_worker = 0
155+
assigned_in_worker = 0
156+
while assigned < Ncomp
157+
batch = batches[bi]
158+
filt(batch) || continue
159+
160+
Neqs = _N_eqs(fg, batch)
161+
stop_collecting = false
162+
while true
163+
if ci > length(batch)
164+
break
165+
end
166+
167+
# compare, whether adding the new component helps to come closer to eqs_per_worker
168+
diff_now = abs(eqs_in_worker - eqs_per_worker)
169+
diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker)
170+
stop_collecting = assigned == Ncomp || diff_now diff_next
171+
if stop_collecting
172+
break
173+
end
174+
175+
# add component to worker
176+
eqs_assigned += Neqs
177+
eqs_in_worker += Neqs
178+
assigned_in_worker += 1
179+
assigned += 1
180+
ci += 1
181+
end
182+
# if the hard stop collection is reached, break, otherwise jump to next batch and continue
183+
stop_collecting && break
184+
185+
bi += 1
186+
ci = 1
97187
end
188+
chunk = (; batch_i=bi_start, comp_i=ci_start, N=assigned_in_worker)
189+
chunks[w] = chunk
190+
191+
# update eqs per worker estimate for the other workders
192+
eqs_per_worker = (total_eqs - eqs_assigned) / (workers - w)
98193
end
194+
@assert assigned == Ncomp
195+
return chunks
99196
end
197+
_N_eqs(::Val{:f}, batch) = Int(dim(batch))
198+
_N_eqs(::Val{:g}, batch) = Int(outdim(batch))
199+
_N_eqs(::Val{:fg}, batch) = Int(dim(batch)) + Int(outdim(batch))
200+
100201

101202
@inline function process_batches!(::PolyesterExecution, fg, filt::F, batches, inbufs, duopt) where {F}
102203
unrolled_foreach(filt, batches) do batch

src/executionstyles.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ struct PolyesterExecution{buffered} <: ExecutionStyle{buffered} end
4141
Parallel execution using Julia threads.
4242
For `buffered` see [`ExecutionStyle`](@ref).
4343
"""
44-
struct ThreadedExecution{buffered} <: ExecutionStyle{buffered} end
44+
@kwdef struct ThreadedExecution{buffered} <: ExecutionStyle{buffered}
45+
chunk_cache::Dict{UInt, Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}}=
46+
Dict{UInt, Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}}()
47+
end
4548

4649
usebuffer(::ExecutionStyle{buffered}) where {buffered} = buffered
4750
usebuffer(::Type{<:ExecutionStyle{buffered}}) where {buffered} = buffered

src/network_structure.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ struct Network{EX<:ExecutionStyle,G,NL,VTup,MM,CT,GBT,EM}
8383
gbufprovider::GBT
8484
"map to gather external inputs"
8585
extmap::EM
86+
"execution style"
87+
executionstyle::EX
8688
end
87-
executionstyle(::Network{ex}) where {ex} = ex()
89+
executionstyle(nw::Network) = nw.executionstyle
8890
nvbatches(::Network) = length(vertexbatches)
8991

9092
"""
@@ -164,6 +166,8 @@ end
164166
@inline compf(b::ComponentBatch) = b.compf
165167
@inline compg(b::ComponentBatch) = b.compg
166168
@inline fftype(b::ComponentBatch) = b.ff
169+
@inline dim(b::ComponentBatch) = sum(b.statestride.strides)
170+
@inline outdim(b::ComponentBatch) = sum(b.outbufstride.strides)
167171
@inline pdim(b::ComponentBatch) = b.pstride.strides
168172
@inline extdim(b::ComponentBatch) = b.extbufstride.strides
169173

0 commit comments

Comments
 (0)