Skip to content

Commit 7649ff5

Browse files
bors[bot]vchuravy
andauthored
Merge #34
34: Use macrotools r=vchuravy a=vchuravy - Use macrotools to make things more robost - Introduce `@uniform` to have values that are uniform across a workgroup - don't define constructors more than once Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
2 parents dc5cc19 + 2896ef7 commit 7649ff5

File tree

5 files changed

+123
-106
lines changed

5 files changed

+123
-106
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
99
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
1010
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
1111
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
12+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1213
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1314
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1415

@@ -20,6 +21,7 @@ CUDAnative = "2.10"
2021
Cassette = "0.3"
2122
Requires = "1.0"
2223
StaticArrays = "0.12"
24+
MacroTools = "0.5"
2325
julia = "1.3"
2426

2527
[extras]

src/KernelAbstractions.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module KernelAbstractions
22

33
export @kernel
4-
export @Const, @localmem, @private, @synchronize, @index, groupsize
4+
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize
55
export Device, GPU, CPU, CUDA
66

7+
using MacroTools
78
using StaticArrays
89
using Cassette
910
using Requires
@@ -23,6 +24,7 @@ and then invoked on the arguments.
2324
- [`@index`](@ref)
2425
- [`@localmem`](@ref)
2526
- [`@private`](@ref)
27+
- [`@uniform`](@ref)
2628
- [`@synchronize`](@ref)
2729
2830
# Example:
@@ -68,6 +70,7 @@ function async_copy! end
6870
# Kernel language
6971
# - @localmem
7072
# - @private
73+
# - @uniform
7174
# - @synchronize
7275
# - @index
7376
# - groupsize
@@ -83,7 +86,7 @@ the total size you can use `prod(groupsize())`.
8386
function groupsize end
8487

8588
"""
86-
@localmem T dims
89+
@localmem T dims
8790
"""
8891
macro localmem(T, dims)
8992
# Stay in sync with CUDAnative
@@ -95,14 +98,21 @@ macro localmem(T, dims)
9598
end
9699

97100
"""
98-
@private T dims
101+
@private T dims
99102
"""
100103
macro private(T, dims)
101104
quote
102105
$Scratchpad($(esc(T)), Val($(esc(dims))))
103106
end
104107
end
105108

109+
"""
110+
@uniform value
111+
"""
112+
macro uniform(value)
113+
esc(value)
114+
end
115+
106116
"""
107117
@synchronize()
108118
"""

src/macros.jl

Lines changed: 97 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,127 @@
1-
import Base.Meta: isexpr
1+
import MacroTools: splitdef, combinedef, isexpr
22

33
# XXX: Proper errors
44
function __kernel(expr)
5-
@assert isexpr(expr, :function)
6-
decl = expr.args[1]
7-
body = expr.args[2]
8-
9-
# parse decl
10-
# `@kernel fname(::T) where {T}`
11-
if isexpr(decl, :where)
12-
iswhere = true
13-
whereargs = decl.args[2:end]
14-
decl = decl.args[1]
15-
else
16-
iswhere = false
17-
end
18-
@assert isexpr(decl, :call)
19-
name = decl.args[1]
20-
21-
# List of tuple (Symbol, Bool) where the bool
22-
# marks if the arg is const
23-
args = Any[]
24-
for i in 2:length(decl.args)
25-
arg = decl.args[i]
5+
def = splitdef(expr)
6+
name = def[:name]
7+
args = def[:args]
8+
9+
constargs = Array{Bool}(undef, length(args))
10+
for (i, arg) in enumerate(args)
2611
if isexpr(arg, :macrocall)
2712
if arg.args[1] === Symbol("@Const")
28-
# args[2] is a LineInfo node
29-
push!(args, (arg.args[3], true))
13+
# arg.args[2] is a LineInfo node
14+
args[i] = arg.args[3] # strip @Const
15+
constargs[i] = true
3016
continue
3117
end
3218
end
33-
push!(args, (arg, false))
19+
constargs[i] = false
3420
end
3521

36-
arglist = map(a->a[1], args)
37-
3822
# create two functions
3923
# 1. GPU function
4024
# 2. CPU function with work-group loops inserted
41-
gpu_name = Symbol(:gpu_, name)
42-
cpu_name = Symbol(:cpu_, name)
25+
#
26+
# Without the deepcopy we might accidentially modify expr shared between CPU and GPU
27+
def_cpu = deepcopy(def)
28+
def_gpu = deepcopy(def)
4329

44-
gpu_decl = Expr(:call, gpu_name, arglist...)
45-
cpu_decl = Expr(:call, cpu_name, arglist...)
30+
def_cpu[:name] = cpu_name = Symbol(:cpu_, name)
31+
def_gpu[:name] = gpu_name = Symbol(:gpu_, name)
4632

