Skip to content

Commit 8c6f9cf

Browse files
committed
fix adapt and make chunking slighly greedy
1 parent ffaa103 commit 8c6f9cf

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

ext/CUDAExt.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@ function Adapt.adapt_structure(to, n::Network)
3434
caches = (;output = _adapt_diffcache(to, n.caches.output),
3535
aggregation = _adapt_diffcache(to, n.caches.aggregation),
3636
external = _adapt_diffcache(to, n.caches.external))
37-
exT = typeof(executionstyle(n))
38-
gT = typeof(n.im.g)
37+
ex = executionstyle(n)
3938
extmap = adapt(to, n.extmap)
4039

41-
Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp),typeof(extmap)}(
42-
vb, layer, n.im, caches, mm, gbp, extmap)
40+
Network(vb, layer, n.im, caches, mm, gbp, extmap, ex)
4341
end
4442

4543
Adapt.@adapt_structure NetworkLayer

src/coreloop.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,13 @@ function _chunk_batches(batches, filt, fg, workers)
124124
chunks = Vector{Any}(undef, workers)
125125

126126
eqs_per_worker = total_eqs / workers
127+
# println("Total eqs: $total_eqs in $Ncomp components, eqs per worker: $eqs_per_worker ($fg)")
127128
bi = 1
128129
ci = 1
129130
assigned = 0
130131
eqs_assigned = 0
131132
for w in 1:workers
133+
# println("Assign worker $w: goal: $eqs_per_worker")
132134
chunk = Vector{Any}()
133135
ci_start = ci
134136
eqs_in_worker = 0
@@ -147,22 +149,28 @@ function _chunk_batches(batches, filt, fg, workers)
147149
# compare, whether adding the new component helps to come closer to eqs_per_worker
148150
diff_now = abs(eqs_in_worker - eqs_per_worker)
149151
diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker)
150-
stop_collecting = assigned == Ncomp || diff_now diff_next
152+
stop_collecting = assigned == Ncomp || diff_now < diff_next
151153
if stop_collecting
152154
break
153155
end
154156

155157
# add component to worker
158+
# println(" - Assign component $ci ($Neqs eqs)")
156159
eqs_assigned += Neqs
157160
eqs_in_worker += Neqs
158161
assigned_in_worker += 1
159162
assigned += 1
160163
ci += 1
161164
end
162165
if ci > ci_start # don't push empty chunks
163-
push!(chunk, (; batch, idxs=ci_start:ci-1))
166+
# println(" - Assign batch $(bi) -> $(ci_start:(ci-1)) $(length(ci_start:(ci-1))*Neqs) eqs)")
167+
push!(chunk, (; batch, idxs=ci_start:(ci-1)))
168+
else
169+
# println(" - Skip empty batch $(bi) -> $(ci_start:(ci-1))")
164170
end
165171
stop_collecting && break
172+
else
173+
# println(" - Skip batch $(bi)")
166174
end
167175

168176
bi += 1

0 commit comments

Comments
 (0)