1
- import Base . Meta : isexpr
1
+ import MacroTools : splitdef, combinedef, isexpr
2
2
3
3
# XXX : Proper errors
4
4
function __kernel (expr)
5
- @assert isexpr (expr, :function )
6
- decl = expr. args[1 ]
7
- body = expr. args[2 ]
8
-
9
- # parse decl
10
- # `@kernel fname(::T) where {T}`
11
- if isexpr (decl, :where )
12
- iswhere = true
13
- whereargs = decl. args[2 : end ]
14
- decl = decl. args[1 ]
15
- else
16
- iswhere = false
17
- end
18
- @assert isexpr (decl, :call )
19
- name = decl. args[1 ]
20
-
21
- # List of tuple (Symbol, Bool) where the bool
22
- # marks if the arg is const
23
- args = Any[]
24
- for i in 2 : length (decl. args)
25
- arg = decl. args[i]
5
+ def = splitdef (expr)
6
+ name = def[:name ]
7
+ args = def[:args ]
8
+
9
+ constargs = Array {Bool} (undef, length (args))
10
+ for (i, arg) in enumerate (args)
26
11
if isexpr (arg, :macrocall )
27
12
if arg. args[1 ] === Symbol (" @Const" )
28
- # args[2] is a LineInfo node
29
- push! (args, (arg. args[3 ], true ))
13
+ # arg.args[2] is a LineInfo node
14
+ args[i] = arg. args[3 ] # strip @Const
15
+ constargs[i] = true
30
16
continue
31
17
end
32
18
end
33
- push! (args, (arg, false ))
19
+ constargs[i] = false
34
20
end
35
21
36
- arglist = map (a-> a[1 ], args)
37
-
38
22
# create two functions
39
23
# 1. GPU function
40
24
# 2. CPU function with work-group loops inserted
41
- gpu_name = Symbol (:gpu_ , name)
42
- cpu_name = Symbol (:cpu_ , name)
43
-
44
- gpu_decl = Expr (:call , gpu_name, arglist... )
45
- cpu_decl = Expr (:call , cpu_name, arglist... )
25
+ #
26
+ # Without the deepcopy we might accidentially modify expr shared between CPU and GPU
27
+ def_cpu = deepcopy (def)
28
+ def_gpu = deepcopy (def)
46
29
47
- if iswhere
48
- gpu_decl = Expr (:where , gpu_decl, whereargs... )
49
- cpu_decl = Expr (:where , cpu_decl, whereargs... )
50
- end
30
+ def_cpu[:name ] = cpu_name = Symbol (:cpu_ , name)
31
+ def_gpu[:name ] = gpu_name = Symbol (:gpu_ , name)
51
32
52
- # Without the deepcopy we might accidentially modify expr shared between CPU and GPU
53
- gpu_body = transform_gpu (deepcopy (body), args)
54
- gpu_function = Expr (:function , gpu_decl, gpu_body)
33
+ transform_cpu! (def_cpu, constargs)
34
+ transform_gpu! (def_gpu, constargs)
55
35
56
- cpu_body = transform_cpu ( deepcopy (body), args )
57
- cpu_function = Expr ( :function , cpu_decl, cpu_body )
36
+ cpu_function = combinedef (def_cpu )
37
+ gpu_function = combinedef (def_gpu )
58
38
59
39
# create constructor functions
60
40
constructors = quote
@@ -72,35 +52,74 @@ function __kernel(expr)
72
52
return Expr (:block , esc (cpu_function), esc (gpu_function), esc (constructors))
73
53
end
74
54
75
- # Transform function for GPU execution
76
- # This involves marking constant arguments
77
- function transform_gpu (expr, args )
55
+ # The easy case, transform the function for GPU execution
56
+ # - mark constant arguments by applying `constify`.
57
+ function transform_gpu! (def, constargs )
78
58
new_stmts = Expr[]
79
- for (arg, isconst ) in args
80
- if isconst
59
+ for (i, arg ) in enumerate (def[ : args])
60
+ if constargs[i]
81
61
push! (new_stmts, :($ arg = $ constify ($ arg)))
82
62
end
83
63
end
84
- return quote
64
+
65
+ def[:body ] = quote
85
66
if $ __validindex ()
86
67
$ (new_stmts... )
87
- $ expr
68
+ $ (def[ :body ])
88
69
end
89
70
return nothing
90
71
end
91
72
end
92
73
74
+ # The hard case, transform the function for CPU execution
75
+ # - mark constant arguments by applying `constify`.
76
+ # - insert aliasscope markers
77
+ # - insert implied loop bodys
78
+ # - handle indicies
79
+ # - hoist workgroup definitions
80
+ # - hoist uniform variables
81
+ function transform_cpu! (def, constargs)
82
+ new_stmts = Expr[]
83
+ for (i, arg) in enumerate (def[:args ])
84
+ if constargs[i]
85
+ push! (new_stmts, :($ arg = $ constify ($ arg)))
86
+ end
87
+ end
88
+
89
+ body = MacroTools. flatten (def[:body ])
90
+ loops = split (body)
91
+
92
+ push! (new_stmts, Expr (:aliasscope ))
93
+ for loop in loops
94
+ push! (new_stmts, emit (loop))
95
+ end
96
+ push! (new_stmts, Expr (:popaliasscope ))
97
+ push! (new_stmts, :(return nothing ))
98
+ def[:body ] = Expr (:block , new_stmts... )
99
+ end
100
+
101
+ struct WorkgroupLoop
102
+ indicies :: Vector{Any}
103
+ stmts :: Vector{Any}
104
+ allocations :: Vector{Any}
105
+ end
106
+
107
+
93
108
function split (stmts)
94
109
# 1. Split the code into blocks separated by `@synchronize`
95
- # 2. Aggregate the index and allocation expressions seen at the sync points
110
+ # 2. Aggregate `@index` expressions
111
+ # 3. Hoist allocations
112
+ # 4. Hoist uniforms
113
+
114
+ current = Any[]
96
115
indicies = Any[]
97
116
allocations = Any[]
98
- loops = Any[]
99
- current = Any[]
100
117
118
+ loops = WorkgroupLoop[]
101
119
for stmt in stmts. args
102
120
if isexpr (stmt, :macrocall ) && stmt. args[1 ] === Symbol (" @synchronize" )
103
- push! (loops, (current, deepcopy (indicies), allocations))
121
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations)
122
+ push! (loops, loop)
104
123
allocations = Any[]
105
124
current = Any[]
106
125
continue
@@ -111,64 +130,38 @@ function split(stmts)
111
130
if callee === Symbol (" @index" )
112
131
push! (indicies, stmt)
113
132
continue
114
- elseif callee === Symbol (" @localmem" ) || callee === Symbol (" @private" )
133
+ elseif callee === Symbol (" @localmem" ) ||
134
+ callee === Symbol (" @private" )
115
135
push! (allocations, stmt)
116
136
continue
117
137
end
118
138
end
119
139
end
120
140
121
- if isexpr (stmt, :block )
122
- # XXX : What about loops, let, ...
123
- @warn " Encountered a block at the top-level unclear semantics"
124
- end
125
141
push! (current, stmt)
126
142
end
127
143
128
144
# everything since the last `@synchronize`
129
145
if ! isempty (current)
130
- push! (loops, (current, copy (indicies), allocations))
146
+ push! (loops, WorkgroupLoop ( deepcopy (indicies), current , allocations))
131
147
end
132
148
return loops
133
149
end
134
150
135
- function generate_cpu_code (loops)
136
- # Create loops
137
- new_stmts = Any[]
138
- for (body, indexes, allocations) in loops
139
- idx = gensym (:I )
151
+ function emit (loop)
152
+ idx = gensym (:I )
153
+ for stmt in loop. indicies
140
154
# splice index into the i = @index(Cartesian, $idx)
141
- for stmt in indexes
142
- @assert stmt. head === :(= )
143
- rhs = stmt. args[2 ]
144
- push! (rhs. args, idx)
145
- end
146
- loop = quote
147
- $ (allocations... )
148
- for $ idx in $ __workitems_iterspace ()
149
- $ __validindex ($ idx) || continue
150
- $ (indexes... )
151
- $ (body... )
152
- end
153
- end
154
- push! (new_stmts, loop)
155
+ @assert stmt. head === :(= )
156
+ rhs = stmt. args[2 ]
157
+ push! (rhs. args, idx)
155
158
end
156
- return Expr (:block , new_stmts... )
157
- end
158
-
159
- function transform_cpu (stmts, args)
160
- new_stmts = Expr[]
161
- for (arg, isconst) in args
162
- if isconst
163
- push! (new_stmts, :($ arg = $ constify ($ arg)))
159
+ quote
160
+ $ (loop. allocations... )
161
+ for $ idx in $ __workitems_iterspace ()
162
+ $ __validindex ($ idx) || continue
163
+ $ (loop. indicies... )
164
+ $ (loop. stmts... )
164
165
end
165
166
end
166
- loops = split (stmts)
167
- body = generate_cpu_code (loops)
168
-
169
- push! (new_stmts, Expr (:aliasscope ))
170
- push! (new_stmts, body)
171
- push! (new_stmts, Expr (:popaliasscope ))
172
- push! (new_stmts, :(return nothing ))
173
- return Expr (:block , new_stmts... )
174
167
end
0 commit comments