Skip to content

Commit 94fd312

Browse files
authored
Add :greedy scheduler to @threads (#52096)
1 parent 353884c commit 94fd312

File tree

3 files changed

+133
-10
lines changed

3 files changed

+133
-10
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ difference between defining a `main` function and executing the code directly at
7171
Multi-threading changes
7272
-----------------------
7373

74+
* `Threads.@threads` now supports the `:greedy` scheduler, intended for non-uniform workloads ([#52096]).
75+
7476
Build system changes
7577
--------------------
7678

base/threadingconstructs.jl

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,46 @@ end
176176
function _threadsfor(iter, lbody, schedule)
177177
lidx = iter.args[1] # index
178178
range = iter.args[2]
179+
esc_range = esc(range)
180+
func = if schedule === :greedy
181+
greedy_func(esc_range, lidx, lbody)
182+
else
183+
default_func(esc_range, lidx, lbody)
184+
end
179185
quote
180186
local threadsfor_fun
181-
let range = $(esc(range))
187+
$func
188+
if $(schedule === :greedy || schedule === :dynamic || schedule === :default)
189+
threading_run(threadsfor_fun, false)
190+
elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static
191+
error("`@threads :static` cannot be used concurrently or nested")
192+
else # :static
193+
threading_run(threadsfor_fun, true)
194+
end
195+
nothing
196+
end
197+
end
198+
199+
function greedy_func(itr, lidx, lbody)
200+
quote
201+
let c = Channel{eltype($itr)}(0,spawn=true) do ch
202+
for item in $itr
203+
put!(ch, item)
204+
end
205+
end
206+
function threadsfor_fun(tid)
207+
for item in c
208+
local $(esc(lidx)) = item
209+
$(esc(lbody))
210+
end
211+
end
212+
end
213+
end
214+
end
215+
216+
function default_func(itr, lidx, lbody)
217+
quote
218+
let range = $itr
182219
function threadsfor_fun(tid = 1; onethread = false)
183220
r = range # Load into local variable
184221
lenr = length(r)
@@ -216,14 +253,6 @@ function _threadsfor(iter, lbody, schedule)
216253
end
217254
end
218255
end
219-
if $(schedule === :dynamic || schedule === :default)
220-
threading_run(threadsfor_fun, false)
221-
elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static
222-
error("`@threads :static` cannot be used concurrently or nested")
223-
else # :static
224-
threading_run(threadsfor_fun, true)
225-
end
226-
nothing
227256
end
228257
end
229258

@@ -289,6 +318,20 @@ microseconds).
289318
!!! compat "Julia 1.8"
290319
The `:dynamic` option for the `schedule` argument is available and the default as of Julia 1.8.
291320
321+
### `:greedy`
322+
323+
`:greedy` scheduler spawns up to [`Threads.threadpoolsize()`](@ref) tasks, each greedily working on
324+
the given iterated values as they are produced. As soon as one task finishes its work, it takes
325+
the next value from the iterator. Work done by any individual task is not necessarily on
326+
contiguous values from the iterator. The given iterator may produce values forever, only the
327+
iterator interface is required (no indexing).
328+
329+
This scheduling option is generally a good choice if the workload of individual iterations
330+
is not uniform/has a large spread.
331+
332+
!!! compat "Julia 1.11"
333+
The `:greedy` option for the `schedule` argument is available as of Julia 1.11.
334+
292335
### `:static`
293336
294337
`:static` scheduler creates one task per thread and divides the iterations equally among
@@ -344,7 +387,7 @@ macro threads(args...)
344387
# for now only allow quoted symbols
345388
sched = nothing
346389
end
347-
if sched !== :static && sched !== :dynamic
390+
if sched !== :static && sched !== :dynamic && sched !== :greedy
348391
throw(ArgumentError("unsupported schedule argument in @threads"))
349392
end
350393
elseif na == 1

test/threads_exec.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,84 @@ function _atthreads_dynamic_with_error(a)
803803
end
804804
@test_throws "user error in the loop body" _atthreads_dynamic_with_error(zeros(threadpoolsize()))
805805

806+
####
807+
# :greedy
808+
###
809+
810+
function _atthreads_greedy_schedule(n)
811+
inc = Threads.Atomic{Int}(0)
812+
flags = zeros(Int, n)
813+
Threads.@threads :greedy for i = 1:n
814+
Threads.atomic_add!(inc, 1)
815+
flags[i] = 1
816+
end
817+
return inc[], flags
818+
end
819+
@test _atthreads_greedy_schedule(threadpoolsize()) == (threadpoolsize(), ones(threadpoolsize()))
820+
@test _atthreads_greedy_schedule(1) == (1, ones(1))
821+
@test _atthreads_greedy_schedule(10) == (10, ones(10))
822+
@test _atthreads_greedy_schedule(threadpoolsize() * 2) == (threadpoolsize() * 2, ones(threadpoolsize() * 2))
823+
824+
# nested greedy schedule
825+
function _atthreads_greedy_greedy_schedule()
826+
inc = Threads.Atomic{Int}(0)
827+
Threads.@threads :greedy for _ = 1:threadpoolsize()
828+
Threads.@threads :greedy for _ = 1:threadpoolsize()
829+
Threads.atomic_add!(inc, 1)
830+
end
831+
end
832+
return inc[]
833+
end
834+
@test _atthreads_greedy_greedy_schedule() == threadpoolsize() * threadpoolsize()
835+
836+
function _atthreads_greedy_dynamic_schedule()
837+
inc = Threads.Atomic{Int}(0)
838+
Threads.@threads :greedy for _ = 1:threadpoolsize()
839+
Threads.@threads :dynamic for _ = 1:threadpoolsize()
840+
Threads.atomic_add!(inc, 1)
841+
end
842+
end
843+
return inc[]
844+
end
845+
@test _atthreads_greedy_dynamic_schedule() == threadpoolsize() * threadpoolsize()
846+
847+
function _atthreads_dymamic_greedy_schedule()
848+
inc = Threads.Atomic{Int}(0)
849+
Threads.@threads :dynamic for _ = 1:threadpoolsize()
850+
Threads.@threads :greedy for _ = 1:threadpoolsize()
851+
Threads.atomic_add!(inc, 1)
852+
end
853+
end
854+
return inc[]
855+
end
856+
@test _atthreads_dymamic_greedy_schedule() == threadpoolsize() * threadpoolsize()
857+
858+
function _atthreads_static_greedy_schedule()
859+
ids = zeros(Int, threadpoolsize())
860+
inc = Threads.Atomic{Int}(0)
861+
Threads.@threads :static for i = 1:threadpoolsize()
862+
ids[i] = Threads.threadid()
863+
Threads.@threads :greedy for _ = 1:threadpoolsize()
864+
Threads.atomic_add!(inc, 1)
865+
end
866+
end
867+
return ids, inc[]
868+
end
869+
@test _atthreads_static_greedy_schedule() == (1:threadpoolsize(), threadpoolsize() * threadpoolsize())
870+
871+
# errors inside @threads :greedy
872+
function _atthreads_greedy_with_error(a)
873+
Threads.@threads :greedy for i in eachindex(a)
874+
error("user error in the loop body")
875+
end
876+
a
877+
end
878+
@test_throws "user error in the loop body" _atthreads_greedy_with_error(zeros(threadpoolsize()))
879+
880+
####
881+
# multi-argument loop
882+
####
883+
806884
try
807885
@macroexpand @threads(for i = 1:10, j = 1:10; end)
808886
catch ex

0 commit comments

Comments
 (0)