Skip to content

Commit f65a912

Browse files
committed
use Macrotools
1 parent db2f496 commit f65a912

File tree

2 files changed

+87
-93
lines changed

2 files changed

+87
-93
lines changed

src/KernelAbstractions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ export @kernel
44
export @Const, @localmem, @private, @synchronize, @index, groupsize
55
export Device, GPU, CPU, CUDA
66

7+
using MacroTools
78
using StaticArrays
89
using Cassette
910
using Requires

src/macros.jl

Lines changed: 86 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,40 @@
1-
import Base.Meta: isexpr
1+
import MacroTools: splitdef, combinedef, isexpr
22

33
# XXX: Proper errors
44
function __kernel(expr)
5-
@assert isexpr(expr, :function)
6-
decl = expr.args[1]
7-
body = expr.args[2]
8-
9-
# parse decl
10-
# `@kernel fname(::T) where {T}`
11-
if isexpr(decl, :where)
12-
iswhere = true
13-
whereargs = decl.args[2:end]
14-
decl = decl.args[1]
15-
else
16-
iswhere = false
17-
end
18-
@assert isexpr(decl, :call)
19-
name = decl.args[1]
20-
21-
# List of tuple (Symbol, Bool) where the bool
22-
# marks if the arg is const
23-
args = Any[]
24-
for i in 2:length(decl.args)
25-
arg = decl.args[i]
5+
def = splitdef(expr)
6+
name = def[:name]
7+
args = def[:args]
8+
9+
constargs = Array{Bool}(undef, length(args))
10+
for (i, arg) in enumerate(args)
2611
if isexpr(arg, :macrocall)
2712
if arg.args[1] === Symbol("@Const")
28-
# args[2] is a LineInfo node
29-
push!(args, (arg.args[3], true))
13+
# arg.args[2] is a LineInfo node
14+
args[i] = arg.args[3] # strip @Const
15+
constargs[i] = true
3016
continue
3117
end
3218
end
33-
push!(args, (arg, false))
19+
constargs[i] = false
3420
end
3521

36-
arglist = map(a->a[1], args)
37-
3822
# create two functions
3923
# 1. GPU function
4024
# 2. CPU function with work-group loops inserted
41-
gpu_name = Symbol(:gpu_, name)
42-
cpu_name = Symbol(:cpu_, name)
43-
44-
gpu_decl = Expr(:call, gpu_name, arglist...)
45-
cpu_decl = Expr(:call, cpu_name, arglist...)
25+
#
26+
# Without the deepcopy we might accidentially modify expr shared between CPU and GPU
27+
def_cpu = deepcopy(def)
28+
def_gpu = deepcopy(def)
4629

47-
if iswhere
48-
gpu_decl = Expr(:where, gpu_decl, whereargs...)
49-
cpu_decl = Expr(:where, cpu_decl, whereargs...)
50-
end
30+
def_cpu[:name] = cpu_name = Symbol(:cpu_, name)
31+
def_gpu[:name] = gpu_name = Symbol(:gpu_, name)
5132

52-
# Without the deepcopy we might accidentially modify expr shared between CPU and GPU
53-
gpu_body = transform_gpu(deepcopy(body), args)
54-
gpu_function = Expr(:function, gpu_decl, gpu_body)
33+
transform_cpu!(def_cpu, constargs)
34+
transform_gpu!(def_gpu, constargs)
5535

56-
cpu_body = transform_cpu(deepcopy(body), args)
57-
cpu_function = Expr(:function, cpu_decl, cpu_body)
36+
cpu_function = combinedef(def_cpu)
37+
gpu_function = combinedef(def_gpu)
5838

5939
# create constructor functions
6040
constructors = quote
@@ -72,35 +52,74 @@ function __kernel(expr)
7252
return Expr(:block, esc(cpu_function), esc(gpu_function), esc(constructors))
7353
end
7454

75-
# Transform function for GPU execution
76-
# This involves marking constant arguments
77-
function transform_gpu(expr, args)
55+
# The easy case, transform the function for GPU execution
56+
# - mark constant arguments by applying `constify`.
57+
function transform_gpu!(def, constargs)
7858
new_stmts = Expr[]
79-
for (arg, isconst) in args
80-
if isconst
59+
for (i, arg) in enumerate(def[:args])
60+
if constargs[i]
8161
push!(new_stmts, :($arg = $constify($arg)))
8262
end
8363
end
84-
return quote
64+
65+
def[:body] = quote
8566
if $__validindex()
8667
$(new_stmts...)
87-
$expr
68+
$(def[:body])
8869
end
8970
return nothing
9071
end
9172
end
9273

