@@ -109,7 +109,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
109
109
entry_fn = LLVM. name (entry)
110
110
111
111
if job. source. kernel
112
- add_address_spaces! (mod, entry)
112
+ add_address_spaces! (job, mod, entry)
113
113
end
114
114
115
115
return functions (mod)[entry_fn]
@@ -147,103 +147,95 @@ end
147
147
148
148
# generic pointer removal
149
149
#
150
- # every pointer argument (e.g. byref objs) to a kernel needs an address space attached.
151
- function add_address_spaces! (mod:: LLVM.Module , f:: LLVM.Function )
150
+ # every pointer argument (i.e. byref objs) to a kernel needs an address space attached.
151
+ # this pass rewrites pointers to reference arguments to be located in address space 1.
152
+ #
153
+ # NOTE: this pass only rewrites byref objs, not plain pointers being passed; the user is
154
+ # responsible for making sure these pointers have an address space attached (using LLVMPtr).
155
+ #
156
+ # NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
157
+ # be executed after optimization (where Julia's address spaces are stripped). If we ever
158
+ # want to execute it earlier, adapt remapType to rewrite all pointer types.
159
+ function add_address_spaces! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module , f:: LLVM.Function )
152
160
ctx = context (mod)
153
161
ft = eltype (llvmtype (f))
154
162
163
+ # find the byref parameters
164
+ byref = BitVector (undef, length (parameters (ft)))
165
+ let args = classify_arguments (job, ft)
166
+ filter! (args) do arg
167
+ arg. cc != GHOST
168
+ end
169
+ for arg in args
170
+ byref[arg. codegen. i] = (arg. cc == BITS_REF)
171
+ end
172
+ end
173
+
155
174
function remapType (src)
156
- # TODO : recurse in structs
175
+ # TODO : cache?
176
+ # TODO : recurse in structs?
177
+ # TODO : when wrapping non-AS1 pointers, shouldn't the parent object use the same AS?
157
178
dst = if src isa LLVM. PointerType && addrspace (src) == 0
158
179
LLVM. PointerType (remapType (eltype (src)), #= device=# 1 )
159
180
else
160
181
src
161
182
end
162
- # TODO : cache
163
183
return dst
164
184
end
165
185
166
186
# generate the new function type & definition
167
- new_types = LLVMType[remapType (typ) for typ in parameters (ft)]
187
+ new_types = LLVMType[]
188
+ for (i, param) in enumerate (parameters (ft))
189
+ if byref[i]
190
+ push! (new_types, remapType (param:: LLVM.PointerType ))
191
+ else
192
+ push! (new_types, param)
193
+ end
194
+ end
168
195
new_ft = LLVM. FunctionType (LLVM. return_type (ft), new_types)
169
196
new_f = LLVM. Function (mod, " " , new_ft)
170
197
linkage! (new_f, linkage (f))
171
198
for (arg, new_arg) in zip (parameters (f), parameters (new_f))
172
199
LLVM. name! (new_arg, LLVM. name (arg))
173
200
end
174
201
175
- # map the parameters
176
- value_map = Dict {LLVM.Value, LLVM.Value} (
177
- param => new_param for (param, new_param) in zip (parameters (f), parameters (new_f))
178
- )
179
- value_map[f] = new_f
180
-
181
- # before D96531 (part of LLVM 13), clone_into! wants to duplicate debug metadata
182
- # when the functions are part of the same module. that is invalid, because it
183
- # results in desynchronized debug intrinsics (GPUCompiler#284), so remove those.
184
- if LLVM. version () < v " 13"
185
- removals = LLVM. Instruction[]
186
- for bb in blocks (f), inst in instructions (bb)
187
- if inst isa LLVM. CallInst && LLVM. name (called_value (inst)) == " llvm.dbg.declare"
188
- push! (removals, inst)
202
+ # we cannot simply remap the function arguments, because that will not propagate the
203
+ # address space changes across, e.g, bitcasts (the dest would still be in AS 0).
204
+ # using a type remapper on the other hand changes too much, including unrelated insts.
205
+ # so instead, we load the arguments in stack slots and dereference them so that we can
206
+ # keep on using the original IR that assumed pointers without address spaces
207
+ new_args = LLVM. Value[]
208
+ @dispose builder= Builder (ctx) begin
209
+ entry = BasicBlock (new_f, " conversion" ; ctx)
210
+ position! (builder, entry)
211
+
212
+ # perform argument conversions
213
+ for (i, param) in enumerate (parameters (ft))
214
+ if byref[i]
215
+ # load the argument in a stack slot
216
+ val = load! (builder, parameters (new_f)[i])
217
+ ptr = alloca! (builder, eltype (param))
218
+ store! (builder, val, ptr)
219
+ push! (new_args, ptr)
220
+ else
221
+ push! (new_args, parameters (new_f)[i])
222
+ end
223
+ for attr in collect (parameter_attributes (f, i))
224
+ push! (parameter_attributes (new_f, i), attr)
189
225
end
190
226
end
191
- for inst in removals
192
- @assert isempty (uses (inst))
193
- unsafe_delete! (LLVM. parent (inst), inst)
194
- end
195
- changes = LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges
196
- else
197
- changes = LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly
198
- end
199
227
200
- function type_mapper (typ)
201
- remapType (typ)
202
- end
203
-
204
- clone_into! (new_f, f; value_map, changes, type_mapper)
205
-
206
- # update calls to overloaded intrinsic, re-mangling their names
207
- # XXX : shouldn't clone_into! do this?
208
- LLVM. @dispose builder= Builder (ctx) begin
209
- for bb in blocks (new_f), inst in instructions (bb)
210
- if inst isa LLVM. CallBase
211
- callee_f = called_value (inst)
212
- LLVM. isintrinsic (callee_f) || continue
213
- intr = Intrinsic (callee_f)
214
- isoverloaded (intr) || continue
215
-
216
- # get an appropriately-overloaded intrinsic instantiation
217
- # NOTE: the overload types differs from the argument types
218
- intr_f = if intr == Intrinsic (" llvm.memcpy" )
219
- LLVM. Function (mod, intr, llvmtype .(arguments (inst)[1 : end - 1 ]))
220
- elseif intr == Intrinsic (" llvm.lifetime.start" ) ||
221
- intr == Intrinsic (" llvm.lifetime.end" )
222
- LLVM. Function (mod, intr, [llvmtype (arguments (inst)[end ])])
223
- else
224
- # TODO : use matchIntrinsicSignature to do this generically
225
- error (""" Unsupported intrinsic call:
226
- $inst .
227
- Please file an issue with at https://github.com/JuliaGPU/GPUCompiler.jl""" )
228
- end
228
+ # map the arguments
229
+ value_map = Dict {LLVM.Value, LLVM.Value} (
230
+ param => new_args[i] for (i,param) in enumerate (parameters (f))
231
+ )
229
232
230
- # create a call to the new intrinsic
231
- # TODO : wrap setCalledFunction instead of using an IRBuilder
232
- position! (builder, inst)
233
- new_inst = if inst isa LLVM. CallInst
234
- call! (builder, intr_f, arguments (inst), operand_bundles (inst))
235
- else
236
- # TODO : invoke and callbr
237
- error (" Rewrite of $(typeof (inst)) -based calls is not implemented: $inst " )
238
- end
239
- callconv! (new_inst, callconv (inst))
233
+ value_map[f] = new_f
234
+ clone_into! (new_f, f; value_map,
235
+ changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
240
236
241
- # replace the old call
242
- replace_uses! (inst, new_inst)
243
- @assert isempty (uses (inst))
244
- unsafe_delete! (LLVM. parent (inst), inst)
245
- end
246
- end
237
+ # fall through
238
+ br! (builder, blocks (new_f)[2 ])
247
239
end
248
240
249
241
# remove the old function
@@ -253,6 +245,16 @@ function add_address_spaces!(mod::LLVM.Module, f::LLVM.Function)
253
245
unsafe_delete! (mod, f)
254
246
LLVM. name! (new_f, fn)
255
247
248
+ # clean-up after this pass (which runs after optimization)
249
+ @dispose pm= ModulePassManager () begin
250
+ cfgsimplification! (pm)
251
+ scalar_repl_aggregates! (pm)
252
+ early_cse! (pm)
253
+ instruction_combining! (pm)
254
+
255
+ run! (pm, mod)
256
+ end
257
+
256
258
return new_f
257
259
end
258
260
0 commit comments