@@ -220,6 +220,73 @@ function _threadsfor(iter, lbody, schedule)
220
220
end
221
221
end
222
222
223
+ function _threadsfor_comprehension (gen:: Expr , schedule)
224
+ @assert gen. head === :generator
225
+
226
+ body = gen. args[1 ]
227
+ iter_or_filter = gen. args[2 ]
228
+
229
+ # Handle filtered vs non-filtered comprehensions
230
+ if isa (iter_or_filter, Expr) && iter_or_filter. head === :filter
231
+ condition = iter_or_filter. args[1 ]
232
+ iterator = iter_or_filter. args[2 ]
233
+ return _threadsfor_filtered_comprehension (body, iterator, condition, schedule)
234
+ else
235
+ iterator = iter_or_filter
236
+ # Use filtered comprehension with `true` condition for non-filtered case
237
+ return _threadsfor_filtered_comprehension (body, iterator, true , schedule)
238
+ end
239
+ end
240
+
241
+ function _threadsfor_filtered_comprehension (body, iterator, condition, schedule)
242
+ lidx = iterator. args[1 ] # index variable
243
+ range = iterator. args[2 ] # range/iterable
244
+ esc_range = esc (range)
245
+ esc_body = esc (body)
246
+ esc_condition = esc (condition)
247
+
248
+ if schedule === :greedy
249
+ quote
250
+ local ch = Channel {eltype($esc_range)} (0 ,spawn= true ) do ch
251
+ for item in $ esc_range
252
+ put! (ch, item)
253
+ end
254
+ end
255
+ local thread_result_storage = Vector {Vector{Any}} (undef, threadpoolsize ())
256
+ function threadsfor_fun (tid)
257
+ local_results = Any[]
258
+ for item in ch
259
+ local $ (esc (lidx)) = item
260
+ if $ esc_condition
261
+ push! (local_results, $ esc_body)
262
+ end
263
+ end
264
+ thread_result_storage[tid] = local_results
265
+ end
266
+ threading_run (threadsfor_fun, false )
267
+ # Collect results after threading_run
268
+ assigned_results = [thread_result_storage[i] for i in 1 : threadpoolsize () if isassigned (thread_result_storage, i)]
269
+ vcat (assigned_results... )
270
+ end
271
+ else
272
+ func = default_filtered_comprehension_func (esc_range, lidx, esc_body, esc_condition)
273
+ quote
274
+ local threadsfor_fun
275
+ local result
276
+ $ func
277
+ if $ (schedule === :dynamic || schedule === :default )
278
+ threading_run (threadsfor_fun, false )
279
+ elseif ccall (:jl_in_threaded_region , Cint, ()) != 0 # :static
280
+ error (" `@threads :static` cannot be used concurrently or nested" )
281
+ else # :static
282
+ threading_run (threadsfor_fun, true )
283
+ end
284
+ # Process result after threading_run
285
+ vcat (result... )
286
+ end
287
+ end
288
+ end
289
+
223
290
function greedy_func (itr, lidx, lbody)
224
291
quote
225
292
let c = Channel {eltype($itr)} (0 ,spawn= true ) do ch
@@ -237,39 +304,47 @@ function greedy_func(itr, lidx, lbody)
237
304
end
238
305
end
239
306
307
+ # Helper function to generate work distribution code
308
+ function _work_distribution_code ()
309
+ quote
310
+ r = range # Load into local variable
311
+ lenr = length (r)
312
+ # divide loop iterations among threads
313
+ if onethread
314
+ tid = 1
315
+ len, rem = lenr, 0
316
+ else
317
+ len, rem = divrem (lenr, threadpoolsize ())
318
+ end
319
+ # not enough iterations for all the threads?
320
+ if len == 0
321
+ if tid > rem
322
+ return
323
+ end
324
+ len, rem = 1 , 0
325
+ end
326
+ # compute this thread's iterations
327
+ f = firstindex (r) + ((tid- 1 ) * len)
328
+ l = f + len - 1
329
+ # distribute remaining iterations evenly
330
+ if rem > 0
331
+ if tid <= rem
332
+ f = f + (tid- 1 )
333
+ l = l + tid
334
+ else
335
+ f = f + rem
336
+ l = l + rem
337
+ end
338
+ end
339
+ end
340
+ end
341
+
240
342
function default_func (itr, lidx, lbody)
343
+ work_dist = _work_distribution_code ()
241
344
quote
242
345
let range = $ itr
243
346
function threadsfor_fun (tid = 1 ; onethread = false )
244
- r = range # Load into local variable
245
- lenr = length (r)
246
- # divide loop iterations among threads
247
- if onethread
248
- tid = 1
249
- len, rem = lenr, 0
250
- else
251
- len, rem = divrem (lenr, threadpoolsize ())
252
- end
253
- # not enough iterations for all the threads?
254
- if len == 0
255
- if tid > rem
256
- return
257
- end
258
- len, rem = 1 , 0
259
- end
260
- # compute this thread's iterations
261
- f = firstindex (r) + ((tid- 1 ) * len)
262
- l = f + len - 1
263
- # distribute remaining iterations evenly
264
- if rem > 0
265
- if tid <= rem
266
- f = f + (tid- 1 )
267
- l = l + tid
268
- else
269
- f = f + rem
270
- l = l + rem
271
- end
272
- end
347
+ $ work_dist
273
348
# run this thread's iterations
274
349
for i = f: l
275
350
local $ (esc (lidx)) = @inbounds r[i]
@@ -280,13 +355,46 @@ function default_func(itr, lidx, lbody)
280
355
end
281
356
end
282
357
358
+ function default_filtered_comprehension_func (itr, lidx, body, condition)
359
+ work_dist = _work_distribution_code ()
360
+ quote
361
+ let range = $ itr
362
+ local thread_results = Vector {Vector{Any}} (undef, threadpoolsize ())
363
+ # Initialize all result vectors to empty
364
+ for i in 1 : threadpoolsize ()
365
+ thread_results[i] = Any[]
366
+ end
367
+
368
+ function threadsfor_fun (tid = 1 ; onethread = false )
369
+ $ work_dist
370
+ # run this thread's iterations with filtering
371
+ local_results = Any[]
372
+ for i = f: l
373
+ local $ (esc (lidx)) = @inbounds r[i]
374
+ if $ condition
375
+ push! (local_results, $ body)
376
+ end
377
+ end
378
+ thread_results[tid] = local_results
379
+ end
380
+
381
+ result = thread_results # This will be populated by threading_run
382
+ end
383
+ end
384
+ end
385
+
283
386
"""
284
387
Threads.@threads [schedule] for ... end
388
+ Threads.@threads [schedule] [expr for ... end]
285
389
286
- A macro to execute a `for` loop in parallel. The iteration space is distributed to
390
+ A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to
287
391
coarse-grained tasks. This policy can be specified by the `schedule` argument. The
288
392
execution of the loop waits for the evaluation of all iterations.
289
393
394
+ For `for` loops, the macro executes the loop body in parallel but does not return a value.
395
+ For array comprehensions, the macro executes the comprehension in parallel and returns
396
+ the collected results as an array.
397
+
290
398
See also: [`@spawn`](@ref Threads.@spawn) and
291
399
`pmap` in [`Distributed`](@ref man-distributed).
292
400
@@ -371,6 +479,8 @@ thread other than 1.
371
479
372
480
## Examples
373
481
482
+ ### For loops
483
+
374
484
To illustrate of the different scheduling strategies, consider the following function
375
485
`busywait` containing a non-yielding timed loop that runs for a given number of seconds.
376
486
@@ -400,6 +510,38 @@ julia> @time begin
400
510
401
511
The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able
402
512
to run two of the 1-second iterations to complete the for loop.
513
+
514
+ ### Array comprehensions
515
+
516
+ The `@threads` macro also supports array comprehensions, which return the collected results:
517
+
518
+ ```julia-repl
519
+ julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension
520
+ 5-element Vector{Int64}:
521
+ 1
522
+ 4
523
+ 9
524
+ 16
525
+ 25
526
+
527
+ julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension
528
+ 2-element Vector{Int64}:
529
+ 4
530
+ 16
531
+ ```
532
+
533
+ When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling
534
+ option can be used, but note that the order of the results is not guaranteed.
535
+ ```julia-repl
536
+ julia> c = Channel(5, spawn=true) do ch
537
+ foreach(i -> put!(ch, i), 1:5)
538
+ end;
539
+
540
+ julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)]
541
+ 2-element Vector{Any}:
542
+ 16
543
+ 4
544
+ ```
403
545
"""
404
546
macro threads (args... )
405
547
na = length (args)
@@ -420,13 +562,18 @@ macro threads(args...)
420
562
else
421
563
throw (ArgumentError (" wrong number of arguments in @threads" ))
422
564
end
423
- if ! (isa (ex, Expr) && ex. head === :for )
424
- throw (ArgumentError (" @threads requires a `for` loop expression" ))
425
- end
426
- if ! (ex. args[1 ] isa Expr && ex. args[1 ]. head === :(= ))
427
- throw (ArgumentError (" nested outer loops are not currently supported by @threads" ))
565
+ if isa (ex, Expr) && ex. head === :comprehension
566
+ # Handle array comprehensions
567
+ return _threadsfor_comprehension (ex. args[1 ], sched)
568
+ elseif isa (ex, Expr) && ex. head === :for
569
+ # Handle for loops
570
+ if ! (ex. args[1 ] isa Expr && ex. args[1 ]. head === :(= ))
571
+ throw (ArgumentError (" nested outer loops are not currently supported by @threads" ))
572
+ end
573
+ return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
574
+ else
575
+ throw (ArgumentError (" @threads requires a `for` loop or comprehension expression" ))
428
576
end
429
- return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
430
577
end
431
578
432
579
function _spawn_set_thrpool (t:: Task , tp:: Symbol )
0 commit comments