Skip to content

Commit d2bced7

Browse files
Make @threads work on array comprehensions
1 parent 11eeed3 commit d2bced7

File tree

3 files changed

+290
-37
lines changed

3 files changed

+290
-37
lines changed

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Multi-threading changes
3737
the first time it is called, and then always return the same result value of type `T`
3838
every subsequent time afterwards. There are also `OncePerThread{T}` and `OncePerTask{T}` types for
3939
similar usage with threads or tasks. ([#TBD])
40+
* `Threads.@threads` now supports array comprehensions with syntax like `@threads [f(i) for i in 1:n]`
41+
and filtered comprehensions like `@threads [f(i) for i in 1:n if condition(i)]`. All scheduling
42+
options (`:static`, `:dynamic`, `:greedy`) are supported. Results preserve element order for
43+
`:static` and `:dynamic` scheduling, while `:greedy` may return elements in arbitrary order for
44+
better performance. ([#TBD])
4045

4146
Build system changes
4247
--------------------

base/threadingconstructs.jl

Lines changed: 183 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,73 @@ function _threadsfor(iter, lbody, schedule)
220220
end
221221
end
222222

223+
function _threadsfor_comprehension(gen::Expr, schedule)
224+
@assert gen.head === :generator
225+
226+
body = gen.args[1]
227+
iter_or_filter = gen.args[2]
228+
229+
# Handle filtered vs non-filtered comprehensions
230+
if isa(iter_or_filter, Expr) && iter_or_filter.head === :filter
231+
condition = iter_or_filter.args[1]
232+
iterator = iter_or_filter.args[2]
233+
return _threadsfor_filtered_comprehension(body, iterator, condition, schedule)
234+
else
235+
iterator = iter_or_filter
236+
# Use filtered comprehension with `true` condition for non-filtered case
237+
return _threadsfor_filtered_comprehension(body, iterator, true, schedule)
238+
end
239+
end
240+
241+
function _threadsfor_filtered_comprehension(body, iterator, condition, schedule)
242+
lidx = iterator.args[1] # index variable
243+
range = iterator.args[2] # range/iterable
244+
esc_range = esc(range)
245+
esc_body = esc(body)
246+
esc_condition = esc(condition)
247+
248+
if schedule === :greedy
249+
quote
250+
local ch = Channel{eltype($esc_range)}(0,spawn=true) do ch
251+
for item in $esc_range
252+
put!(ch, item)
253+
end
254+
end
255+
local thread_result_storage = Vector{Vector{Any}}(undef, threadpoolsize())
256+
function threadsfor_fun(tid)
257+
local_results = Any[]
258+
for item in ch
259+
local $(esc(lidx)) = item
260+
if $esc_condition
261+
push!(local_results, $esc_body)
262+
end
263+
end
264+
thread_result_storage[tid] = local_results
265+
end
266+
threading_run(threadsfor_fun, false)
267+
# Collect results after threading_run
268+
assigned_results = [thread_result_storage[i] for i in 1:threadpoolsize() if isassigned(thread_result_storage, i)]
269+
vcat(assigned_results...)
270+
end
271+
else
272+
func = default_filtered_comprehension_func(esc_range, lidx, esc_body, esc_condition)
273+
quote
274+
local threadsfor_fun
275+
local result
276+
$func
277+
if $(schedule === :dynamic || schedule === :default)
278+
threading_run(threadsfor_fun, false)
279+
elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static
280+
error("`@threads :static` cannot be used concurrently or nested")
281+
else # :static
282+
threading_run(threadsfor_fun, true)
283+
end
284+
# Process result after threading_run
285+
vcat(result...)
286+
end
287+
end
288+
end
289+
223290
function greedy_func(itr, lidx, lbody)
224291
quote
225292
let c = Channel{eltype($itr)}(0,spawn=true) do ch
@@ -237,39 +304,47 @@ function greedy_func(itr, lidx, lbody)
237304
end
238305
end
239306

307+
# Helper function to generate work distribution code
308+
function _work_distribution_code()
309+
quote
310+
r = range # Load into local variable
311+
lenr = length(r)
312+
# divide loop iterations among threads
313+
if onethread
314+
tid = 1
315+
len, rem = lenr, 0
316+
else
317+
len, rem = divrem(lenr, threadpoolsize())
318+
end
319+
# not enough iterations for all the threads?
320+
if len == 0
321+
if tid > rem
322+
return
323+
end
324+
len, rem = 1, 0
325+
end
326+
# compute this thread's iterations
327+
f = firstindex(r) + ((tid-1) * len)
328+
l = f + len - 1
329+
# distribute remaining iterations evenly
330+
if rem > 0
331+
if tid <= rem
332+
f = f + (tid-1)
333+
l = l + tid
334+
else
335+
f = f + rem
336+
l = l + rem
337+
end
338+
end
339+
end
340+
end
341+
240342
function default_func(itr, lidx, lbody)
343+
work_dist = _work_distribution_code()
241344
quote
242345
let range = $itr
243346
function threadsfor_fun(tid = 1; onethread = false)
244-
r = range # Load into local variable
245-
lenr = length(r)
246-
# divide loop iterations among threads
247-
if onethread
248-
tid = 1
249-
len, rem = lenr, 0
250-
else
251-
len, rem = divrem(lenr, threadpoolsize())
252-
end
253-
# not enough iterations for all the threads?
254-
if len == 0
255-
if tid > rem
256-
return
257-
end
258-
len, rem = 1, 0
259-
end
260-
# compute this thread's iterations
261-
f = firstindex(r) + ((tid-1) * len)
262-
l = f + len - 1
263-
# distribute remaining iterations evenly
264-
if rem > 0
265-
if tid <= rem
266-
f = f + (tid-1)
267-
l = l + tid
268-
else
269-
f = f + rem
270-
l = l + rem
271-
end
272-
end
347+
$work_dist
273348
# run this thread's iterations
274349
for i = f:l
275350
local $(esc(lidx)) = @inbounds r[i]
@@ -280,13 +355,46 @@ function default_func(itr, lidx, lbody)
280355
end
281356
end
282357

358+
function default_filtered_comprehension_func(itr, lidx, body, condition)
359+
work_dist = _work_distribution_code()
360+
quote
361+
let range = $itr
362+
local thread_results = Vector{Vector{Any}}(undef, threadpoolsize())
363+
# Initialize all result vectors to empty
364+
for i in 1:threadpoolsize()
365+
thread_results[i] = Any[]
366+
end
367+
368+
function threadsfor_fun(tid = 1; onethread = false)
369+
$work_dist
370+
# run this thread's iterations with filtering
371+
local_results = Any[]
372+
for i = f:l
373+
local $(esc(lidx)) = @inbounds r[i]
374+
if $condition
375+
push!(local_results, $body)
376+
end
377+
end
378+
thread_results[tid] = local_results
379+
end
380+
381+
result = thread_results # This will be populated by threading_run
382+
end
383+
end
384+
end
385+
283386
"""
284387
Threads.@threads [schedule] for ... end
388+
Threads.@threads [schedule] [expr for ... end]
285389
286-
A macro to execute a `for` loop in parallel. The iteration space is distributed to
390+
A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to
287391
coarse-grained tasks. This policy can be specified by the `schedule` argument. The
288392
execution of the loop waits for the evaluation of all iterations.
289393
394+
For `for` loops, the macro executes the loop body in parallel but does not return a value.
395+
For array comprehensions, the macro executes the comprehension in parallel and returns
396+
the collected results as an array.
397+
290398
See also: [`@spawn`](@ref Threads.@spawn) and
291399
`pmap` in [`Distributed`](@ref man-distributed).
292400
@@ -371,6 +479,8 @@ thread other than 1.
371479
372480
## Examples
373481
482+
### For loops
483+
374484
To illustrate of the different scheduling strategies, consider the following function
375485
`busywait` containing a non-yielding timed loop that runs for a given number of seconds.
376486
@@ -400,6 +510,38 @@ julia> @time begin
400510
401511
The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able
402512
to run two of the 1-second iterations to complete the for loop.
513+
514+
### Array comprehensions
515+
516+
The `@threads` macro also supports array comprehensions, which return the collected results:
517+
518+
```julia-repl
519+
julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension
520+
5-element Vector{Int64}:
521+
1
522+
4
523+
9
524+
16
525+
25
526+
527+
julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension
528+
2-element Vector{Int64}:
529+
4
530+
16
531+
```
532+
533+
When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling
534+
option can be used, but note that the order of the results is not guaranteed.
535+
```julia-repl
536+
julia> c = Channel(5, spawn=true) do ch
537+
foreach(i -> put!(ch, i), 1:5)
538+
end;
539+
540+
julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)]
541+
2-element Vector{Any}:
542+
16
543+
4
544+
```
403545
"""
404546
macro threads(args...)
405547
na = length(args)
@@ -420,13 +562,18 @@ macro threads(args...)
420562
else
421563
throw(ArgumentError("wrong number of arguments in @threads"))
422564
end
423-
if !(isa(ex, Expr) && ex.head === :for)
424-
throw(ArgumentError("@threads requires a `for` loop expression"))
425-
end
426-
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
427-
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
565+
if isa(ex, Expr) && ex.head === :comprehension
566+
# Handle array comprehensions
567+
return _threadsfor_comprehension(ex.args[1], sched)
568+
elseif isa(ex, Expr) && ex.head === :for
569+
# Handle for loops
570+
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
571+
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
572+
end
573+
return _threadsfor(ex.args[1], ex.args[2], sched)
574+
else
575+
throw(ArgumentError("@threads requires a `for` loop or comprehension expression"))
428576
end
429-
return _threadsfor(ex.args[1], ex.args[2], sched)
430577
end
431578

432579
function _spawn_set_thrpool(t::Task, tp::Symbol)

test/threads.jl

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,109 @@ end
335335
@test_throws ArgumentError @macroexpand(@threads 1 2) # wrong number of args
336336
@test_throws ArgumentError @macroexpand(@threads 1) # arg isn't an Expr
337337
@test_throws ArgumentError @macroexpand(@threads if true 1 end) # arg doesn't start with for
338-
end
338+
# Test bad arguments for comprehensions
339+
@test_throws ArgumentError @macroexpand(@threads [i for i in 1:10] 2) # wrong number of args
340+
@test_throws ArgumentError @macroexpand(@threads 1) # arg isn't an Expr
341+
@test_throws ArgumentError @macroexpand(@threads if true 1 end) # arg doesn't start with for or comprehension
342+
end
343+
344+
@testset "@threads comprehensions" begin
345+
# Test simple array comprehensions
346+
@testset "simple comprehensions" begin
347+
n = 1000
348+
# Test default scheduling
349+
result = @threads [i^2 for i in 1:n]
350+
@test length(result) == n
351+
@test all(result[i] == i^2 for i in 1:n)
352+
@test issorted(result) # should be ordered for default scheduling
353+
354+
# Test static scheduling
355+
result_static = @threads :static [i^2 for i in 1:n]
356+
@test length(result_static) == n
357+
@test all(result_static[i] == i^2 for i in 1:n)
358+
@test issorted(result_static) # should be ordered for static scheduling
359+
360+
# Test dynamic scheduling
361+
result_dynamic = @threads :dynamic [i^2 for i in 1:n]
362+
@test length(result_dynamic) == n
363+
@test all(result_dynamic[i] == i^2 for i in 1:n)
364+
@test issorted(result_dynamic) # should be ordered for dynamic scheduling
365+
366+
# Test greedy scheduling (may not preserve order)
367+
result_greedy = @threads :greedy [i^2 for i in 1:n]
368+
@test length(result_greedy) == n
369+
@test sort(result_greedy) == [i^2 for i in 1:n] # same elements but potentially different order
370+
end
371+
372+
# Test filtered comprehensions
373+
@testset "filtered comprehensions" begin
374+
n = 100
375+
376+
# Test default scheduling with filter
377+
result = @threads [i^2 for i in 1:n if iseven(i)]
378+
expected = [i^2 for i in 1:n if iseven(i)]
379+
@test length(result) == length(expected)
380+
@test result == expected # should preserve order
381+
382+
# Test static scheduling with filter
383+
result_static = @threads :static [i^2 for i in 1:n if iseven(i)]
384+
@test length(result_static) == length(expected)
385+
@test result_static == expected # should preserve order
386+
387+
# Test dynamic scheduling with filter
388+
result_dynamic = @threads :dynamic [i^2 for i in 1:n if iseven(i)]
389+
@test length(result_dynamic) == length(expected)
390+
@test result_dynamic == expected # should preserve order
391+
392+
# Test greedy scheduling with filter
393+
result_greedy = @threads :greedy [i^2 for i in 1:n if iseven(i)]
394+
@test length(result_greedy) == length(expected)
395+
@test sort(result_greedy) == sort(expected) # same elements but potentially different order
396+
397+
# Test with more complex filter
398+
result_complex = @threads [i for i in 1:100 if i % 3 == 0 && i > 20]
399+
expected_complex = [i for i in 1:100 if i % 3 == 0 && i > 20]
400+
@test result_complex == expected_complex
401+
end
339402

403+
# Test edge cases
404+
@testset "edge cases" begin
405+
# Empty range
406+
result_empty = @threads [i for i in 1:0]
407+
@test result_empty == []
408+
409+
# Single element
410+
result_single = @threads [i^2 for i in 1:1]
411+
@test result_single == [1]
412+
413+
# Filter that excludes all elements
414+
result_none = @threads [i for i in 1:10 if i > 20]
415+
@test result_none == []
416+
417+
# Large range to test thread distribution
418+
n = 10000
419+
result_large = @threads [i for i in 1:n]
420+
@test length(result_large) == n
421+
@test result_large == collect(1:n)
422+
end
423+
424+
# Test with side effects (should work but order may vary with greedy)
425+
@testset "side effects" begin
426+
# Test with atomic operations
427+
counter = Threads.Atomic{Int}(0)
428+
result = @threads [begin
429+
Threads.atomic_add!(counter, 1)
430+
i
431+
end for i in 1:100]
432+
@test counter[] == 100
433+
@test sort(result) == collect(1:100)
434+
435+
# Test with thread-local computation
436+
result_tid = @threads [Threads.threadid() for i in 1:100]
437+
@test length(result_tid) == 100
438+
@test all(1 <= tid <= Threads.nthreads() for tid in result_tid)
439+
end
440+
end
340441
@testset "rand_ptls underflow" begin
341442
@test Base.Partr.cong(UInt32(0)) == 0
342443
end

0 commit comments

Comments
 (0)