Skip to content

Commit a72a0d3

Browse files
bors[bot]vchuravy
andauthored
Merge #39
39: fix nested unroll macros r=vchuravy a=vchuravy bors r+ Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
2 parents ba2bdbf + 98d8f75 commit a72a0d3

File tree

5 files changed

+70
-15
lines changed

5 files changed

+70
-15
lines changed

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
44
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
55
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
66
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
7+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

src/KernelAbstractions.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,34 @@ end
115115

116116
"""
117117
@synchronize()
118+
119+
After a `@synchronize` statement all read and writes to global and local memory
120+
from each thread in the workgroup are visible in from all other threads in the
121+
workgroup.
118122
"""
119123
macro synchronize()
120124
quote
121125
$__synchronize()
122126
end
123127
end
124128

129+
"""
130+
@synchronize(cond)
131+
132+
After a `@synchronize` statement all read and writes to global and local memory
133+
from each thread in the workgroup are visible in from all other threads in the
134+
workgroup. `cond` is not allowed to have any visible sideffects.
135+
136+
# Platform differences
137+
- `GPU`: This synchronization will only occur if the `cond` evaluates.
138+
- `CPU`: This synchronization will always occur.
139+
"""
140+
macro synchronize(cond)
141+
quote
142+
$(esc(cond)) && $__synchronize()
143+
end
144+
end
145+
125146
"""
126147
@index
127148

src/extras/loopinfo.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LoopInfo
22

3-
const HAS_LOOPINFO_EXPR = VERSION >= v"1.2.0-DEV.462"
3+
using MacroTools
44
export @unroll
55

66
##
@@ -20,13 +20,16 @@ module MD
2020
end
2121

2222
function loopinfo(expr, nodes...)
23-
if expr.head != :for
23+
if @capture(expr, for i_ in iter_ body__ end)
24+
return quote
25+
for $i in $iter
26+
$(body...)
27+
$(Expr(:loopinfo, nodes...))
28+
end
29+
end
30+
else
2431
error("Syntax error: loopinfo needs a for loop")
2532
end
26-
if HAS_LOOPINFO_EXPR
27-
push!(expr.args[2].args, Expr(:loopinfo, nodes...))
28-
end
29-
return expr
3033
end
3134

3235
"""
@@ -48,6 +51,7 @@ if it is safe to do so.
4851
"""
4952
macro unroll(N, expr)
5053
if !(N isa Integer)
54+
@debug "@unroll macro inputs" N expr
5155
error("Syntax error: `@unroll N expr` needs a constant integer N")
5256
end
5357
expr = loopinfo(expr, MD.unroll_count(N))

src/macros.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,13 @@ struct WorkgroupLoop
103103
private :: Vector{Any}
104104
end
105105

106-
is_sync(expr) = @capture(expr, @synchronize)
106+
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))
107+
108+
function is_scope_construct(expr::Expr)
109+
expr.head === :block # ||
110+
# expr.head === :let
111+
end
112+
107113
function find_sync(stmt)
108114
result = false
109115
postwalk(stmt) do expr
@@ -131,7 +137,7 @@ function split(stmts,
131137
push!(new_stmts, emit(loop))
132138
allocations = Any[]
133139
current = Any[]
134-
@capture(stmt, @synchronize) && continue
140+
is_sync(stmt) && continue
135141

136142
# Recurse into scope constructs
137143
# TODO: This currently implements hard scoping
@@ -140,7 +146,7 @@ function split(stmts,
140146
recurse(x) = x
141147
function recurse(expr::Expr)
142148
expr = unblock(expr)
143-
if any(is_sync, expr.args)
149+
if is_scope_construct(expr) && any(is_sync, expr.args)
144150
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
145151
return Expr(expr.head, new_args...)
146152
else

test/unroll.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using KernelAbstractions
22
using KernelAbstractions.Extras
3+
using StaticArrays
34

45
@kernel function kernel_unroll!(a)
56
@unroll for i in 1:5
@@ -8,16 +9,38 @@ using KernelAbstractions.Extras
89
end
910

1011
@kernel function kernel_unroll!(a, ::Val{N}) where N
11-
@unroll for i in 1:N
12-
@inbounds a[i] = i
12+
let M = N+5
13+
@unroll for i in 6:M
14+
@inbounds a[i-5] = i
15+
end
16+
@synchronize
17+
end
18+
end
19+
20+
# Check that nested `@unroll` doesn't throw a syntax error
21+
@kernel function kernel_unroll2!(A)
22+
@uniform begin
23+
a = MVector{3, Float64}(1, 2, 3)
24+
b = MVector{3, Float64}(3, 2, 1)
25+
c = MMatrix{3, 3, Float64}(undef)
26+
end
27+
I = @index(Global)
28+
@inbounds for m in 1:3
29+
@unroll for j = 1:3
30+
@unroll for i = 1:3
31+
c[1, j] = m * a[1] * b[j]
32+
end
33+
end
34+
A[I] = c[1, 1]
35+
@synchronize(m % 2 == 0)
1336
end
1437
end
1538

1639
let
1740
a = zeros(5)
1841
kernel! = kernel_unroll!(CPU(), 1, 1)
19-
event = kernel!(a)
20-
wait(event)
21-
event = kernel!(a, Val(5))
22-
wait(event)
42+
wait(kernel!(a))
43+
wait(kernel!(a, Val(5)))
44+
kernel2! = kernel_unroll2!(CPU(), 1, 1)
45+
wait(kernel2!(a))
2346
end

0 commit comments

Comments
 (0)