1
- import Base . Meta : isexpr
1
+ import MacroTools : splitdef, combinedef, isexpr
2
2
3
3
# XXX : Proper errors
4
4
function __kernel (expr)
5
- @assert isexpr (expr, :function )
6
- decl = expr. args[1 ]
7
- body = expr. args[2 ]
8
-
9
- # parse decl
10
- # `@kernel fname(::T) where {T}`
11
- if isexpr (decl, :where )
12
- iswhere = true
13
- whereargs = decl. args[2 : end ]
14
- decl = decl. args[1 ]
15
- else
16
- iswhere = false
17
- end
18
- @assert isexpr (decl, :call )
19
- name = decl. args[1 ]
20
-
21
- # List of tuple (Symbol, Bool) where the bool
22
- # marks if the arg is const
23
- args = Any[]
24
- for i in 2 : length (decl. args)
25
- arg = decl. args[i]
5
+ def = splitdef (expr)
6
+ name = def[:name ]
7
+ args = def[:args ]
8
+
9
+ constargs = Array {Bool} (undef, length (args))
10
+ for (i, arg) in enumerate (args)
26
11
if isexpr (arg, :macrocall )
27
12
if arg. args[1 ] === Symbol (" @Const" )
28
- # args[2] is a LineInfo node
29
- push! (args, (arg. args[3 ], true ))
13
+ # arg.args[2] is a LineInfo node
14
+ args[i] = arg. args[3 ] # strip @Const
15
+ constargs[i] = true
30
16
continue
31
17
end
32
18
end
33
- push! (args, (arg, false ))
19
+ constargs[i] = false
34
20
end
35
21
36
- arglist = map (a-> a[1 ], args)
37
-
38
22
# create two functions
39
23
# 1. GPU function
40
24
# 2. CPU function with work-group loops inserted
41
- gpu_name = Symbol (:gpu_ , name)
42
- cpu_name = Symbol (:cpu_ , name)
25
+ #
26
+ # Without the deepcopy we might accidentially modify expr shared between CPU and GPU
27
+ def_cpu = deepcopy (def)
28
+ def_gpu = deepcopy (def)
43
29
44
- gpu_decl = Expr ( :call , gpu_name, arglist ... )
45
- cpu_decl = Expr ( :call , cpu_name, arglist ... )
30
+ def_cpu[ :name ] = cpu_name = Symbol ( :cpu_ , name )
31
+ def_gpu[ :name ] = gpu_name = Symbol ( :gpu_ , name )
46
32
47
- if iswhere
48
- gpu_decl = Expr (:where , gpu_decl, whereargs... )
49
- cpu_decl = Expr (:where , cpu_decl, whereargs... )
50
- end
33
+ transform_cpu! (def_cpu, constargs)
34
+ transform_gpu! (def_gpu, constargs)
51
35
52
- # Without the deepcopy we might accidentially modify expr shared between CPU and GPU
53
- gpu_body = transform_gpu (deepcopy (body), args)
54
- gpu_function = Expr (:function , gpu_decl, gpu_body)
55
-
56
- cpu_body = transform_cpu (deepcopy (body), args)
57
- cpu_function = Expr (:function , cpu_decl, cpu_body)
36
+ cpu_function = combinedef (def_cpu)
37
+ gpu_function = combinedef (def_gpu)
58
38
59
39
# create constructor functions
60
40
constructors = quote
61
- $ name (dev:: $Device ) = $ name (dev, $ DynamicSize (), $ DynamicSize ())
62
- $ name (dev:: $Device , size) = $ name (dev, $ StaticSize (size), $ DynamicSize ())
63
- $ name (dev:: $Device , size, range) = $ name (dev, $ StaticSize (size), $ StaticSize (range))
64
- function $name (:: Device , :: S , :: NDRange ) where {Device<: $CPU , S<: $_Size , NDRange<: $_Size }
65
- return $ Kernel {Device, S, NDRange, typeof($cpu_name)} ($ cpu_name)
66
- end
67
- function $name (:: Device , :: S , :: NDRange ) where {Device<: $GPU , S<: $_Size , NDRange<: $_Size }
68
- return $ Kernel {Device, S, NDRange, typeof($gpu_name)} ($ gpu_name)
41
+ if ! @isdefined ($ name)
42
+ $ name (dev:: $Device ) = $ name (dev, $ DynamicSize (), $ DynamicSize ())
43
+ $ name (dev:: $Device , size) = $ name (dev, $ StaticSize (size), $ DynamicSize ())
44
+ $ name (dev:: $Device , size, range) = $ name (dev, $ StaticSize (size), $ StaticSize (range))
45
+ function $name (:: Device , :: S , :: NDRange ) where {Device<: $CPU , S<: $_Size , NDRange<: $_Size }
46
+ return $ Kernel {Device, S, NDRange, typeof($cpu_name)} ($ cpu_name)
47
+ end
48
+ function $name (:: Device , :: S , :: NDRange ) where {Device<: $GPU , S<: $_Size , NDRange<: $_Size }
49
+ return $ Kernel {Device, S, NDRange, typeof($gpu_name)} ($ gpu_name)
50
+ end
69
51
end
70
52
end
71
53
72
54
return Expr (:block , esc (cpu_function), esc (gpu_function), esc (constructors))
73
55
end
74
56
75
- # Transform function for GPU execution
76
- # This involves marking constant arguments
77
- function transform_gpu (expr, args )
57
+ # The easy case, transform the function for GPU execution
58
+ # - mark constant arguments by applying `constify`.
59
+ function transform_gpu! (def, constargs )
78
60
new_stmts = Expr[]
79
- for (arg, isconst ) in args
80
- if isconst
61
+ for (i, arg ) in enumerate (def[ : args])
62
+ if constargs[i]
81
63
push! (new_stmts, :($ arg = $ constify ($ arg)))
82
64
end
83
65
end
84
- return quote
66
+
67
+ def[:body ] = quote
85
68
if $ __validindex ()
86
69
$ (new_stmts... )
87
- $ expr
70
+ $ (def[ :body ])
88
71
end
89
72
return nothing
90
73
end
91
74
end
92
75
76
+ # The hard case, transform the function for CPU execution
77
+ # - mark constant arguments by applying `constify`.
78
+ # - insert aliasscope markers
79
+ # - insert implied loop bodys
80
+ # - handle indicies
81
+ # - hoist workgroup definitions
82
+ # - hoist uniform variables
83
+ function transform_cpu! (def, constargs)
84
+ new_stmts = Expr[]
85
+ for (i, arg) in enumerate (def[:args ])
86
+ if constargs[i]
87
+ push! (new_stmts, :($ arg = $ constify ($ arg)))
88
+ end
89
+ end
90
+
91
+ body = MacroTools. flatten (def[:body ])
92
+ loops = split (body)
93
+
94
+ push! (new_stmts, Expr (:aliasscope ))
95
+ for loop in loops
96
+ push! (new_stmts, emit (loop))
97
+ end
98
+ push! (new_stmts, Expr (:popaliasscope ))
99
+ push! (new_stmts, :(return nothing ))
100
+ def[:body ] = Expr (:block , new_stmts... )
101
+ end
102
+
103
+ struct WorkgroupLoop
104
+ indicies :: Vector{Any}
105
+ stmts :: Vector{Any}
106
+ allocations :: Vector{Any}
107
+ end
108
+
109
+
93
110
function split (stmts)
94
111
# 1. Split the code into blocks separated by `@synchronize`
95
- # 2. Aggregate the index and allocation expressions seen at the sync points
112
+ # 2. Aggregate `@index` expressions
113
+ # 3. Hoist allocations
114
+ # 4. Hoist uniforms
115
+
116
+ current = Any[]
96
117
indicies = Any[]
97
118
allocations = Any[]
98
- loops = Any[]
99
- current = Any[]
100
119
120
+ loops = WorkgroupLoop[]
101
121
for stmt in stmts. args
102
122
if isexpr (stmt, :macrocall ) && stmt. args[1 ] === Symbol (" @synchronize" )
103
- push! (loops, (current, deepcopy (indicies), allocations))
123
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations)
124
+ push! (loops, loop)
104
125
allocations = Any[]
105
126
current = Any[]
106
127
continue
@@ -111,64 +132,39 @@ function split(stmts)
111
132
if callee === Symbol (" @index" )
112
133
push! (indicies, stmt)
113
134
continue
114
- elseif callee === Symbol (" @localmem" ) || callee === Symbol (" @private" )
135
+ elseif callee === Symbol (" @localmem" ) ||
136
+ callee === Symbol (" @private" ) ||
137
+ callee === Symbol (" @uniform" )
115
138
push! (allocations, stmt)
116
139
continue
117
140
end
118
141
end
119
142
end
120
143
121
- if isexpr (stmt, :block )
122
- # XXX : What about loops, let, ...
123
- @warn " Encountered a block at the top-level unclear semantics"
124
- end
125
144
push! (current, stmt)
126
145
end
127
146
128
147
# everything since the last `@synchronize`
129
148
if ! isempty (current)
130
- push! (loops, (current, copy (indicies), allocations))
149
+ push! (loops, WorkgroupLoop ( deepcopy (indicies), current , allocations))
131
150
end
132
151
return loops
133
152
end
134
153
135
- function generate_cpu_code (loops)
136
- # Create loops
137
- new_stmts = Any[]
138
- for (body, indexes, allocations) in loops
139
- idx = gensym (:I )
154
+ function emit (loop)
155
+ idx = gensym (:I )
156
+ for stmt in loop. indicies
140
157
# splice index into the i = @index(Cartesian, $idx)
141
- for stmt in indexes
142
- @assert stmt. head === :(= )
143
- rhs = stmt. args[2 ]
144
- push! (rhs. args, idx)
145
- end
146
- loop = quote
147
- $ (allocations... )
148
- for $ idx in $ __workitems_iterspace ()
149
- $ __validindex ($ idx) || continue
150
- $ (indexes... )
151
- $ (body... )
152
- end
153
- end
154
- push! (new_stmts, loop)
158
+ @assert stmt. head === :(= )
159
+ rhs = stmt. args[2 ]
160
+ push! (rhs. args, idx)
155
161
end
156
- return Expr (:block , new_stmts... )
157
- end
158
-
159
- function transform_cpu (stmts, args)
160
- new_stmts = Expr[]
161
- for (arg, isconst) in args
162
- if isconst
163
- push! (new_stmts, :($ arg = $ constify ($ arg)))
162
+ quote
163
+ $ (loop. allocations... )
164
+ for $ idx in $ __workitems_iterspace ()
165
+ $ __validindex ($ idx) || continue
166
+ $ (loop. indicies... )
167
+ $ (loop. stmts... )
164
168
end
165
169
end
166
- loops = split (stmts)
167
- body = generate_cpu_code (loops)
168
-
169
- push! (new_stmts, Expr (:aliasscope ))
170
- push! (new_stmts, body)
171
- push! (new_stmts, Expr (:popaliasscope ))
172
- push! (new_stmts, :(return nothing ))
173
- return Expr (:block , new_stmts... )
174
170
end
0 commit comments