99
99
# return
100
100
101
101
Nchunks = Threads. nthreads ()
102
- # Nchunks = 4
102
+ # Nchunks = 8
103
103
# chunking is kinda expensive, so we cache it
104
104
key = hash ((Base. objectid (batches), filt, fg, Nchunks))
105
105
chunks = get! (ex. chunk_cache, key) do
106
106
_chunk_batches (batches, filt, fg, Nchunks)
107
107
end
108
108
109
- _progress_in_batch = function (batch, ci, processed, N )
109
+ _eval_batchportion = function (batch, idxs )
110
110
(du, u, o, p, t) = duopt
111
111
_type = dispatchT (batch)
112
- while ci ≤ length (batch) && processed < N
113
- apply_comp! (_type, fg, batch, ci, du, u, o, inbufs, p, t)
114
- ci += 1
115
- processed += 1
112
+ for i in idxs
113
+ apply_comp! (_type, fg, batch, i, du, u, o, inbufs, p, t)
116
114
end
117
- processed, ci
118
115
end
119
116
120
117
Threads. @sync for chunk in chunks
121
- chunk. N == 0 && continue
118
+ isempty ( chunk) && continue
122
119
Threads. @spawn begin
123
- local N = chunk. N
124
- local bi = chunk. batch_i
125
- local ci = chunk. comp_i
126
- local processed = 0
127
- while processed < N
128
- batch = batches[bi]
129
- filt (batch) || continue
130
- processed, ci = @noinline _progress_in_batch (batch, ci, processed, N)
131
- bi += 1
132
- ci = 1
120
+ for (; bi, idxs) in chunk
121
+ batch = batches[bi] # filtering don in chunks
122
+ @noinline _eval_batchportion (batch, idxs)
133
123
end
134
124
end
135
125
end
@@ -141,51 +131,53 @@ function _chunk_batches(batches, filt, fg, workers)
141
131
Ncomp += length (batch):: Int
142
132
total_eqs += length (batch):: Int * _N_eqs (fg, batch):: Int
143
133
end
144
- chunks = Vector {@NamedTuple{batch_i ::Int, comp_i::Int, N::Int }} (undef, workers)
134
+ chunks = Vector{Vector{ @NamedTuple {bi :: Int ,idxs :: UnitRange{Int64} } }}(undef, workers)
145
135
146
136
eqs_per_worker = total_eqs / workers
147
137
bi = 1
148
138
ci = 1
149
139
assigned = 0
150
140
eqs_assigned = 0
151
141
for w in 1 : workers
142
+ chunk = @NamedTuple {bi:: Int ,idxs:: UnitRange{Int64} }[]
152
143
ci_start = ci
153
- bi_start = bi
154
144
eqs_in_worker = 0
155
145
assigned_in_worker = 0
156
146
while assigned < Ncomp
157
147
batch = batches[bi]
158
- filt (batch) || continue
159
148
160
- Neqs = _N_eqs (fg, batch)
161
- stop_collecting = false
162
- while true
163
- if ci > length (batch)
164
- break
149
+ if filt (batch) # only process if the batch is not filtered out
150
+ Neqs = _N_eqs (fg, batch)
151
+ stop_collecting = false
152
+ while true
153
+ if ci > length (batch)
154
+ break
155
+ end
156
+
157
+ # compare, whether adding the new component helps to come closer to eqs_per_worker
158
+ diff_now = abs (eqs_in_worker - eqs_per_worker)
159
+ diff_next = abs (eqs_in_worker + Neqs - eqs_per_worker)
160
+ stop_collecting = assigned == Ncomp || diff_now ≤ diff_next
161
+ if stop_collecting
162
+ break
163
+ end
164
+
165
+ # add component to worker
166
+ eqs_assigned += Neqs
167
+ eqs_in_worker += Neqs
168
+ assigned_in_worker += 1
169
+ assigned += 1
170
+ ci += 1
165
171
end
166
-
167
- # compare, whether adding the new component helps to come closer to eqs_per_worker
168
- diff_now = abs (eqs_in_worker - eqs_per_worker)
169
- diff_next = abs (eqs_in_worker + Neqs - eqs_per_worker)
170
- stop_collecting = assigned == Ncomp || diff_now ≤ diff_next
171
- if stop_collecting
172
- break
172
+ if ci > ci_start # don't push empty chunks
173
+ push! (chunk, (; bi, idxs= ci_start: ci- 1 ))
173
174
end
174
-
175
- # add component to worker
176
- eqs_assigned += Neqs
177
- eqs_in_worker += Neqs
178
- assigned_in_worker += 1
179
- assigned += 1
180
- ci += 1
175
+ stop_collecting && break
181
176
end
182
- # if the hard stop collection is reached, break, otherwise jump to next batch and continue
183
- stop_collecting && break
184
177
185
178
bi += 1
186
179
ci = 1
187
180
end
188
- chunk = (; batch_i= bi_start, comp_i= ci_start, N= assigned_in_worker)
189
181
chunks[w] = chunk
190
182
191
183
# update eqs per worker estimate for the other workders
0 commit comments