Skip to content

Commit 98d8f75

Browse files
committed
add conditional synchronize
1 parent 4aff96c commit 98d8f75

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

src/KernelAbstractions.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,34 @@ end
115115

116116
"""
117117
@synchronize()
118+
119+
After a `@synchronize` statement all read and writes to global and local memory
120+
from each thread in the workgroup are visible in from all other threads in the
121+
workgroup.
118122
"""
119123
macro synchronize()
120124
quote
121125
$__synchronize()
122126
end
123127
end
124128

129+
"""
130+
@synchronize(cond)
131+
132+
After a `@synchronize` statement all read and writes to global and local memory
133+
from each thread in the workgroup are visible in from all other threads in the
134+
workgroup. `cond` is not allowed to have any visible sideffects.
135+
136+
# Platform differences
137+
- `GPU`: This synchronization will only occur if the `cond` evaluates.
138+
- `CPU`: This synchronization will always occur.
139+
"""
140+
macro synchronize(cond)
141+
quote
142+
$(esc(cond)) && $__synchronize()
143+
end
144+
end
145+
125146
"""
126147
@index
127148

src/macros.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,13 @@ struct WorkgroupLoop
103103
private :: Vector{Any}
104104
end
105105

106-
is_sync(expr) = @capture(expr, @synchronize)
106+
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))
107+
108+
function is_scope_construct(expr::Expr)
109+
expr.head === :block # ||
110+
# expr.head === :let
111+
end
112+
107113
function find_sync(stmt)
108114
result = false
109115
postwalk(stmt) do expr
@@ -131,7 +137,7 @@ function split(stmts,
131137
push!(new_stmts, emit(loop))
132138
allocations = Any[]
133139
current = Any[]
134-
@capture(stmt, @synchronize) && continue
140+
is_sync(stmt) && continue
135141

136142
# Recurse into scope constructs
137143
# TODO: This currently implements hard scoping
@@ -140,7 +146,7 @@ function split(stmts,
140146
recurse(x) = x
141147
function recurse(expr::Expr)
142148
expr = unblock(expr)
143-
if any(is_sync, expr.args)
149+
if is_scope_construct(expr) && any(is_sync, expr.args)
144150
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
145151
return Expr(expr.head, new_args...)
146152
else

test/unroll.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ using StaticArrays
99
end
1010

1111
@kernel function kernel_unroll!(a, ::Val{N}) where N
12-
@unroll for i in 1:N
13-
@inbounds a[i] = i
12+
let M = N+5
13+
@unroll for i in 6:M
14+
@inbounds a[i-5] = i
15+
end
16+
@synchronize
1417
end
1518
end
1619

1720
# Check that nested `@unroll` doesn't throw a syntax error
18-
@kernel function kernel_unroll!(a, ::Val{N}) where N
21+
@kernel function kernel_unroll2!(A)
1922
@uniform begin
2023
a = MVector{3, Float64}(1, 2, 3)
2124
b = MVector{3, Float64}(3, 2, 1)
@@ -28,16 +31,16 @@ end
2831
c[1, j] = m * a[1] * b[j]
2932
end
3033
end
31-
a[I] = c[1, 1]
32-
m % 2 == 0 && @synchronize
34+
A[I] = c[1, 1]
35+
@synchronize(m % 2 == 0)
3336
end
3437
end
3538

3639
let
3740
a = zeros(5)
3841
kernel! = kernel_unroll!(CPU(), 1, 1)
39-
event = kernel!(a)
40-
wait(event)
41-
event = kernel!(a, Val(5))
42-
wait(event)
42+
wait(kernel!(a))
43+
wait(kernel!(a, Val(5)))
44+
kernel2! = kernel_unroll2!(CPU(), 1, 1)
45+
wait(kernel2!(a))
4346
end

0 commit comments

Comments
 (0)