Skip to content

Commit d294367

Browse files
committed
handle at_synchronize in blocks
1 parent a881a13 commit d294367

File tree

3 files changed

+111
-56
lines changed

3 files changed

+111
-56
lines changed

src/macros.jl

Lines changed: 76 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,8 @@ function transform_cpu!(def, constargs)
8989
end
9090

9191
body = MacroTools.flatten(def[:body])
92-
loops = split(body)
93-
9492
push!(new_stmts, Expr(:aliasscope))
95-
for loop in loops
96-
push!(new_stmts, emit(loop))
97-
end
93+
append!(new_stmts, split(body.args))
9894
push!(new_stmts, Expr(:popaliasscope))
9995
push!(new_stmts, :(return nothing))
10096
def[:body] = Expr(:block, new_stmts...)
@@ -107,46 +103,68 @@ struct WorkgroupLoop
107103
private :: Vector{Any}
108104
end
109105

106+
is_sync(expr) = @capture(expr, @synchronize)
107+
function find_sync(stmt)
108+
result = false
109+
postwalk(stmt) do expr
110+
result |= is_sync(expr)
111+
expr
112+
end
113+
result
114+
end
110115

111-
function split(stmts)
116+
# TODO proper handling of LineInfo
117+
function split(stmts,
118+
indicies = Any[], private=Any[])
112119
# 1. Split the code into blocks separated by `@synchronize`
113120
# 2. Aggregate `@index` expressions
114121
# 3. Hoist allocations
115122
# 4. Hoist uniforms
116123

117124
current = Any[]
118-
indicies = Any[]
119125
allocations = Any[]
120-
private = Any[]
121-
122-
loops = WorkgroupLoop[]
123-
for stmt in stmts.args
124-
if isexpr(stmt, :macrocall)
125-
if stmt.args[1] === Symbol("@synchronize")
126-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private))
127-
push!(loops, loop)
128-
allocations = Any[]
129-
current = Any[]
126+
new_stmts = Any[]
127+
for stmt in stmts
128+
has_sync = find_sync(stmt)
129+
if has_sync
130+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private))
131+
push!(new_stmts, emit(loop))
132+
allocations = Any[]
133+
current = Any[]
134+
@capture(stmt, @synchronize) && continue
135+
136+
# Recurse into scope constructs
137+
# TODO: This currently implements hard scoping
138+
# probably need to implemet soft scoping
139+
# by not deepcopying the environment.
140+
recurse(x) = x
141+
function recurse(expr::Expr)
142+
expr = unblock(expr)
143+
if any(is_sync, expr.args)
144+
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
145+
return Expr(expr.head, new_args...)
146+
else
147+
return Expr(expr.head, map(recurse, expr.args)...)
148+
end
149+
end
150+
push!(new_stmts, recurse(stmt))
151+
continue
152+
end
153+
154+
if @capture(stmt, @uniform x_)
155+
push!(allocations, stmt)
156+
continue
157+
elseif @capture(stmt, lhs_ = rhs_)
158+
if @capture(rhs, @index(args__))
159+
push!(indicies, stmt)
130160
continue
131-
elseif stmt.args[1] === Symbol("@uniform")
161+
elseif @capture(rhs, @localmem(args__) | @uniform(args__) )
132162
push!(allocations, stmt)
133-
end
134-
elseif isexpr(stmt, :(=))
135-
rhs = stmt.args[2]
136-
if isexpr(rhs, :macrocall)
137-
callee = rhs.args[1]
138-
if callee === Symbol("@index")
139-
push!(indicies, stmt)
140-
continue
141-
elseif callee === Symbol("@localmem") ||
142-
callee === Symbol("@uniform")
143-
push!(allocations, stmt)
144-
continue
145-
elseif callee === Symbol("@private")
146-
push!(allocations, stmt)
147-
push!(private, stmt.args[1])
148-
continue
149-
end
163+
continue
164+
elseif @capture(rhs, @private(args__))
165+
push!(allocations, stmt)
166+
push!(private, lhs)
167+
continue
150168
end
151169
end
152170

