@@ -220,6 +220,117 @@ function _threadsfor(iter, lbody, schedule)
220
220
end
221
221
end
222
222
223
+ function _threadsfor_comprehension (gen:: Expr , schedule)
224
+ @assert gen. head === :generator
225
+
226
+ body = gen. args[1 ]
227
+ iter_or_filter = gen. args[2 ]
228
+
229
+ # Handle filtered vs non-filtered comprehensions
230
+ if isa (iter_or_filter, Expr) && iter_or_filter. head === :filter
231
+ condition = iter_or_filter. args[1 ]
232
+ iterator = iter_or_filter. args[2 ]
233
+ return _threadsfor_filtered_comprehension (body, iterator, condition, schedule)
234
+ else
235
+ iterator = iter_or_filter
236
+ return _threadsfor_simple_comprehension (body, iterator, schedule)
237
+ end
238
+ end
239
+
240
+ function _threadsfor_simple_comprehension (body, iterator, schedule)
241
+ lidx = iterator. args[1 ] # index variable
242
+ range = iterator. args[2 ] # range/iterable
243
+ esc_range = esc (range)
244
+ esc_body = esc (body)
245
+
246
+ if schedule === :greedy
247
+ quote
248
+ local ch = Channel {eltype($esc_range)} (0 ,spawn= true ) do ch
249
+ for item in $ esc_range
250
+ put! (ch, item)
251
+ end
252
+ end
253
+ local thread_result_storage = Vector {Vector{Any}} (undef, threadpoolsize ())
254
+ function threadsfor_fun (tid)
255
+ local_results = Any[]
256
+ for item in ch
257
+ local $ (esc (lidx)) = item
258
+ push! (local_results, $ esc_body)
259
+ end
260
+ thread_result_storage[tid] = local_results
261
+ end
262
+ threading_run (threadsfor_fun, false )
263
+ # Collect results after threading_run
264
+ assigned_results = [thread_result_storage[i] for i in 1 : threadpoolsize () if isassigned (thread_result_storage, i)]
265
+ vcat (assigned_results... )
266
+ end
267
+ else
268
+ func = default_comprehension_func (esc_range, lidx, esc_body)
269
+ quote
270
+ local threadsfor_fun
271
+ local result
272
+ $ func
273
+ if $ (schedule === :dynamic || schedule === :default )
274
+ threading_run (threadsfor_fun, false )
275
+ elseif ccall (:jl_in_threaded_region , Cint, ()) != 0 # :static
276
+ error (" `@threads :static` cannot be used concurrently or nested" )
277
+ else # :static
278
+ threading_run (threadsfor_fun, true )
279
+ end
280
+ result
281
+ end
282
+ end
283
+ end
284
+
285
+ function _threadsfor_filtered_comprehension (body, iterator, condition, schedule)
286
+ lidx = iterator. args[1 ] # index variable
287
+ range = iterator. args[2 ] # range/iterable
288
+ esc_range = esc (range)
289
+ esc_body = esc (body)
290
+ esc_condition = esc (condition)
291
+
292
+ if schedule === :greedy
293
+ quote
294
+ local ch = Channel {eltype($esc_range)} (0 ,spawn= true ) do ch
295
+ for item in $ esc_range
296
+ put! (ch, item)
297
+ end
298
+ end
299
+ local thread_result_storage = Vector {Vector{Any}} (undef, threadpoolsize ())
300
+ function threadsfor_fun (tid)
301
+ local_results = Any[]
302
+ for item in ch
303
+ local $ (esc (lidx)) = item
304
+ if $ esc_condition
305
+ push! (local_results, $ esc_body)
306
+ end
307
+ end
308
+ thread_result_storage[tid] = local_results
309
+ end
310
+ threading_run (threadsfor_fun, false )
311
+ # Collect results after threading_run
312
+ assigned_results = [thread_result_storage[i] for i in 1 : threadpoolsize () if isassigned (thread_result_storage, i)]
313
+ vcat (assigned_results... )
314
+ end
315
+ else
316
+ func = default_filtered_comprehension_func (esc_range, lidx, esc_body, esc_condition)
317
+ quote
318
+ local threadsfor_fun
319
+ local result
320
+ $ func
321
+ if $ (schedule === :dynamic || schedule === :default )
322
+ threading_run (threadsfor_fun, false )
323
+ elseif ccall (:jl_in_threaded_region , Cint, ()) != 0 # :static
324
+ error (" `@threads :static` cannot be used concurrently or nested" )
325
+ else # :static
326
+ threading_run (threadsfor_fun, true )
327
+ end
328
+ # Process result after threading_run
329
+ vcat (result... )
330
+ end
331
+ end
332
+ end
333
+
223
334
function greedy_func (itr, lidx, lbody)
224
335
quote
225
336
let c = Channel {eltype($itr)} (0 ,spawn= true ) do ch
@@ -237,39 +348,47 @@ function greedy_func(itr, lidx, lbody)
237
348
end
238
349
end
239
350
351
+ # Helper function to generate work distribution code
352
+ function _work_distribution_code ()
353
+ quote
354
+ r = range # Load into local variable
355
+ lenr = length (r)
356
+ # divide loop iterations among threads
357
+ if onethread
358
+ tid = 1
359
+ len, rem = lenr, 0
360
+ else
361
+ len, rem = divrem (lenr, threadpoolsize ())
362
+ end
363
+ # not enough iterations for all the threads?
364
+ if len == 0
365
+ if tid > rem
366
+ return
367
+ end
368
+ len, rem = 1 , 0
369
+ end
370
+ # compute this thread's iterations
371
+ f = firstindex (r) + ((tid- 1 ) * len)
372
+ l = f + len - 1
373
+ # distribute remaining iterations evenly
374
+ if rem > 0
375
+ if tid <= rem
376
+ f = f + (tid- 1 )
377
+ l = l + tid
378
+ else
379
+ f = f + rem
380
+ l = l + rem
381
+ end
382
+ end
383
+ end
384
+ end
385
+
240
386
function default_func (itr, lidx, lbody)
387
+ work_dist = _work_distribution_code ()
241
388
quote
242
389
let range = $ itr
243
390
function threadsfor_fun (tid = 1 ; onethread = false )
244
- r = range # Load into local variable
245
- lenr = length (r)
246
- # divide loop iterations among threads
247
- if onethread
248
- tid = 1
249
- len, rem = lenr, 0
250
- else
251
- len, rem = divrem (lenr, threadpoolsize ())
252
- end
253
- # not enough iterations for all the threads?
254
- if len == 0
255
- if tid > rem
256
- return
257
- end
258
- len, rem = 1 , 0
259
- end
260
- # compute this thread's iterations
261
- f = firstindex (r) + ((tid- 1 ) * len)
262
- l = f + len - 1
263
- # distribute remaining iterations evenly
264
- if rem > 0
265
- if tid <= rem
266
- f = f + (tid- 1 )
267
- l = l + tid
268
- else
269
- f = f + rem
270
- l = l + rem
271
- end
272
- end
391
+ $ work_dist
273
392
# run this thread's iterations
274
393
for i = f: l
275
394
local $ (esc (lidx)) = @inbounds r[i]
@@ -280,13 +399,68 @@ function default_func(itr, lidx, lbody)
280
399
end
281
400
end
282
401
402
+ function default_comprehension_func (itr, lidx, body)
403
+ work_dist = _work_distribution_code ()
404
+ quote
405
+ result = let range = $ itr
406
+ lenr = length (range)
407
+ # Pre-allocate result array with the correct size
408
+ local result_array = Vector {Any} (undef, lenr)
409
+
410
+ function threadsfor_fun (tid = 1 ; onethread = false )
411
+ $ work_dist
412
+ # run this thread's iterations and store directly in result_array
413
+ for i = f: l
414
+ local $ (esc (lidx)) = @inbounds r[i]
415
+ result_array[i] = $ body
416
+ end
417
+ end
418
+
419
+ result_array
420
+ end
421
+ end
422
+ end
423
+
424
+ function default_filtered_comprehension_func (itr, lidx, body, condition)
425
+ work_dist = _work_distribution_code ()
426
+ quote
427
+ let range = $ itr
428
+ local thread_results = Vector {Vector{Any}} (undef, threadpoolsize ())
429
+ # Initialize all result vectors to empty
430
+ for i in 1 : threadpoolsize ()
431
+ thread_results[i] = Any[]
432
+ end
433
+
434
+ function threadsfor_fun (tid = 1 ; onethread = false )
435
+ $ work_dist
436
+ # run this thread's iterations with filtering
437
+ local_results = Any[]
438
+ for i = f: l
439
+ local $ (esc (lidx)) = @inbounds r[i]
440
+ if $ condition
441
+ push! (local_results, $ body)
442
+ end
443
+ end
444
+ thread_results[tid] = local_results
445
+ end
446
+
447
+ result = thread_results # This will be populated by threading_run
448
+ end
449
+ end
450
+ end
451
+
283
452
"""
284
453
Threads.@threads [schedule] for ... end
454
+ Threads.@threads [schedule] [expr for ... end]
285
455
286
- A macro to execute a `for` loop in parallel. The iteration space is distributed to
456
+ A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to
287
457
coarse-grained tasks. This policy can be specified by the `schedule` argument. The
288
458
execution of the loop waits for the evaluation of all iterations.
289
459
460
+ For `for` loops, the macro executes the loop body in parallel but does not return a value.
461
+ For array comprehensions, the macro executes the comprehension in parallel and returns
462
+ the collected results as an array.
463
+
290
464
See also: [`@spawn`](@ref Threads.@spawn) and
291
465
`pmap` in [`Distributed`](@ref man-distributed).
292
466
@@ -371,6 +545,8 @@ thread other than 1.
371
545
372
546
## Examples
373
547
548
+ ### For loops
549
+
374
550
To illustrate of the different scheduling strategies, consider the following function
375
551
`busywait` containing a non-yielding timed loop that runs for a given number of seconds.
376
552
@@ -400,6 +576,38 @@ julia> @time begin
400
576
401
577
The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able
402
578
to run two of the 1-second iterations to complete the for loop.
579
+
580
+ ### Array comprehensions
581
+
582
+ The `@threads` macro also supports array comprehensions, which return the collected results:
583
+
584
+ ```julia-repl
585
+ julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension
586
+ 5-element Vector{Int64}:
587
+ 1
588
+ 4
589
+ 9
590
+ 16
591
+ 25
592
+
593
+ julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension
594
+ 2-element Vector{Int64}:
595
+ 4
596
+ 16
597
+ ```
598
+
599
+ When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling
600
+ option can be used, but note that the order of the results is not guaranteed.
601
+ ```julia-repl
602
+ julia> c = Channel(5, spawn=true) do ch
603
+ foreach(i -> put!(ch, i), 1:5)
604
+ end;
605
+
606
+ julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)]
607
+ 2-element Vector{Any}:
608
+ 16
609
+ 4
610
+ ```
403
611
"""
404
612
macro threads (args... )
405
613
na = length (args)
@@ -420,13 +628,18 @@ macro threads(args...)
420
628
else
421
629
throw (ArgumentError (" wrong number of arguments in @threads" ))
422
630
end
423
- if ! (isa (ex, Expr) && ex. head === :for )
424
- throw (ArgumentError (" @threads requires a `for` loop expression" ))
425
- end
426
- if ! (ex. args[1 ] isa Expr && ex. args[1 ]. head === :(= ))
427
- throw (ArgumentError (" nested outer loops are not currently supported by @threads" ))
631
+ if isa (ex, Expr) && ex. head === :comprehension
632
+ # Handle array comprehensions
633
+ return _threadsfor_comprehension (ex. args[1 ], sched)
634
+ elseif isa (ex, Expr) && ex. head === :for
635
+ # Handle for loops
636
+ if ! (ex. args[1 ] isa Expr && ex. args[1 ]. head === :(= ))
637
+ throw (ArgumentError (" nested outer loops are not currently supported by @threads" ))
638
+ end
639
+ return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
640
+ else
641
+ throw (ArgumentError (" @threads requires a `for` loop or comprehension expression" ))
428
642
end
429
- return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
430
643
end
431
644
432
645
function _spawn_set_thrpool (t:: Task , tp:: Symbol )
0 commit comments