Skip to content

batch-spanning chunking in threaded execution #219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ function Adapt.adapt_structure(to, n::Network)
caches = (;output = _adapt_diffcache(to, n.caches.output),
aggregation = _adapt_diffcache(to, n.caches.aggregation),
external = _adapt_diffcache(to, n.caches.external))
exT = typeof(executionstyle(n))
gT = typeof(n.im.g)
ex = executionstyle(n)
extmap = adapt(to, n.extmap)

Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp),typeof(extmap)}(
vb, layer, n.im, caches, mm, gbp, extmap)
Network(vb, layer, n.im, caches, mm, gbp, extmap, ex)
end

Adapt.@adapt_structure NetworkLayer
Expand Down
4 changes: 2 additions & 2 deletions src/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ function Network(g::AbstractGraph,
# create map for extenral inputs
extmap = has_external_input(im) ? ExtMap(im) : nothing

nw = Network{typeof(execution),typeof(g),typeof(nl), typeof(vertexbatches),
typeof(mass_matrix),eltype(caches),typeof(gbufprovider),typeof(extmap)}(
nw = Network(
vertexbatches,
nl, im,
caches,
mass_matrix,
gbufprovider,
extmap,
execution,
)

end
Expand Down
107 changes: 102 additions & 5 deletions src/coreloop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,112 @@ end
end
end

@inline function process_batches!(::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F}
unrolled_foreach(filt, batches) do batch
(du, u, o, p, t) = duopt
Threads.@threads for i in 1:length(batch)
@inline function process_batches!(ex::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F}
Nchunks = Threads.nthreads()

# chunking is kinda expensive, so we cache it
key = hash((Base.objectid(batches), filt, fg, Nchunks))
chunks = get!(ex.chunk_cache, key) do
_chunk_batches(batches, filt, fg, Nchunks)
end

# each chunk consists of array or tuple [(batch, idxs), ...]
_eval_chunk = function(chunk)
unrolled_foreach(chunk) do ch
(; batch, idxs) = ch
(du, u, o, p, t) = duopt
_type = dispatchT(batch)
apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t)
for i in idxs
apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t)
end
end
end
Threads.@sync for chunk in chunks
Threads.@spawn begin
@noinline _eval_chunk(chunk)
end
end
end
function _chunk_batches(batches, filt, fg, workers)
Ncomp = 0
total_eqs = 0
unrolled_foreach(filt, batches) do batch
Ncomp += length(batch)::Int
total_eqs += length(batch)::Int * _N_eqs(fg, batch)::Int
end
chunks = Vector{Any}(undef, workers)

eqs_per_worker = total_eqs / workers
# println("\nTotal eqs: $total_eqs in $Ncomp components, eqs per worker: $eqs_per_worker ($fg)")
bi = 1
ci = 1
assigned = 0
eqs_assigned = 0
for w in 1:workers
# println("Assign worker $w: goal: $eqs_per_worker")
chunk = Vector{Any}()
eqs_in_worker = 0
assigned_in_worker = 0
while assigned < Ncomp
batch = batches[bi]

if filt(batch) #only process if the batch is not filtered out
ci_start = ci
Neqs = _N_eqs(fg, batch)
stop_collecting = false
while true
if ci > length(batch)
break
end

# compare, whether adding the new component helps to come closer to eqs_per_worker
diff_now = abs(eqs_in_worker - eqs_per_worker)
diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker)
stop_collecting = assigned == Ncomp || diff_now < diff_next
if stop_collecting
break
end

# add component to worker
# println(" - Assign component $ci ($Neqs eqs)")
eqs_assigned += Neqs
eqs_in_worker += Neqs
assigned_in_worker += 1
assigned += 1
ci += 1
end
if ci > ci_start # don't push empty chunks
# println(" - Assign batch $(bi) -> $(ci_start:(ci-1)) $(length(ci_start:(ci-1))*Neqs) eqs)")
push!(chunk, (; batch, idxs=ci_start:(ci-1)))
else
# println(" - Skip empty batch $(bi) -> $(ci_start:(ci-1))")
end
stop_collecting && break
else
# println(" - Skip batch $(bi)")
end

bi += 1
ci = 1
end

# narrow down type / make tuple
chunks[w] = if length(chunk) < 10
Tuple(chunk)
else
[c for c in chunk] # narrow down type
end

# update eqs per worker estimate for the other workders
eqs_per_worker = (total_eqs - eqs_assigned) / (workers - w)
end
@assert assigned == Ncomp
return chunks
end
_N_eqs(::Val{:f}, batch) = Int(dim(batch))
_N_eqs(::Val{:g}, batch) = Int(outdim(batch))
_N_eqs(::Val{:fg}, batch) = Int(dim(batch)) + Int(outdim(batch))


@inline function process_batches!(::PolyesterExecution, fg, filt::F, batches, inbufs, duopt) where {F}
unrolled_foreach(filt, batches) do batch
Expand Down
4 changes: 3 additions & 1 deletion src/executionstyles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ struct PolyesterExecution{buffered} <: ExecutionStyle{buffered} end
Parallel execution using Julia threads.
For `buffered` see [`ExecutionStyle`](@ref).
"""
struct ThreadedExecution{buffered} <: ExecutionStyle{buffered} end
@kwdef struct ThreadedExecution{buffered} <: ExecutionStyle{buffered}
chunk_cache::Dict{UInt, Vector} = Dict{UInt, Vector}()
end

usebuffer(::ExecutionStyle{buffered}) where {buffered} = buffered
usebuffer(::Type{<:ExecutionStyle{buffered}}) where {buffered} = buffered
Expand Down
6 changes: 5 additions & 1 deletion src/network_structure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ struct Network{EX<:ExecutionStyle,G,NL,VTup,MM,CT,GBT,EM}
gbufprovider::GBT
"map to gather external inputs"
extmap::EM
"execution style"
executionstyle::EX
end
executionstyle(::Network{ex}) where {ex} = ex()
executionstyle(nw::Network) = nw.executionstyle
nvbatches(::Network) = length(vertexbatches)

"""
Expand Down Expand Up @@ -164,6 +166,8 @@ end
@inline compf(b::ComponentBatch) = b.compf
@inline compg(b::ComponentBatch) = b.compg
@inline fftype(b::ComponentBatch) = b.ff
@inline dim(b::ComponentBatch) = sum(b.statestride.strides)
@inline outdim(b::ComponentBatch) = sum(b.outbufstride.strides)
@inline pdim(b::ComponentBatch) = b.pstride.strides
@inline extdim(b::ComponentBatch) = b.extbufstride.strides

Expand Down
Loading