Skip to content

Commit fea28a0

Browse files
Make @threads work on array comprehensions
1 parent 11eeed3 commit fea28a0

File tree

3 files changed

+356
-37
lines changed

3 files changed

+356
-37
lines changed

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Multi-threading changes
3737
the first time it is called, and then always return the same result value of type `T`
3838
every subsequent time afterwards. There are also `OncePerThread{T}` and `OncePerTask{T}` types for
3939
similar usage with threads or tasks. ([#TBD])
40+
* `Threads.@threads` now supports array comprehensions with syntax like `@threads [f(i) for i in 1:n]`
41+
and filtered comprehensions like `@threads [f(i) for i in 1:n if condition(i)]`. All scheduling
42+
options (`:static`, `:dynamic`, `:greedy`) are supported. Results preserve element order for
43+
`:static` and `:dynamic` scheduling, while `:greedy` may return elements in arbitrary order for
44+
better performance. ([#TBD])
4045

4146
Build system changes
4247
--------------------

base/threadingconstructs.jl

Lines changed: 249 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,117 @@ function _threadsfor(iter, lbody, schedule)
220220
end
221221
end
222222

223+
function _threadsfor_comprehension(gen::Expr, schedule)
224+
@assert gen.head === :generator
225+
226+
body = gen.args[1]
227+
iter_or_filter = gen.args[2]
228+
229+
# Handle filtered vs non-filtered comprehensions
230+
if isa(iter_or_filter, Expr) && iter_or_filter.head === :filter
231+
condition = iter_or_filter.args[1]
232+
iterator = iter_or_filter.args[2]
233+
return _threadsfor_filtered_comprehension(body, iterator, condition, schedule)
234+
else
235+
iterator = iter_or_filter
236+
return _threadsfor_simple_comprehension(body, iterator, schedule)
237+
end
238+
end
239+
240+
function _threadsfor_simple_comprehension(body, iterator, schedule)
241+
lidx = iterator.args[1] # index variable
242+
range = iterator.args[2] # range/iterable
243+
esc_range = esc(range)
244+
esc_body = esc(body)
245+
246+
if schedule === :greedy
247+
quote
248+
local ch = Channel{eltype($esc_range)}(0,spawn=true) do ch
249+
for item in $esc_range
250+
put!(ch, item)
251+
end
252+
end
253+
local thread_result_storage = Vector{Vector{Any}}(undef, threadpoolsize())
254+
function threadsfor_fun(tid)
255+
local_results = Any[]
256+
for item in ch
257+
local $(esc(lidx)) = item
258+
push!(local_results, $esc_body)
259+
end
260+
thread_result_storage[tid] = local_results
261+
end
262+
threading_run(threadsfor_fun, false)
263+
# Collect results after threading_run
264+
assigned_results = [thread_result_storage[i] for i in 1:threadpoolsize() if isassigned(thread_result_storage, i)]
265+
vcat(assigned_results...)
266+
end
267+
else
268+
func = default_comprehension_func(esc_range, lidx, esc_body)
269+
quote
270+
local threadsfor_fun
271+
local result
272+
$func
273+
if $(schedule === :dynamic || schedule === :default)
274+
threading_run(threadsfor_fun, false)
275+
elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static
276+
error("`@threads :static` cannot be used concurrently or nested")
277+
else # :static
278+
threading_run(threadsfor_fun, true)
279+
end
280+
result
281+
end
282+
end
283+
end
284+
285+
function _threadsfor_filtered_comprehension(body, iterator, condition, schedule)
286+
lidx = iterator.args[1] # index variable
287+
range = iterator.args[2] # range/iterable
288+
esc_range = esc(range)
289+
esc_body = esc(body)
290+
esc_condition = esc(condition)
291+
292+
if schedule === :greedy
293+
quote
294+
local ch = Channel{eltype($esc_range)}(0,spawn=true) do ch
295+
for item in $esc_range
296+
put!(ch, item)
297+
end
298+
end
299+
local thread_result_storage = Vector{Vector{Any}}(undef, threadpoolsize())
300+
function threadsfor_fun(tid)
301+
local_results = Any[]
302+
for item in ch
303+
local $(esc(lidx)) = item
304+
if $esc_condition
305+
push!(local_results, $esc_body)
306+
end
307+
end
308+
thread_result_storage[tid] = local_results
309+
end
310+
threading_run(threadsfor_fun, false)
311+
# Collect results after threading_run
312+
assigned_results = [thread_result_storage[i] for i in 1:threadpoolsize() if isassigned(thread_result_storage, i)]
313+
vcat(assigned_results...)
314+
end
315+
else
316+
func = default_filtered_comprehension_func(esc_range, lidx, esc_body, esc_condition)
317+
quote
318+
local threadsfor_fun
319+
local result
320+
$func
321+
if $(schedule === :dynamic || schedule === :default)
322+
threading_run(threadsfor_fun, false)
323+
elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static
324+
error("`@threads :static` cannot be used concurrently or nested")
325+
else # :static
326+
threading_run(threadsfor_fun, true)
327+
end
328+
# Process result after threading_run
329+
vcat(result...)
330+
end
331+
end
332+
end
333+
223334
function greedy_func(itr, lidx, lbody)
224335
quote
225336
let c = Channel{eltype($itr)}(0,spawn=true) do ch
@@ -237,39 +348,47 @@ function greedy_func(itr, lidx, lbody)
237348
end
238349
end
239350

351+
# Helper function to generate work distribution code
352+
function _work_distribution_code()
353+
quote
354+
r = range # Load into local variable
355+
lenr = length(r)
356+
# divide loop iterations among threads
357+
if onethread
358+
tid = 1
359+
len, rem = lenr, 0
360+
else
361+
len, rem = divrem(lenr, threadpoolsize())
362+
end
363+
# not enough iterations for all the threads?
364+
if len == 0
365+
if tid > rem
366+
return
367+
end
368+
len, rem = 1, 0
369+
end
370+
# compute this thread's iterations
371+
f = firstindex(r) + ((tid-1) * len)
372+
l = f + len - 1
373+
# distribute remaining iterations evenly
374+
if rem > 0
375+
if tid <= rem
376+
f = f + (tid-1)
377+
l = l + tid
378+
else
379+
f = f + rem
380+
l = l + rem
381+
end
382+
end
383+
end
384+
end
385+
240386
function default_func(itr, lidx, lbody)
387+
work_dist = _work_distribution_code()
241388
quote
242389
let range = $itr
243390
function threadsfor_fun(tid = 1; onethread = false)
244-
r = range # Load into local variable
245-
lenr = length(r)
246-
# divide loop iterations among threads
247-
if onethread
248-
tid = 1
249-
len, rem = lenr, 0
250-
else
251-
len, rem = divrem(lenr, threadpoolsize())
252-
end
253-
# not enough iterations for all the threads?
254-
if len == 0
255-
if tid > rem
256-
return
257-
end
258-
len, rem = 1, 0
259-
end
260-
# compute this thread's iterations
261-
f = firstindex(r) + ((tid-1) * len)
262-
l = f + len - 1
263-
# distribute remaining iterations evenly
264-
if rem > 0
265-
if tid <= rem
266-
f = f + (tid-1)
267-
l = l + tid
268-
else
269-
f = f + rem
270-
l = l + rem
271-
end
272-
end
391+
$work_dist
273392
# run this thread's iterations
274393
for i = f:l
275394
local $(esc(lidx)) = @inbounds r[i]
@@ -280,13 +399,68 @@ function default_func(itr, lidx, lbody)
280399
end
281400
end
282401

402+
function default_comprehension_func(itr, lidx, body)
403+
work_dist = _work_distribution_code()
404+
quote
405+
result = let range = $itr
406+
lenr = length(range)
407+
# Pre-allocate result array with the correct size
408+
local result_array = Vector{Any}(undef, lenr)
409+
410+
function threadsfor_fun(tid = 1; onethread = false)
411+
$work_dist
412+
# run this thread's iterations and store directly in result_array
413+
for i = f:l
414+
local $(esc(lidx)) = @inbounds r[i]
415+
result_array[i] = $body
416+
end
417+
end
418+
419+
result_array
420+
end
421+
end
422+
end
423+
424+
function default_filtered_comprehension_func(itr, lidx, body, condition)
425+
work_dist = _work_distribution_code()
426+
quote
427+
let range = $itr
428+
local thread_results = Vector{Vector{Any}}(undef, threadpoolsize())
429+
# Initialize all result vectors to empty
430+
for i in 1:threadpoolsize()
431+
thread_results[i] = Any[]
432+
end
433+
434+
function threadsfor_fun(tid = 1; onethread = false)
435+
$work_dist
436+
# run this thread's iterations with filtering
437+
local_results = Any[]
438+
for i = f:l
439+
local $(esc(lidx)) = @inbounds r[i]
440+
if $condition
441+
push!(local_results, $body)
442+
end
443+
end
444+
thread_results[tid] = local_results
445+
end
446+
447+
result = thread_results # This will be populated by threading_run
448+
end
449+
end
450+
end
451+
283452
"""
284453
Threads.@threads [schedule] for ... end
454+
Threads.@threads [schedule] [expr for ... end]
285455
286-
A macro to execute a `for` loop in parallel. The iteration space is distributed to
456+
A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to
287457
coarse-grained tasks. This policy can be specified by the `schedule` argument. The
288458
execution of the loop waits for the evaluation of all iterations.
289459
460+
For `for` loops, the macro executes the loop body in parallel but does not return a value.
461+
For array comprehensions, the macro executes the comprehension in parallel and returns
462+
the collected results as an array.
463+
290464
See also: [`@spawn`](@ref Threads.@spawn) and
291465
`pmap` in [`Distributed`](@ref man-distributed).
292466
@@ -371,6 +545,8 @@ thread other than 1.
371545
372546
## Examples
373547
548+
### For loops
549+
374550
To illustrate of the different scheduling strategies, consider the following function
375551
`busywait` containing a non-yielding timed loop that runs for a given number of seconds.
376552
@@ -400,6 +576,38 @@ julia> @time begin
400576
401577
The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able
402578
to run two of the 1-second iterations to complete the for loop.
579+
580+
### Array comprehensions
581+
582+
The `@threads` macro also supports array comprehensions, which return the collected results:
583+
584+
```julia-repl
585+
julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension
586+
5-element Vector{Int64}:
587+
1
588+
4
589+
9
590+
16
591+
25
592+
593+
julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension
594+
2-element Vector{Int64}:
595+
4
596+
16
597+
```
598+
599+
When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling
600+
option can be used, but note that the order of the results is not guaranteed.
601+
```julia-repl
602+
julia> c = Channel(5, spawn=true) do ch
603+
foreach(i -> put!(ch, i), 1:5)
604+
end;
605+
606+
julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)]
607+
2-element Vector{Any}:
608+
16
609+
4
610+
```
403611
"""
404612
macro threads(args...)
405613
na = length(args)
@@ -420,13 +628,18 @@ macro threads(args...)
420628
else
421629
throw(ArgumentError("wrong number of arguments in @threads"))
422630
end
423-
if !(isa(ex, Expr) && ex.head === :for)
424-
throw(ArgumentError("@threads requires a `for` loop expression"))
425-
end
426-
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
427-
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
631+
if isa(ex, Expr) && ex.head === :comprehension
632+
# Handle array comprehensions
633+
return _threadsfor_comprehension(ex.args[1], sched)
634+
elseif isa(ex, Expr) && ex.head === :for
635+
# Handle for loops
636+
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
637+
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
638+
end
639+
return _threadsfor(ex.args[1], ex.args[2], sched)
640+
else
641+
throw(ArgumentError("@threads requires a `for` loop or comprehension expression"))
428642
end
429-
return _threadsfor(ex.args[1], ex.args[2], sched)
430643
end
431644

432645
function _spawn_set_thrpool(t::Task, tp::Symbol)

0 commit comments

Comments
 (0)