|
88 | 88 | end
|
89 | 89 | end
|
90 | 90 |
|
91 |
| -@inline function process_batches!(::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F} |
92 |
| - unrolled_foreach(filt, batches) do batch |
| 91 | +@inline function process_batches!(ex::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F} |
| 92 | + # unrolled_foreach(filt, batches) do batch |
| 93 | + # (du, u, o, p, t) = duopt |
| 94 | + # Threads.@threads for i in 1:length(batch) |
| 95 | + # _type = dispatchT(batch) |
| 96 | + # apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) |
| 97 | + # end |
| 98 | + # end |
| 99 | + # return |
| 100 | + |
| 101 | + Nchunks = Threads.nthreads() |
| 102 | + # Nchunks = 4 |
| 103 | + # chunking is kinda expensive, so we cache it |
| 104 | + key = hash((Base.objectid(batches), filt, fg, Nchunks)) |
| 105 | + chunks = get!(ex.chunk_cache, key) do |
| 106 | + _chunk_batches(batches, filt, fg, Nchunks) |
| 107 | + end |
| 108 | + |
| 109 | + _progress_in_batch = function(batch, ci, processed, N) |
93 | 110 | (du, u, o, p, t) = duopt
|
94 |
| - Threads.@threads for i in 1:length(batch) |
95 |
| - _type = dispatchT(batch) |
96 |
| - apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) |
| 111 | + _type = dispatchT(batch) |
| 112 | + while ci ≤ length(batch) && processed < N |
| 113 | + apply_comp!(_type, fg, batch, ci, du, u, o, inbufs, p, t) |
| 114 | + ci += 1 |
| 115 | + processed += 1 |
| 116 | + end |
| 117 | + processed, ci |
| 118 | + end |
| 119 | + |
| 120 | + Threads.@sync for chunk in chunks |
| 121 | + chunk.N == 0 && continue |
| 122 | + Threads.@spawn begin |
| 123 | + local N = chunk.N |
| 124 | + local bi = chunk.batch_i |
| 125 | + local ci = chunk.comp_i |
| 126 | + local processed = 0 |
| 127 | + while processed < N |
| 128 | + batch = batches[bi] |
| 129 | + filt(batch) || continue |
| 130 | + processed, ci = @noinline _progress_in_batch(batch, ci, processed, N) |
| 131 | + bi += 1 |
| 132 | + ci = 1 |
| 133 | + end |
| 134 | + end |
| 135 | + end |
| 136 | +end |
| 137 | +function _chunk_batches(batches, filt, fg, workers) |
| 138 | + Ncomp = 0 |
| 139 | + total_eqs = 0 |
| 140 | + unrolled_foreach(filt, batches) do batch |
| 141 | + Ncomp += length(batch)::Int |
| 142 | + total_eqs += length(batch)::Int * _N_eqs(fg, batch)::Int |
| 143 | + end |
| 144 | + chunks = Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}(undef, workers) |
| 145 | + |
| 146 | + eqs_per_worker = total_eqs / workers |
| 147 | + bi = 1 |
| 148 | + ci = 1 |
| 149 | + assigned = 0 |
| 150 | + eqs_assigned = 0 |
| 151 | + for w in 1:workers |
| 152 | + ci_start = ci |
| 153 | + bi_start = bi |
| 154 | + eqs_in_worker = 0 |
| 155 | + assigned_in_worker = 0 |
| 156 | + while assigned < Ncomp |
| 157 | + batch = batches[bi] |
| 158 | + filt(batch) || continue |
| 159 | + |
| 160 | + Neqs = _N_eqs(fg, batch) |
| 161 | + stop_collecting = false |
| 162 | + while true |
| 163 | + if ci > length(batch) |
| 164 | + break |
| 165 | + end |
| 166 | + |
| 167 | + # compare, whether adding the new component helps to come closer to eqs_per_worker |
| 168 | + diff_now = abs(eqs_in_worker - eqs_per_worker) |
| 169 | + diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker) |
| 170 | + stop_collecting = assigned == Ncomp || diff_now ≤ diff_next |
| 171 | + if stop_collecting |
| 172 | + break |
| 173 | + end |
| 174 | + |
| 175 | + # add component to worker |
| 176 | + eqs_assigned += Neqs |
| 177 | + eqs_in_worker += Neqs |
| 178 | + assigned_in_worker += 1 |
| 179 | + assigned += 1 |
| 180 | + ci += 1 |
| 181 | + end |
| 182 | + # if the hard stop collection is reached, break, otherwise jump to next batch and continue |
| 183 | + stop_collecting && break |
| 184 | + |
| 185 | + bi += 1 |
| 186 | + ci = 1 |
97 | 187 | end
|
| 188 | + chunk = (; batch_i=bi_start, comp_i=ci_start, N=assigned_in_worker) |
| 189 | + chunks[w] = chunk |
| 190 | + |
| 191 | + # update eqs per worker estimate for the other workders |
| 192 | + eqs_per_worker = (total_eqs - eqs_assigned) / (workers - w) |
98 | 193 | end
|
| 194 | + @assert assigned == Ncomp |
| 195 | + return chunks |
99 | 196 | end
|
| 197 | +_N_eqs(::Val{:f}, batch) = Int(dim(batch)) |
| 198 | +_N_eqs(::Val{:g}, batch) = Int(outdim(batch)) |
| 199 | +_N_eqs(::Val{:fg}, batch) = Int(dim(batch)) + Int(outdim(batch)) |
| 200 | + |
100 | 201 |
|
101 | 202 | @inline function process_batches!(::PolyesterExecution, fg, filt::F, batches, inbufs, duopt) where {F}
|
102 | 203 | unrolled_foreach(filt, batches) do batch
|
|
0 commit comments