@@ -155,9 +173,10 @@ function split(stmts)
155173

156174
# everything since the last `@synchronize`
157175
if !isempty(current)
158-
push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private)))
176+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private))
177+
push!(new_stmts, emit(loop))
159178
end
160-
return loops
179+
return new_stmts
161180
end
162181

163182
function emit(loop)
@@ -168,21 +187,28 @@ function emit(loop)
168187
rhs = stmt.args[2]
169188
push!(rhs.args, idx)
170189
end
171-
body = Expr(:block, loop.stmts...)
172-
body = postwalk(body) do expr
173-
if @capture(expr, A_[i__])
174-
if A in loop.private
175-
return :($A[$(i...), $(idx).I...])
190+
stmts = Any[]
191+
append!(stmts, loop.allocations)
192+
# don't emit empty loops
193+
if !(isempty(loop.stmts) || all(s->s isa LineNumberNode, loop.stmts))
194+
body = Expr(:block, loop.stmts...)
195+
body = postwalk(body) do expr
196+
if @capture(expr, A_[i__])
197+
if A in loop.private
198+
return :($A[$(i...), $(idx).I...])
199+
end
176200
end
201+
return expr
177202
end
178-
return expr
179-
end
180-
quote
181-
$(loop.allocations...)
182-
for $idx in $__workitems_iterspace()
183-
$__validindex($idx) || continue
184-
$(loop.indicies...)
185-
$(body)
203+
loopexpr = quote
204+
for $idx in $__workitems_iterspace()
205+
$__validindex($idx) || continue
206+
$(loop.indicies...)
207+
$(unblock(body))
208+
end
186209
end
210+
push!(stmts, loopexpr)
187211
end
212+
213+
return unblock(Expr(:block, stmts...))
188214
end

test/localmem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ end
1414
I = @index(Global, Linear)
1515
i = @index(Local, Linear)
1616
lmem = @localmem Int (N,) # Ok iff groupsize is static
17-
lmem[i] = i
18-
@synchronize
19-
A[I] = lmem[N2 - i + 1]
17+
@inbounds begin
18+
lmem[i] = i
19+
@synchronize
20+
A[I] = lmem[N2 - i + 1]
21+
end
2022
end
2123

2224
function harness(backend, ArrayT)

test/private.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,30 @@ if has_cuda_gpu()
77
end
88

99
@kernel function private(A)
10-
N = prod(groupsize())
10+
@uniform N = prod(groupsize())
1111
I = @index(Global, Linear)
1212
i = @index(Local, Linear)
1313
priv = @private Int (1,)
14-
priv[1] = N - i + 1
14+
@inbounds begin
15+
priv[1] = N - i + 1
16+
@synchronize
17+
A[I] = priv[1]
18+
end
19+
end
20+
21+
# This is horrible don't write code like this
22+
@kernel function forloop(A, ::Val{N}) where N
23+
i = @index(Global, Linear)
24+
priv = @private Int (N,)
25+
for j in 1:N
26+
priv[j] = A[i, j]
27+
end
28+
A[i, 1] = 0
1529
@synchronize
16-
A[I] = priv[1]
30+
for j in 1:N
31+
A[j, 1] += priv[j]
32+
@synchronize
33+
end
1734
end
1835

1936
function harness(backend, ArrayT)
@@ -23,6 +40,16 @@ function harness(backend, ArrayT)
2340
@test all(A[17:32] .== 16:-1:1)
2441
@test all(A[33:48] .== 16:-1:1)
2542
@test all(A[49:64] .== 16:-1:1)
43+
44+
A = ArrayT{Int}(undef, 64, 64)
45+
A .= 1
46+
wait(forloop(backend)(A, Val(size(A, 2)), ndrange=size(A,1), workgroupsize=size(A,1)))
47+
if ArrayT <: Array
48+
@test all(A[:, 1] .== 64)
49+
else
50+
@test_broken all(A[:, 1] .== 64)
51+
end
52+
@test all(A[:, 2:end] .== 1)
2653
end
2754

2855
@testset "kernels" begin

0 commit comments

Comments
 (0)