Skip to content

Commit a881a13

Browse files
bors[bot]mwaruszvchuravy
authored
Merge #36
36: Fix private memory on the CPU r=vchuravy a=mwarusz Co-authored-by: Valentin Churavy <v.churavy@gmail.com> Co-authored-by: Maciej Waruszewski <mwarusz@igf.fuw.edu.pl> Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
2 parents c4d2487 + 2b832c2 commit a881a13

File tree

5 files changed

+63
-26
lines changed

5 files changed

+63
-26
lines changed

src/backends/cpu.jl

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -132,28 +132,13 @@ struct ScratchArray{N, D}
132132
end
133133

134134
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
135-
return ScratchArray{length(Dims)}(MArray{__size((Dims..., __groupsize(ctx.metadata))), T}(undef))
135+
return ScratchArray{length(Dims)}(MArray{__size((Dims..., __groupsize(ctx.metadata)...)), T}(undef))
136136
end
137137

138-
Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.getindex), A::ScratchArray{N}, I...) where N
139-
nI = ntuple(Val(N+1)) do i
140-
if i == N+1
141-
__groupindex(ctx.metadata)
142-
else
143-
I[i]
144-
end
145-
end
146-
147-
return A.data[nI...]
138+
Base.@propagate_inbounds function Base.getindex(A::ScratchArray, I...)
139+
return A.data[I...]
148140
end
149141

150-
Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.setindex!), A::ScratchArray{N}, val, I...) where N
151-
nI = ntuple(Val(N+1)) do i
152-
if i == N+1
153-
__groupindex(ctx.metadata)
154-
else
155-
I[i]
156-
end
157-
end
158-
A.data[nI...] = val
142+
Base.@propagate_inbounds function Base.setindex!(A::ScratchArray, val, I...)
143+
A.data[I...] = val
159144
end

src/compiler.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ end
2222

2323
@inline __iterspace(cm::CompilerMetadata) = cm.iterspace
2424
@inline __groupindex(cm::CompilerMetadata) = cm.groupindex
25+
@inline __groupsize(cm::CompilerMetadata) = size(workitems(__iterspace(cm)))
2526
@inline __dynamic_checkbounds(::CompilerMetadata{NDRange, CB}) where {NDRange, CB} = CB
2627
@inline __ndrange(cm::CompilerMetadata{NDRange}) where {NDRange<:StaticSize} = CartesianIndices(get(NDRange))
2728
@inline __ndrange(cm::CompilerMetadata{NDRange}) where {NDRange<:DynamicSize} = cm.ndrange
@@ -31,7 +32,7 @@ include("compiler/pass.jl")
3132

3233
function generate_overdubs(Ctx)
3334
@eval begin
34-
@inline Cassette.overdub(ctx::$Ctx, ::typeof(groupsize)) = size(workitems(__iterspace(ctx.metadata)))
35+
@inline Cassette.overdub(ctx::$Ctx, ::typeof(groupsize)) = __groupsize(ctx.metadata)
3536
@inline Cassette.overdub(ctx::$Ctx, ::typeof(__workitems_iterspace)) = workitems(__iterspace(ctx.metadata))
3637

3738
###

src/macros.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import MacroTools: splitdef, combinedef, isexpr
1+
import MacroTools: splitdef, combinedef, isexpr, postwalk
22

33
# XXX: Proper errors
44
function __kernel(expr)
@@ -104,6 +104,7 @@ struct WorkgroupLoop
104104
indicies :: Vector{Any}
105105
stmts :: Vector{Any}
106106
allocations :: Vector{Any}
107+
private :: Vector{Any}
107108
end
108109

109110

@@ -116,12 +117,13 @@ function split(stmts)
116117
current = Any[]
117118
indicies = Any[]
118119
allocations = Any[]
120+
private = Any[]
119121

120122
loops = WorkgroupLoop[]
121123
for stmt in stmts.args
122124
if isexpr(stmt, :macrocall)
123125
if stmt.args[1] === Symbol("@synchronize")
124-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations)
126+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private))
125127
push!(loops, loop)
126128
allocations = Any[]
127129
current = Any[]
@@ -137,10 +139,13 @@ function split(stmts)
137139
push!(indicies, stmt)
138140
continue
139141
elseif callee === Symbol("@localmem") ||
140-
callee === Symbol("@private") ||
141142
callee === Symbol("@uniform")
142143
push!(allocations, stmt)
143144
continue
145+
elseif callee === Symbol("@private")
146+
push!(allocations, stmt)
147+
push!(private, stmt.args[1])
148+
continue
144149
end
145150
end
146151
end
@@ -150,7 +155,7 @@ function split(stmts)
150155

151156
# everything since the last `@synchronize`
152157
if !isempty(current)
153-
push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations))
158+
push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private)))
154159
end
155160
return loops
156161
end
@@ -163,12 +168,21 @@ function emit(loop)
163168
rhs = stmt.args[2]
164169
push!(rhs.args, idx)
165170
end
171+
body = Expr(:block, loop.stmts...)
172+
body = postwalk(body) do expr
173+
if @capture(expr, A_[i__])
174+
if A in loop.private
175+
return :($A[$(i...), $(idx).I...])
176+
end
177+
end
178+
return expr
179+
end
166180
quote
167181
$(loop.allocations...)
168182
for $idx in $__workitems_iterspace()
169183
$__validindex($idx) || continue
170184
$(loop.indicies...)
171-
$(loop.stmts...)
185+
$(body)
172186
end
173187
end
174188
end

test/private.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using KernelAbstractions
2+
using Test
3+
using CUDAapi
4+
if has_cuda_gpu()
5+
using CuArrays
6+
CuArrays.allowscalar(false)
7+
end
8+
9+
@kernel function private(A)
10+
N = prod(groupsize())
11+
I = @index(Global, Linear)
12+
i = @index(Local, Linear)
13+
priv = @private Int (1,)
14+
priv[1] = N - i + 1
15+
@synchronize
16+
A[I] = priv[1]
17+
end
18+
19+
function harness(backend, ArrayT)
20+
A = ArrayT{Int}(undef, 64)
21+
wait(private(backend, 16)(A, ndrange=size(A)))
22+
@test all(A[1:16] .== 16:-1:1)
23+
@test all(A[17:32] .== 16:-1:1)
24+
@test all(A[33:48] .== 16:-1:1)
25+
@test all(A[49:64] .== 16:-1:1)
26+
end
27+
28+
@testset "kernels" begin
29+
harness(CPU(), Array)
30+
if has_cuda_gpu()
31+
harness(CUDA(), CuArray)
32+
end
33+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ end
99
include("localmem.jl")
1010
end
1111

12+
@testset "Private" begin
13+
include("private.jl")
14+
end
15+
1216
@testset "Unroll" begin
1317
include("unroll.jl")
1418
end

0 commit comments

Comments
 (0)