47-
if iswhere
48-
gpu_decl = Expr(:where, gpu_decl, whereargs...)
49-
cpu_decl = Expr(:where, cpu_decl, whereargs...)
50-
end
33+
transform_cpu!(def_cpu, constargs)
34+
transform_gpu!(def_gpu, constargs)
5135

52-
# Without the deepcopy we might accidentially modify expr shared between CPU and GPU
53-
gpu_body = transform_gpu(deepcopy(body), args)
54-
gpu_function = Expr(:function, gpu_decl, gpu_body)
55-
56-
cpu_body = transform_cpu(deepcopy(body), args)
57-
cpu_function = Expr(:function, cpu_decl, cpu_body)
36+
cpu_function = combinedef(def_cpu)
37+
gpu_function = combinedef(def_gpu)
5838

5939
# create constructor functions
6040
constructors = quote
61-
$name(dev::$Device) = $name(dev, $DynamicSize(), $DynamicSize())
62-
$name(dev::$Device, size) = $name(dev, $StaticSize(size), $DynamicSize())
63-
$name(dev::$Device, size, range) = $name(dev, $StaticSize(size), $StaticSize(range))
64-
function $name(::Device, ::S, ::NDRange) where {Device<:$CPU, S<:$_Size, NDRange<:$_Size}
65-
return $Kernel{Device, S, NDRange, typeof($cpu_name)}($cpu_name)
66-
end
67-
function $name(::Device, ::S, ::NDRange) where {Device<:$GPU, S<:$_Size, NDRange<:$_Size}
68-
return $Kernel{Device, S, NDRange, typeof($gpu_name)}($gpu_name)
41+
if !@isdefined($name)
42+
$name(dev::$Device) = $name(dev, $DynamicSize(), $DynamicSize())
43+
$name(dev::$Device, size) = $name(dev, $StaticSize(size), $DynamicSize())
44+
$name(dev::$Device, size, range) = $name(dev, $StaticSize(size), $StaticSize(range))
45+
function $name(::Device, ::S, ::NDRange) where {Device<:$CPU, S<:$_Size, NDRange<:$_Size}
46+
return $Kernel{Device, S, NDRange, typeof($cpu_name)}($cpu_name)
47+
end
48+
function $name(::Device, ::S, ::NDRange) where {Device<:$GPU, S<:$_Size, NDRange<:$_Size}
49+
return $Kernel{Device, S, NDRange, typeof($gpu_name)}($gpu_name)
50+
end
6951
end
7052
end
7153

7254
return Expr(:block, esc(cpu_function), esc(gpu_function), esc(constructors))
7355
end
7456

75-
# Transform function for GPU execution
76-
# This involves marking constant arguments
77-
function transform_gpu(expr, args)
57+
# The easy case, transform the function for GPU execution
58+
# - mark constant arguments by applying `constify`.
59+
function transform_gpu!(def, constargs)
7860
new_stmts = Expr[]
79-
for (arg, isconst) in args
80-
if isconst
61+
for (i, arg) in enumerate(def[:args])
62+
if constargs[i]
8163
push!(new_stmts, :($arg = $constify($arg)))
8264
end
8365
end
84-
return quote
66+
67+
def[:body] = quote
8568
if $__validindex()
8669
$(new_stmts...)
87-
$expr
70+
$(def[:body])
8871
end
8972
return nothing
9073
end
9174
end
9275

76+
# The hard case, transform the function for CPU execution
77+
# - mark constant arguments by applying `constify`.
78+
# - insert aliasscope markers
79+
# - insert implied loop bodys
80+
# - handle indicies
81+
# - hoist workgroup definitions
82+
# - hoist uniform variables
83+
function transform_cpu!(def, constargs)
84+
new_stmts = Expr[]
85+
for (i, arg) in enumerate(def[:args])
86+
if constargs[i]
87+
push!(new_stmts, :($arg = $constify($arg)))
88+
end
89+
end
90+
91+
body = MacroTools.flatten(def[:body])
92+
loops = split(body)
93+
94+
push!(new_stmts, Expr(:aliasscope))
95+
for loop in loops
96+
push!(new_stmts, emit(loop))
97+
end
98+
push!(new_stmts, Expr(:popaliasscope))
99+
push!(new_stmts, :(return nothing))
100+
def[:body] = Expr(:block, new_stmts...)
101+
end
102+
103+
struct WorkgroupLoop
104+
indicies :: Vector{Any}
105+
stmts :: Vector{Any}
106+
allocations :: Vector{Any}
107+
end
108+
109+
93110
function split(stmts)
94111
# 1. Split the code into blocks separated by `@synchronize`
95-
# 2. Aggregate the index and allocation expressions seen at the sync points
112+
# 2. Aggregate `@index` expressions
113+
# 3. Hoist allocations
114+
# 4. Hoist uniforms
115+
116+
current = Any[]
96117
indicies = Any[]
97118
allocations = Any[]
98-
loops = Any[]
99-
current = Any[]
100119

