Skip to content

Commit 4aff96c

Browse files
committed
fix nested unroll macros
1 parent ba2bdbf commit 4aff96c

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
44
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
55
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
66
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
7+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

src/extras/loopinfo.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LoopInfo
22

3-
const HAS_LOOPINFO_EXPR = VERSION >= v"1.2.0-DEV.462"
3+
using MacroTools
44
export @unroll
55

66
##
@@ -20,13 +20,16 @@ module MD
2020
end
2121

2222
function loopinfo(expr, nodes...)
23-
if expr.head != :for
23+
if @capture(expr, for i_ in iter_ body__ end)
24+
return quote
25+
for $i in $iter
26+
$(body...)
27+
$(Expr(:loopinfo, nodes...))
28+
end
29+
end
30+
else
2431
error("Syntax error: loopinfo needs a for loop")
2532
end
26-
if HAS_LOOPINFO_EXPR
27-
push!(expr.args[2].args, Expr(:loopinfo, nodes...))
28-
end
29-
return expr
3033
end
3134

3235
"""
@@ -48,6 +51,7 @@ if it is safe to do so.
4851
"""
4952
macro unroll(N, expr)
5053
if !(N isa Integer)
54+
@debug "@unroll macro inputs" N expr
5155
error("Syntax error: `@unroll N expr` needs a constant integer N")
5256
end
5357
expr = loopinfo(expr, MD.unroll_count(N))

test/unroll.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using KernelAbstractions
22
using KernelAbstractions.Extras
3+
using StaticArrays
34

45
@kernel function kernel_unroll!(a)
56
@unroll for i in 1:5
@@ -13,6 +14,25 @@ end
1314
end
1415
end
1516

17+
# Check that nested `@unroll` doesn't throw a syntax error
18+
@kernel function kernel_unroll!(a, ::Val{N}) where N
19+
@uniform begin
20+
a = MVector{3, Float64}(1, 2, 3)
21+
b = MVector{3, Float64}(3, 2, 1)
22+
c = MMatrix{3, 3, Float64}(undef)
23+
end
24+
I = @index(Global)
25+
@inbounds for m in 1:3
26+
@unroll for j = 1:3
27+
@unroll for i = 1:3
28+
c[1, j] = m * a[1] * b[j]
29+
end
30+
end
31+
a[I] = c[1, 1]
32+
m % 2 == 0 && @synchronize
33+
end
34+
end
35+
1636
let
1737
a = zeros(5)
1838
kernel! = kernel_unroll!(CPU(), 1, 1)

0 commit comments

Comments
 (0)