74+
# The hard case, transform the function for CPU execution
75+
# - mark constant arguments by applying `constify`.
76+
# - insert aliasscope markers
77+
# - insert implied loop bodys
78+
# - handle indicies
79+
# - hoist workgroup definitions
80+
# - hoist uniform variables
81+
function transform_cpu!(def, constargs)
82+
new_stmts = Expr[]
83+
for (i, arg) in enumerate(def[:args])
84+
if constargs[i]
85+
push!(new_stmts, :($arg = $constify($arg)))
86+
end
87+
end
88+
89+
body = MacroTools.flatten(def[:body])
90+
loops = split(body)
91+
92+
push!(new_stmts, Expr(:aliasscope))
93+
for loop in loops
94+
push!(new_stmts, emit(loop))
95+
end
96+
push!(new_stmts, Expr(:popaliasscope))
97+
push!(new_stmts, :(return nothing))
98+
def[:body] = Expr(:block, new_stmts...)
99+
end
100+
101+
struct WorkgroupLoop
102+
indicies :: Vector{Any}
103+
stmts :: Vector{Any}
104+
allocations :: Vector{Any}
105+
end
106+
107+
93108
function split(stmts)
94109
# 1. Split the code into blocks separated by `@synchronize`
95-
# 2. Aggregate the index and allocation expressions seen at the sync points
110+
# 2. Aggregate `@index` expressions
111+
# 3. Hoist allocations
112+
# 4. Hoist uniforms
113+
114+
current = Any[]
96115
indicies = Any[]
97116
allocations = Any[]
98-
loops = Any[]
99-
current = Any[]
100117

118+
loops = WorkgroupLoop[]
101119
for stmt in stmts.args
102120
if isexpr(stmt, :macrocall) && stmt.args[1] === Symbol("@synchronize")
103-
push!(loops, (current, deepcopy(indicies), allocations))
121+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations)
122+
push!(loops, loop)
104123
allocations = Any[]
105124
current = Any[]
106125
continue
@@ -111,64 +130,38 @@ function split(stmts)
111130
if callee === Symbol("@index")
112131
push!(indicies, stmt)
113132
continue
114-
elseif callee === Symbol("@localmem") || callee === Symbol("@private")
133+
elseif callee === Symbol("@localmem") ||
134+
callee === Symbol("@private")
115135
push!(allocations, stmt)
116136
continue
117137
end
118138
end
119139
end
120140

121-
if isexpr(stmt, :block)
122-
# XXX: What about loops, let, ...
123-
@warn "Encountered a block at the top-level unclear semantics"
124-
end
125141
push!(current, stmt)
126142
end
127143

128144
# everything since the last `@synchronize`
129145
if !isempty(current)
130-
push!(loops, (current, copy(indicies), allocations))
146+
push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations))
131147
end
132148
return loops
133149
end
134150

135-
function generate_cpu_code(loops)
136-
# Create loops
137-
new_stmts = Any[]
138-
for (body, indexes, allocations) in loops
139-
idx = gensym(:I)
151+
function emit(loop)
152+
idx = gensym(:I)
153+
for stmt in loop.indicies
140154
# splice index into the i = @index(Cartesian, $idx)
141-
for stmt in indexes
142-
@assert stmt.head === :(=)
143-
rhs = stmt.args[2]
144-
push!(rhs.args, idx)
145-
end
146-
loop = quote
147-
$(allocations...)
148-
for $idx in $__workitems_iterspace()
149-
$__validindex($idx) || continue
150-
$(indexes...)
151-
$(body...)
152-
end
153-
end
154-
push!(new_stmts, loop)
155+
@assert stmt.head === :(=)
156+
rhs = stmt.args[2]
157+
push!(rhs.args, idx)
155158
end
156-
return Expr(:block, new_stmts...)
157-
end
158-
159-
function transform_cpu(stmts, args)
160-
new_stmts = Expr[]
161-
for (arg, isconst) in args
162-
if isconst
163-
push!(new_stmts, :($arg = $constify($arg)))
159+
quote
160+
$(loop.allocations...)
161+
for $idx in $__workitems_iterspace()
162+
$__validindex($idx) || continue
163+
$(loop.indicies...)
164+
$(loop.stmts...)
164165
end
165166
end
166-
loops = split(stmts)
167-
body = generate_cpu_code(loops)
168-
169-
push!(new_stmts, Expr(:aliasscope))
170-
push!(new_stmts, body)
171-
push!(new_stmts, Expr(:popaliasscope))
172-
push!(new_stmts, :(return nothing))
173-
return Expr(:block, new_stmts...)
174167
end

0 commit comments

Comments
 (0)