File tree Expand file tree Collapse file tree 5 files changed +70
-15
lines changed Expand file tree Collapse file tree 5 files changed +70
-15
lines changed Original file line number Diff line number Diff line change @@ -4,4 +4,5 @@ CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
4
4
CUDAnative = " be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
5
5
CuArrays = " 3a865a2d-5b23-5a0f-bc46-62713ec82fae"
6
6
KernelAbstractions = " 63c18a36-062a-441e-b654-da1e3ab1ce7c"
7
+ StaticArrays = " 90137ffa-7385-5640-81b9-e52037218182"
7
8
Test = " 8dfed614-e22c-5e08-85e1-65c5234f0b40"
Original file line number Diff line number Diff line change @@ -115,13 +115,34 @@ end
115
115
116
116
"""
117
117
@synchronize()
118
+
119
+ After a `@synchronize` statement all read and writes to global and local memory
120
+ from each thread in the workgroup are visible in from all other threads in the
121
+ workgroup.
118
122
"""
119
123
macro synchronize ()
120
124
quote
121
125
$ __synchronize ()
122
126
end
123
127
end
124
128
129
+ """
130
+ @synchronize(cond)
131
+
132
+ After a `@synchronize` statement all read and writes to global and local memory
133
+ from each thread in the workgroup are visible in from all other threads in the
134
+ workgroup. `cond` is not allowed to have any visible sideffects.
135
+
136
+ # Platform differences
137
+ - `GPU`: This synchronization will only occur if the `cond` evaluates.
138
+ - `CPU`: This synchronization will always occur.
139
+ """
140
+ macro synchronize (cond)
141
+ quote
142
+ $ (esc (cond)) && $ __synchronize ()
143
+ end
144
+ end
145
+
125
146
"""
126
147
@index
127
148
Original file line number Diff line number Diff line change 1
1
module LoopInfo
2
2
3
- const HAS_LOOPINFO_EXPR = VERSION >= v " 1.2.0-DEV.462 "
3
+ using MacroTools
4
4
export @unroll
5
5
6
6
# #
@@ -20,13 +20,16 @@ module MD
20
20
end
21
21
22
22
function loopinfo (expr, nodes... )
23
- if expr. head != :for
23
+ if @capture (expr, for i_ in iter_ body__ end )
24
+ return quote
25
+ for $ i in $ iter
26
+ $ (body... )
27
+ $ (Expr (:loopinfo , nodes... ))
28
+ end
29
+ end
30
+ else
24
31
error (" Syntax error: loopinfo needs a for loop" )
25
32
end
26
- if HAS_LOOPINFO_EXPR
27
- push! (expr. args[2 ]. args, Expr (:loopinfo , nodes... ))
28
- end
29
- return expr
30
33
end
31
34
32
35
"""
@@ -48,6 +51,7 @@ if it is safe to do so.
48
51
"""
49
52
macro unroll (N, expr)
50
53
if ! (N isa Integer)
54
+ @debug " @unroll macro inputs" N expr
51
55
error (" Syntax error: `@unroll N expr` needs a constant integer N" )
52
56
end
53
57
expr = loopinfo (expr, MD. unroll_count (N))
Original file line number Diff line number Diff line change @@ -103,7 +103,13 @@ struct WorkgroupLoop
103
103
private :: Vector{Any}
104
104
end
105
105
106
- is_sync (expr) = @capture (expr, @synchronize )
106
+ is_sync (expr) = @capture (expr, @synchronize () | @synchronize (a_))
107
+
108
+ function is_scope_construct (expr:: Expr )
109
+ expr. head === :block # ||
110
+ # expr.head === :let
111
+ end
112
+
107
113
function find_sync (stmt)
108
114
result = false
109
115
postwalk (stmt) do expr
@@ -131,7 +137,7 @@ function split(stmts,
131
137
push! (new_stmts, emit (loop))
132
138
allocations = Any[]
133
139
current = Any[]
134
- @capture (stmt, @synchronize ) && continue
140
+ is_sync (stmt) && continue
135
141
136
142
# Recurse into scope constructs
137
143
# TODO : This currently implements hard scoping
@@ -140,7 +146,7 @@ function split(stmts,
140
146
recurse (x) = x
141
147
function recurse (expr:: Expr )
142
148
expr = unblock (expr)
143
- if any (is_sync, expr. args)
149
+ if is_scope_construct (expr) && any (is_sync, expr. args)
144
150
new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private)))
145
151
return Expr (expr. head, new_args... )
146
152
else
Original file line number Diff line number Diff line change 1
1
using KernelAbstractions
2
2
using KernelAbstractions. Extras
3
+ using StaticArrays
3
4
4
5
@kernel function kernel_unroll! (a)
5
6
@unroll for i in 1 : 5
@@ -8,16 +9,38 @@ using KernelAbstractions.Extras
8
9
end
9
10
10
11
@kernel function kernel_unroll! (a, :: Val{N} ) where N
11
- @unroll for i in 1 : N
12
- @inbounds a[i] = i
12
+ let M = N+ 5
13
+ @unroll for i in 6 : M
14
+ @inbounds a[i- 5 ] = i
15
+ end
16
+ @synchronize
17
+ end
18
+ end
19
+
20
+ # Check that nested `@unroll` doesn't throw a syntax error
21
+ @kernel function kernel_unroll2! (A)
22
+ @uniform begin
23
+ a = MVector {3, Float64} (1 , 2 , 3 )
24
+ b = MVector {3, Float64} (3 , 2 , 1 )
25
+ c = MMatrix {3, 3, Float64} (undef)
26
+ end
27
+ I = @index (Global)
28
+ @inbounds for m in 1 : 3
29
+ @unroll for j = 1 : 3
30
+ @unroll for i = 1 : 3
31
+ c[1 , j] = m * a[1 ] * b[j]
32
+ end
33
+ end
34
+ A[I] = c[1 , 1 ]
35
+ @synchronize (m % 2 == 0 )
13
36
end
14
37
end
15
38
16
39
let
17
40
a = zeros (5 )
18
41
kernel! = kernel_unroll! (CPU (), 1 , 1 )
19
- event = kernel! (a)
20
- wait (event )
21
- event = kernel! (a, Val ( 5 ) )
22
- wait (event )
42
+ wait ( kernel! (a) )
43
+ wait (kernel! (a, Val ( 5 )) )
44
+ kernel2! = kernel_unroll2! ( CPU (), 1 , 1 )
45
+ wait (kernel2! (a) )
23
46
end
You can’t perform that action at this time.
0 commit comments