Skip to content

Make @threads work on array comprehensions #59019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ Multi-threading changes
the first time it is called, and then always return the same result value of type `T`
every subsequent time afterwards. There are also `OncePerThread{T}` and `OncePerTask{T}` types for
similar usage with threads or tasks. ([#TBD])
* `Threads.@threads` now supports array comprehensions with syntax like `@threads [f(i) for i in 1:n]`
and filtered comprehensions like `@threads [f(i) for i in 1:n if condition(i)]`. All scheduling
options (`:static`, `:dynamic`, `:greedy`) are supported. Results preserve element order for
`:static` and `:dynamic` scheduling, while `:greedy` may return elements in arbitrary order for
better performance. ([#TBD])

Build system changes
--------------------
Expand Down
219 changes: 183 additions & 36 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,73 @@ function _threadsfor(iter, lbody, schedule)
end
end

function _threadsfor_comprehension(gen::Expr, schedule)
@assert gen.head === :generator

body = gen.args[1]
iter_or_filter = gen.args[2]

# Handle filtered vs non-filtered comprehensions
if isa(iter_or_filter, Expr) && iter_or_filter.head === :filter
condition = iter_or_filter.args[1]
iterator = iter_or_filter.args[2]
return _threadsfor_filtered_comprehension(body, iterator, condition, schedule)
else
iterator = iter_or_filter
# Use filtered comprehension with `true` condition for non-filtered case
return _threadsfor_filtered_comprehension(body, iterator, true, schedule)
end
end

function _threadsfor_filtered_comprehension(body, iterator, condition, schedule)
lidx = iterator.args[1] # index variable
range = iterator.args[2] # range/iterable
esc_range = esc(range)
esc_body = esc(body)
esc_condition = esc(condition)

if schedule === :greedy
quote
local ch = Channel{eltype($esc_range)}(0,spawn=true) do ch
for item in $esc_range
put!(ch, item)
end
end
local thread_result_storage = Vector{Vector{Any}}(undef, threadpoolsize())
function threadsfor_fun(tid)
local_results = Any[]
for item in ch
local $(esc(lidx)) = item
if $esc_condition
push!(local_results, $esc_body)
end
end
thread_result_storage[tid] = local_results
end
threading_run(threadsfor_fun, false)
# Collect results after threading_run
assigned_results = [thread_result_storage[i] for i in 1:threadpoolsize() if isassigned(thread_result_storage, i)]
vcat(assigned_results...)
end
else
func = default_filtered_comprehension_func(esc_range, lidx, esc_body, esc_condition)
quote
local threadsfor_fun
local result
$func
if $(schedule === :dynamic || schedule === :default)
threading_run(threadsfor_fun, false)
elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static
error("`@threads :static` cannot be used concurrently or nested")
else # :static
threading_run(threadsfor_fun, true)
end
# Process result after threading_run
vcat(result...)
end
end
end

function greedy_func(itr, lidx, lbody)
quote
let c = Channel{eltype($itr)}(0,spawn=true) do ch
Expand All @@ -237,39 +304,47 @@ function greedy_func(itr, lidx, lbody)
end
end

# Helper function to generate work distribution code
function _work_distribution_code()
quote
r = range # Load into local variable
lenr = length(r)
# divide loop iterations among threads
if onethread
tid = 1
len, rem = lenr, 0
else
len, rem = divrem(lenr, threadpoolsize())
end
# not enough iterations for all the threads?
if len == 0
if tid > rem
return
end
len, rem = 1, 0
end
# compute this thread's iterations
f = firstindex(r) + ((tid-1) * len)
l = f + len - 1
# distribute remaining iterations evenly
if rem > 0
if tid <= rem
f = f + (tid-1)
l = l + tid
else
f = f + rem
l = l + rem
end
end
end
end

function default_func(itr, lidx, lbody)
work_dist = _work_distribution_code()
quote
let range = $itr
function threadsfor_fun(tid = 1; onethread = false)
r = range # Load into local variable
lenr = length(r)
# divide loop iterations among threads
if onethread
tid = 1
len, rem = lenr, 0
else
len, rem = divrem(lenr, threadpoolsize())
end
# not enough iterations for all the threads?
if len == 0
if tid > rem
return
end
len, rem = 1, 0
end
# compute this thread's iterations
f = firstindex(r) + ((tid-1) * len)
l = f + len - 1
# distribute remaining iterations evenly
if rem > 0
if tid <= rem
f = f + (tid-1)
l = l + tid
else
f = f + rem
l = l + rem
end
end
$work_dist
# run this thread's iterations
for i = f:l
local $(esc(lidx)) = @inbounds r[i]
Expand All @@ -280,13 +355,46 @@ function default_func(itr, lidx, lbody)
end
end

function default_filtered_comprehension_func(itr, lidx, body, condition)
work_dist = _work_distribution_code()
quote
let range = $itr
local thread_results = Vector{Vector{Any}}(undef, threadpoolsize())
# Initialize all result vectors to empty
for i in 1:threadpoolsize()
thread_results[i] = Any[]
end

function threadsfor_fun(tid = 1; onethread = false)
$work_dist
# run this thread's iterations with filtering
local_results = Any[]
for i = f:l
local $(esc(lidx)) = @inbounds r[i]
if $condition
push!(local_results, $body)
end
end
thread_results[tid] = local_results
end

result = thread_results # This will be populated by threading_run
end
end
end

"""
Threads.@threads [schedule] for ... end
Threads.@threads [schedule] [expr for ... end]

A macro to execute a `for` loop in parallel. The iteration space is distributed to
A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to
coarse-grained tasks. This policy can be specified by the `schedule` argument. The
execution of the loop waits for the evaluation of all iterations.

For `for` loops, the macro executes the loop body in parallel but does not return a value.
For array comprehensions, the macro executes the comprehension in parallel and returns
the collected results as an array.

See also: [`@spawn`](@ref Threads.@spawn) and
`pmap` in [`Distributed`](@ref man-distributed).

Expand Down Expand Up @@ -371,6 +479,8 @@ thread other than 1.

## Examples

### For loops

To illustrate of the different scheduling strategies, consider the following function
`busywait` containing a non-yielding timed loop that runs for a given number of seconds.

Expand Down Expand Up @@ -400,6 +510,38 @@ julia> @time begin

The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able
to run two of the 1-second iterations to complete the for loop.

### Array comprehensions

The `@threads` macro also supports array comprehensions, which return the collected results:

```julia-repl
julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension
5-element Vector{Int64}:
1
4
9
16
25

julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension
2-element Vector{Int64}:
4
16
```

When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling
option can be used, but note that the order of the results is not guaranteed.
```julia-repl
julia> c = Channel(5, spawn=true) do ch
foreach(i -> put!(ch, i), 1:5)
end;

julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)]
2-element Vector{Any}:
16
4
```
"""
macro threads(args...)
na = length(args)
Expand All @@ -420,13 +562,18 @@ macro threads(args...)
else
throw(ArgumentError("wrong number of arguments in @threads"))
end
if !(isa(ex, Expr) && ex.head === :for)
throw(ArgumentError("@threads requires a `for` loop expression"))
end
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
if isa(ex, Expr) && ex.head === :comprehension
# Handle array comprehensions
return _threadsfor_comprehension(ex.args[1], sched)
elseif isa(ex, Expr) && ex.head === :for
# Handle for loops
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
end
return _threadsfor(ex.args[1], ex.args[2], sched)
else
throw(ArgumentError("@threads requires a `for` loop or comprehension expression"))
end
return _threadsfor(ex.args[1], ex.args[2], sched)
end

function _spawn_set_thrpool(t::Task, tp::Symbol)
Expand Down
103 changes: 102 additions & 1 deletion test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,109 @@ end
@test_throws ArgumentError @macroexpand(@threads 1 2) # wrong number of args
@test_throws ArgumentError @macroexpand(@threads 1) # arg isn't an Expr
@test_throws ArgumentError @macroexpand(@threads if true 1 end) # arg doesn't start with for
end
# Test bad arguments for comprehensions
@test_throws ArgumentError @macroexpand(@threads [i for i in 1:10] 2) # wrong number of args
@test_throws ArgumentError @macroexpand(@threads 1) # arg isn't an Expr
@test_throws ArgumentError @macroexpand(@threads if true 1 end) # arg doesn't start with for or comprehension
end

@testset "@threads comprehensions" begin
# Test simple array comprehensions
@testset "simple comprehensions" begin
n = 1000
# Test default scheduling
result = @threads [i^2 for i in 1:n]
@test length(result) == n
@test all(result[i] == i^2 for i in 1:n)
@test issorted(result) # should be ordered for default scheduling

# Test static scheduling
result_static = @threads :static [i^2 for i in 1:n]
@test length(result_static) == n
@test all(result_static[i] == i^2 for i in 1:n)
@test issorted(result_static) # should be ordered for static scheduling

# Test dynamic scheduling
result_dynamic = @threads :dynamic [i^2 for i in 1:n]
@test length(result_dynamic) == n
@test all(result_dynamic[i] == i^2 for i in 1:n)
@test issorted(result_dynamic) # should be ordered for dynamic scheduling

# Test greedy scheduling (may not preserve order)
result_greedy = @threads :greedy [i^2 for i in 1:n]
@test length(result_greedy) == n
@test sort(result_greedy) == [i^2 for i in 1:n] # same elements but potentially different order
end

# Test filtered comprehensions
@testset "filtered comprehensions" begin
n = 100

# Test default scheduling with filter
result = @threads [i^2 for i in 1:n if iseven(i)]
expected = [i^2 for i in 1:n if iseven(i)]
@test length(result) == length(expected)
@test result == expected # should preserve order

# Test static scheduling with filter
result_static = @threads :static [i^2 for i in 1:n if iseven(i)]
@test length(result_static) == length(expected)
@test result_static == expected # should preserve order

# Test dynamic scheduling with filter
result_dynamic = @threads :dynamic [i^2 for i in 1:n if iseven(i)]
@test length(result_dynamic) == length(expected)
@test result_dynamic == expected # should preserve order

# Test greedy scheduling with filter
result_greedy = @threads :greedy [i^2 for i in 1:n if iseven(i)]
@test length(result_greedy) == length(expected)
@test sort(result_greedy) == sort(expected) # same elements but potentially different order

# Test with more complex filter
result_complex = @threads [i for i in 1:100 if i % 3 == 0 && i > 20]
expected_complex = [i for i in 1:100 if i % 3 == 0 && i > 20]
@test result_complex == expected_complex
end

# Test edge cases
@testset "edge cases" begin
# Empty range
result_empty = @threads [i for i in 1:0]
@test result_empty == []

# Single element
result_single = @threads [i^2 for i in 1:1]
@test result_single == [1]

# Filter that excludes all elements
result_none = @threads [i for i in 1:10 if i > 20]
@test result_none == []

# Large range to test thread distribution
n = 10000
result_large = @threads [i for i in 1:n]
@test length(result_large) == n
@test result_large == collect(1:n)
end

# Test with side effects (should work but order may vary with greedy)
@testset "side effects" begin
# Test with atomic operations
counter = Threads.Atomic{Int}(0)
result = @threads [begin
Threads.atomic_add!(counter, 1)
i
end for i in 1:100]
@test counter[] == 100
@test sort(result) == collect(1:100)

# Test with thread-local computation
result_tid = @threads [Threads.threadid() for i in 1:100]
@test length(result_tid) == 100
@test all(1 <= tid <= Threads.nthreads() for tid in result_tid)
end
end
@testset "rand_ptls underflow" begin
@test Base.Partr.cong(UInt32(0)) == 0
end
Expand Down