120+
loops = WorkgroupLoop[]
101121
for stmt in stmts.args
102122
if isexpr(stmt, :macrocall) && stmt.args[1] === Symbol("@synchronize")
103-
push!(loops, (current, deepcopy(indicies), allocations))
123+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations)
124+
push!(loops, loop)
104125
allocations = Any[]
105126
current = Any[]
106127
continue
@@ -111,64 +132,39 @@ function split(stmts)
111132
if callee === Symbol("@index")
112133
push!(indicies, stmt)
113134
continue
114-
elseif callee === Symbol("@localmem") || callee === Symbol("@private")
135+
elseif callee === Symbol("@localmem") ||
136+
callee === Symbol("@private") ||
137+
callee === Symbol("@uniform")
115138
push!(allocations, stmt)
116139
continue
117140
end
118141
end
119142
end
120143

121-
if isexpr(stmt, :block)
122-
# XXX: What about loops, let, ...
123-
@warn "Encountered a block at the top-level unclear semantics"
124-
end
125144
push!(current, stmt)
126145
end
127146

128147
# everything since the last `@synchronize`
129148
if !isempty(current)
130-
push!(loops, (current, copy(indicies), allocations))
149+
push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations))
131150
end
132151
return loops
133152
end
134153

135-
function generate_cpu_code(loops)
136-
# Create loops
137-
new_stmts = Any[]
138-
for (body, indexes, allocations) in loops
139-
idx = gensym(:I)
154+
function emit(loop)
155+
idx = gensym(:I)
156+
for stmt in loop.indicies
140157
# splice index into the i = @index(Cartesian, $idx)
141-
for stmt in indexes
142-
@assert stmt.head === :(=)
143-
rhs = stmt.args[2]
144-
push!(rhs.args, idx)
145-
end
146-
loop = quote
147-
$(allocations...)
148-
for $idx in $__workitems_iterspace()
149-
$__validindex($idx) || continue
150-
$(indexes...)
151-
$(body...)
152-
end
153-
end
154-
push!(new_stmts, loop)
158+
@assert stmt.head === :(=)
159+
rhs = stmt.args[2]
160+
push!(rhs.args, idx)
155161
end
156-
return Expr(:block, new_stmts...)
157-
end
158-
159-
function transform_cpu(stmts, args)
160-
new_stmts = Expr[]
161-
for (arg, isconst) in args
162-
if isconst
163-
push!(new_stmts, :($arg = $constify($arg)))
162+
quote
163+
$(loop.allocations...)
164+
for $idx in $__workitems_iterspace()
165+
$__validindex($idx) || continue
166+
$(loop.indicies...)
167+
$(loop.stmts...)
164168
end
165169
end
166-
loops = split(stmts)
167-
body = generate_cpu_code(loops)
168-
169-
push!(new_stmts, Expr(:aliasscope))
170-
push!(new_stmts, body)
171-
push!(new_stmts, Expr(:popaliasscope))
172-
push!(new_stmts, :(return nothing))
173-
return Expr(:block, new_stmts...)
174170
end

test/localmem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ if has_cuda_gpu()
77
end
88

99
@kernel function localmem(A)
10+
N = @uniform prod(groupsize())
1011
I = @index(Global, Linear)
1112
i = @index(Local, Linear)
12-
lmem = @localmem Int groupsize() # Ok iff groupsize is static
13+
lmem = @localmem Int (N,) # Ok iff groupsize is static
1314
lmem[i] = i
1415
@synchronize
15-
A[I] = lmem[prod(groupsize()) - i + 1]
16+
A[I] = lmem[N - i + 1]
1617
end
1718

1819
function harness(backend, ArrayT)

test/unroll.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,17 @@ using KernelAbstractions.Extras
77
end
88
end
99

10+
@kernel function kernel_unroll!(a, ::Val{N}) where N
11+
@unroll for i in 1:N
12+
@inbounds a[i] = i
13+
end
14+
end
15+
1016
let
1117
a = zeros(5)
1218
kernel! = kernel_unroll!(CPU(), 1, 1)
1319
event = kernel!(a)
1420
wait(event)
21+
event = kernel!(a, Val(5))
22+
wait(event)
1523
end

0 commit comments

Comments
 (0)