@@ -89,12 +89,8 @@ function transform_cpu!(def, constargs)
89
89
end
90
90
91
91
body = MacroTools. flatten (def[:body ])
92
- loops = split (body)
93
-
94
92
push! (new_stmts, Expr (:aliasscope ))
95
- for loop in loops
96
- push! (new_stmts, emit (loop))
97
- end
93
+ append! (new_stmts, split (body. args))
98
94
push! (new_stmts, Expr (:popaliasscope ))
99
95
push! (new_stmts, :(return nothing ))
100
96
def[:body ] = Expr (:block , new_stmts... )
@@ -107,46 +103,68 @@ struct WorkgroupLoop
107
103
private :: Vector{Any}
108
104
end
109
105
106
+ is_sync (expr) = @capture (expr, @synchronize )
107
+ function find_sync (stmt)
108
+ result = false
109
+ postwalk (stmt) do expr
110
+ result |= is_sync (expr)
111
+ expr
112
+ end
113
+ result
114
+ end
110
115
111
- function split (stmts)
116
+ # TODO proper handling of LineInfo
117
+ function split (stmts,
118
+ indicies = Any[], private= Any[])
112
119
# 1. Split the code into blocks separated by `@synchronize`
113
120
# 2. Aggregate `@index` expressions
114
121
# 3. Hoist allocations
115
122
# 4. Hoist uniforms
116
123
117
124
current = Any[]
118
- indicies = Any[]
119
125
allocations = Any[]
120
- private = Any[]
121
-
122
- loops = WorkgroupLoop[]
123
- for stmt in stmts. args
124
- if isexpr (stmt, :macrocall )
125
- if stmt. args[1 ] === Symbol (" @synchronize" )
126
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, deepcopy (private))
127
- push! (loops, loop)
128
- allocations = Any[]
129
- current = Any[]
126
+ new_stmts = Any[]
127
+ for stmt in stmts
128
+ has_sync = find_sync (stmt)
129
+ if has_sync
130
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, deepcopy (private))
131
+ push! (new_stmts, emit (loop))
132
+ allocations = Any[]
133
+ current = Any[]
134
+ @capture (stmt, @synchronize ) && continue
135
+
136
+ # Recurse into scope constructs
137
+ # TODO : This currently implements hard scoping
138
+ # probably need to implemet soft scoping
139
+ # by not deepcopying the environment.
140
+ recurse (x) = x
141
+ function recurse (expr:: Expr )
142
+ expr = unblock (expr)
143
+ if any (is_sync, expr. args)
144
+ new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private)))
145
+ return Expr (expr. head, new_args... )
146
+ else
147
+ return Expr (expr. head, map (recurse, expr. args)... )
148
+ end
149
+ end
150
+ push! (new_stmts, recurse (stmt))
151
+ continue
152
+ end
153
+
154
+ if @capture (stmt, @uniform x_)
155
+ push! (allocations, stmt)
156
+ continue
157
+ elseif @capture (stmt, lhs_ = rhs_)
158
+ if @capture (rhs, @index (args__))
159
+ push! (indicies, stmt)
130
160
continue
131
- elseif stmt . args[ 1 ] === Symbol ( " @uniform" )
161
+ elseif @capture (rhs, @localmem (args__) | @uniform (args__) )
132
162
push! (allocations, stmt)
133
- end
134
- elseif isexpr (stmt, :(= ))
135
- rhs = stmt. args[2 ]
136
- if isexpr (rhs, :macrocall )
137
- callee = rhs. args[1 ]
138
- if callee === Symbol (" @index" )
139
- push! (indicies, stmt)
140
- continue
141
- elseif callee === Symbol (" @localmem" ) ||
142
- callee === Symbol (" @uniform" )
143
- push! (allocations, stmt)
144
- continue
145
- elseif callee === Symbol (" @private" )
146
- push! (allocations, stmt)
147
- push! (private, stmt. args[1 ])
148
- continue
149
- end
163
+ continue
164
+ elseif @capture (rhs, @private (args__))
165
+ push! (allocations, stmt)
166
+ push! (private, lhs)
167
+ continue
150
168
end
151
169
end
152
170
@@ -155,9 +173,10 @@ function split(stmts)
155
173
156
174
# everything since the last `@synchronize`
157
175
if ! isempty (current)
158
- push! (loops, WorkgroupLoop (deepcopy (indicies), current, allocations, deepcopy (private)))
176
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, deepcopy (private))
177
+ push! (new_stmts, emit (loop))
159
178
end
160
- return loops
179
+ return new_stmts
161
180
end
162
181
163
182
function emit (loop)
@@ -168,21 +187,28 @@ function emit(loop)
168
187
rhs = stmt. args[2 ]
169
188
push! (rhs. args, idx)
170
189
end
171
- body = Expr (:block , loop. stmts... )
172
- body = postwalk (body) do expr
173
- if @capture (expr, A_[i__])
174
- if A in loop. private
175
- return :($ A[$ (i... ), $ (idx). I... ])
190
+ stmts = Any[]
191
+ append! (stmts, loop. allocations)
192
+ # don't emit empty loops
193
+ if ! (isempty (loop. stmts) || all (s-> s isa LineNumberNode, loop. stmts))
194
+ body = Expr (:block , loop. stmts... )
195
+ body = postwalk (body) do expr
196
+ if @capture (expr, A_[i__])
197
+ if A in loop. private
198
+ return :($ A[$ (i... ), $ (idx). I... ])
199
+ end
176
200
end
201
+ return expr
177
202
end
178
- return expr
179
- end
180
- quote
181
- $ (loop. allocations... )
182
- for $ idx in $ __workitems_iterspace ()
183
- $ __validindex ($ idx) || continue
184
- $ (loop. indicies... )
185
- $ (body)
203
+ loopexpr = quote
204
+ for $ idx in $ __workitems_iterspace ()
205
+ $ __validindex ($ idx) || continue
206
+ $ (loop. indicies... )
207
+ $ (unblock (body))
208
+ end
186
209
end
210
+ push! (stmts, loopexpr)
187
211
end
212
+
213
+ return unblock (Expr (:block , stmts... ))
188
214
end
0 commit comments