From d2bced7b94018ff034b969829ff8a0526891a7fa Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Tue, 15 Jul 2025 19:37:47 -0400 Subject: [PATCH] Make `@threads` work on array comprehensions --- NEWS.md | 5 + base/threadingconstructs.jl | 219 ++++++++++++++++++++++++++++++------ test/threads.jl | 103 ++++++++++++++++- 3 files changed, 290 insertions(+), 37 deletions(-) diff --git a/NEWS.md b/NEWS.md index 5d9bf83467b77..51f979510d19a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -37,6 +37,11 @@ Multi-threading changes the first time it is called, and then always return the same result value of type `T` every subsequent time afterwards. There are also `OncePerThread{T}` and `OncePerTask{T}` types for similar usage with threads or tasks. ([#TBD]) +* `Threads.@threads` now supports array comprehensions with syntax like `@threads [f(i) for i in 1:n]` + and filtered comprehensions like `@threads [f(i) for i in 1:n if condition(i)]`. All scheduling + options (`:static`, `:dynamic`, `:greedy`) are supported. Results preserve element order for + `:static` and `:dynamic` scheduling, while `:greedy` may return elements in arbitrary order for + better performance. ([#TBD]) Build system changes -------------------- diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index 9d175bf0d8e81..8f8690034f0ce 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -220,6 +220,73 @@ function _threadsfor(iter, lbody, schedule) end end +function _threadsfor_comprehension(gen::Expr, schedule) + @assert gen.head === :generator + + body = gen.args[1] + iter_or_filter = gen.args[2] + + # Handle filtered vs non-filtered comprehensions + if isa(iter_or_filter, Expr) && iter_or_filter.head === :filter + condition = iter_or_filter.args[1] + iterator = iter_or_filter.args[2] + return _threadsfor_filtered_comprehension(body, iterator, condition, schedule) + else + iterator = iter_or_filter + # Use filtered comprehension with `true` condition for non-filtered case + return _threadsfor_filtered_comprehension(body, iterator, true, schedule) + end +end + +function _threadsfor_filtered_comprehension(body, iterator, condition, schedule) + lidx = iterator.args[1] # index variable + range = iterator.args[2] # range/iterable + esc_range = esc(range) + esc_body = esc(body) + esc_condition = esc(condition) + + if schedule === :greedy + quote + local ch = Channel{eltype($esc_range)}(0,spawn=true) do ch + for item in $esc_range + put!(ch, item) + end + end + local thread_result_storage = Vector{Vector{Any}}(undef, threadpoolsize()) + function threadsfor_fun(tid) + local_results = Any[] + for item in ch + local $(esc(lidx)) = item + if $esc_condition + push!(local_results, $esc_body) + end + end + thread_result_storage[tid] = local_results + end + threading_run(threadsfor_fun, false) + # Collect results after threading_run + assigned_results = [thread_result_storage[i] for i in 1:threadpoolsize() if isassigned(thread_result_storage, i)] + vcat(assigned_results...) + end + else + func = default_filtered_comprehension_func(esc_range, lidx, esc_body, esc_condition) + quote + local threadsfor_fun + local result + $func + if $(schedule === :dynamic || schedule === :default) + threading_run(threadsfor_fun, false) + elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static + error("`@threads :static` cannot be used concurrently or nested") + else # :static + threading_run(threadsfor_fun, true) + end + # Process result after threading_run + vcat(result...) + end + end +end + function greedy_func(itr, lidx, lbody) quote let c = Channel{eltype($itr)}(0,spawn=true) do ch @@ -237,39 +304,47 @@ function greedy_func(itr, lidx, lbody) end end +# Helper function to generate work distribution code +function _work_distribution_code() + quote + r = range # Load into local variable + lenr = length(r) + # divide loop iterations among threads + if onethread + tid = 1 + len, rem = lenr, 0 + else + len, rem = divrem(lenr, threadpoolsize()) + end + # not enough iterations for all the threads? + if len == 0 + if tid > rem + return + end + len, rem = 1, 0 + end + # compute this thread's iterations + f = firstindex(r) + ((tid-1) * len) + l = f + len - 1 + # distribute remaining iterations evenly + if rem > 0 + if tid <= rem + f = f + (tid-1) + l = l + tid + else + f = f + rem + l = l + rem + end + end + end +end + function default_func(itr, lidx, lbody) + work_dist = _work_distribution_code() quote let range = $itr function threadsfor_fun(tid = 1; onethread = false) - r = range # Load into local variable - lenr = length(r) - # divide loop iterations among threads - if onethread - tid = 1 - len, rem = lenr, 0 - else - len, rem = divrem(lenr, threadpoolsize()) - end - # not enough iterations for all the threads? - if len == 0 - if tid > rem - return - end - len, rem = 1, 0 - end - # compute this thread's iterations - f = firstindex(r) + ((tid-1) * len) - l = f + len - 1 - # distribute remaining iterations evenly - if rem > 0 - if tid <= rem - f = f + (tid-1) - l = l + tid - else - f = f + rem - l = l + rem - end - end + $work_dist # run this thread's iterations for i = f:l local $(esc(lidx)) = @inbounds r[i] @@ -280,13 +355,46 @@ function default_func(itr, lidx, lbody) end end +function default_filtered_comprehension_func(itr, lidx, body, condition) + work_dist = _work_distribution_code() + quote + let range = $itr + local thread_results = Vector{Vector{Any}}(undef, threadpoolsize()) + # Initialize all result vectors to empty + for i in 1:threadpoolsize() + thread_results[i] = Any[] + end + + function threadsfor_fun(tid = 1; onethread = false) + $work_dist + # run this thread's iterations with filtering + local_results = Any[] + for i = f:l + local $(esc(lidx)) = @inbounds r[i] + if $condition + push!(local_results, $body) + end + end + thread_results[tid] = local_results + end + + result = thread_results # This will be populated by threading_run + end + end +end + """ Threads.@threads [schedule] for ... end + Threads.@threads [schedule] [expr for ... end] -A macro to execute a `for` loop in parallel. The iteration space is distributed to +A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to coarse-grained tasks. This policy can be specified by the `schedule` argument. The execution of the loop waits for the evaluation of all iterations. +For `for` loops, the macro executes the loop body in parallel but does not return a value. +For array comprehensions, the macro executes the comprehension in parallel and returns +the collected results as an array. + See also: [`@spawn`](@ref Threads.@spawn) and `pmap` in [`Distributed`](@ref man-distributed). @@ -371,6 +479,8 @@ thread other than 1. ## Examples +### For loops + To illustrate of the different scheduling strategies, consider the following function `busywait` containing a non-yielding timed loop that runs for a given number of seconds. @@ -400,6 +510,38 @@ julia> @time begin The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able to run two of the 1-second iterations to complete the for loop. + +### Array comprehensions + +The `@threads` macro also supports array comprehensions, which return the collected results: + +```julia-repl +julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension +5-element Vector{Int64}: + 1 + 4 + 9 + 16 + 25 + +julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension +2-element Vector{Int64}: + 4 + 16 +``` + +When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling +option can be used, but note that the order of the results is not guaranteed. +```julia-repl +julia> c = Channel(5, spawn=true) do ch + foreach(i -> put!(ch, i), 1:5) + end; + +julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)] +2-element Vector{Any}: + 16 + 4 +``` """ macro threads(args...) na = length(args) @@ -420,13 +562,18 @@ macro threads(args...) else throw(ArgumentError("wrong number of arguments in @threads")) end - if !(isa(ex, Expr) && ex.head === :for) - throw(ArgumentError("@threads requires a `for` loop expression")) - end - if !(ex.args[1] isa Expr && ex.args[1].head === :(=)) - throw(ArgumentError("nested outer loops are not currently supported by @threads")) + if isa(ex, Expr) && ex.head === :comprehension + # Handle array comprehensions + return _threadsfor_comprehension(ex.args[1], sched) + elseif isa(ex, Expr) && ex.head === :for + # Handle for loops + if !(ex.args[1] isa Expr && ex.args[1].head === :(=)) + throw(ArgumentError("nested outer loops are not currently supported by @threads")) + end + return _threadsfor(ex.args[1], ex.args[2], sched) + else + throw(ArgumentError("@threads requires a `for` loop or comprehension expression")) end - return _threadsfor(ex.args[1], ex.args[2], sched) end function _spawn_set_thrpool(t::Task, tp::Symbol) diff --git a/test/threads.jl b/test/threads.jl index fa0b33a6352f3..a2e9b92c32b64 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -335,8 +335,109 @@ end @test_throws ArgumentError @macroexpand(@threads 1 2) # wrong number of args @test_throws ArgumentError @macroexpand(@threads 1) # arg isn't an Expr @test_throws ArgumentError @macroexpand(@threads if true 1 end) # arg doesn't start with for -end + # Test bad arguments for comprehensions + @test_throws ArgumentError @macroexpand(@threads [i for i in 1:10] 2) # wrong number of args + @test_throws ArgumentError @macroexpand(@threads 1) # arg isn't an Expr + @test_throws ArgumentError @macroexpand(@threads if true 1 end) # arg doesn't start with for or comprehension +end + +@testset "@threads comprehensions" begin + # Test simple array comprehensions + @testset "simple comprehensions" begin + n = 1000 + # Test default scheduling + result = @threads [i^2 for i in 1:n] + @test length(result) == n + @test all(result[i] == i^2 for i in 1:n) + @test issorted(result) # should be ordered for default scheduling + + # Test static scheduling + result_static = @threads :static [i^2 for i in 1:n] + @test length(result_static) == n + @test all(result_static[i] == i^2 for i in 1:n) + @test issorted(result_static) # should be ordered for static scheduling + + # Test dynamic scheduling + result_dynamic = @threads :dynamic [i^2 for i in 1:n] + @test length(result_dynamic) == n + @test all(result_dynamic[i] == i^2 for i in 1:n) + @test issorted(result_dynamic) # should be ordered for dynamic scheduling + + # Test greedy scheduling (may not preserve order) + result_greedy = @threads :greedy [i^2 for i in 1:n] + @test length(result_greedy) == n + @test sort(result_greedy) == [i^2 for i in 1:n] # same elements but potentially different order + end + + # Test filtered comprehensions + @testset "filtered comprehensions" begin + n = 100 + + # Test default scheduling with filter + result = @threads [i^2 for i in 1:n if iseven(i)] + expected = [i^2 for i in 1:n if iseven(i)] + @test length(result) == length(expected) + @test result == expected # should preserve order + + # Test static scheduling with filter + result_static = @threads :static [i^2 for i in 1:n if iseven(i)] + @test length(result_static) == length(expected) + @test result_static == expected # should preserve order + + # Test dynamic scheduling with filter + result_dynamic = @threads :dynamic [i^2 for i in 1:n if iseven(i)] + @test length(result_dynamic) == length(expected) + @test result_dynamic == expected # should preserve order + + # Test greedy scheduling with filter + result_greedy = @threads :greedy [i^2 for i in 1:n if iseven(i)] + @test length(result_greedy) == length(expected) + @test sort(result_greedy) == sort(expected) # same elements but potentially different order + + # Test with more complex filter + result_complex = @threads [i for i in 1:100 if i % 3 == 0 && i > 20] + expected_complex = [i for i in 1:100 if i % 3 == 0 && i > 20] + @test result_complex == expected_complex + end + # Test edge cases + @testset "edge cases" begin + # Empty range + result_empty = @threads [i for i in 1:0] + @test result_empty == [] + + # Single element + result_single = @threads [i^2 for i in 1:1] + @test result_single == [1] + + # Filter that excludes all elements + result_none = @threads [i for i in 1:10 if i > 20] + @test result_none == [] + + # Large range to test thread distribution + n = 10000 + result_large = @threads [i for i in 1:n] + @test length(result_large) == n + @test result_large == collect(1:n) + end + + # Test with side effects (should work but order may vary with greedy) + @testset "side effects" begin + # Test with atomic operations + counter = Threads.Atomic{Int}(0) + result = @threads [begin + Threads.atomic_add!(counter, 1) + i + end for i in 1:100] + @test counter[] == 100 + @test sort(result) == collect(1:100) + + # Test with thread-local computation + result_tid = @threads [Threads.threadid() for i in 1:100] + @test length(result_tid) == 100 + @test all(1 <= tid <= Threads.nthreads() for tid in result_tid) + end +end @testset "rand_ptls underflow" begin @test Base.Partr.cong(UInt32(0)) == 0 end