Skip to content

Commit 56033ad

Browse files
authored
Metal: Change approach to adding ptr arg ASs. (#340)
Turns out that using module cloning isn't working here. If we change the address space of the arguments, and just remap those arguments using a value mapper, those changes will get lost after, e.g., a bitcast from the argument pointer to a different type (that destination type won't have an address space, resulting in invalid operations). We were working around this limitation by using a type remapper, rewriting all pointer types to contain an address space, but that's obviously invalid: Stack space (i.e. `alloca`s) should still reside in address space 0. So instead we're taking the approach we're using with byval lowering, storing the arguments to a temporary slot. We're still changing the function's pointer arguments to include an address space, but after loading the argument and storing it in a stack slot we can dereference it and get a valid pointer without an address space. We can then re-use the existing IR with that pointer, without having to rewrite it. This simplifies the pass a lot, because we don't have to worry about rewriting intrinsics like memcpy.
1 parent 40b663a commit 56033ad

File tree

2 files changed

+75
-93
lines changed

2 files changed

+75
-93
lines changed

src/metal.jl

Lines changed: 75 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
109109
entry_fn = LLVM.name(entry)
110110

111111
if job.source.kernel
112-
add_address_spaces!(mod, entry)
112+
add_address_spaces!(job, mod, entry)
113113
end
114114

115115
return functions(mod)[entry_fn]
@@ -147,103 +147,95 @@ end
147147

148148
# generic pointer removal
149149
#
150-
# every pointer argument (e.g. byref objs) to a kernel needs an address space attached.
151-
function add_address_spaces!(mod::LLVM.Module, f::LLVM.Function)
150+
# every pointer argument (i.e. byref objs) to a kernel needs an address space attached.
151+
# this pass rewrites pointers to reference arguments to be located in address space 1.
152+
#
153+
# NOTE: this pass only rewrites byref objs, not plain pointers being passed; the user is
154+
# responsible for making sure these pointers have an address space attached (using LLVMPtr).
155+
#
156+
# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
157+
# be executed after optimization (where Julia's address spaces are stripped). If we ever
158+
# want to execute it earlier, adapt remapType to rewrite all pointer types.
159+
function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
152160
ctx = context(mod)
153161
ft = eltype(llvmtype(f))
154162

163+
# find the byref parameters
164+
byref = BitVector(undef, length(parameters(ft)))
165+
let args = classify_arguments(job, ft)
166+
filter!(args) do arg
167+
arg.cc != GHOST
168+
end
169+
for arg in args
170+
byref[arg.codegen.i] = (arg.cc == BITS_REF)
171+
end
172+
end
173+
155174
function remapType(src)
156-
# TODO: recurse in structs
175+
# TODO: cache?
176+
# TODO: recurse in structs?
177+
# TODO: when wrapping non-AS1 pointers, shouldn't the parent object use the same AS?
157178
dst = if src isa LLVM.PointerType && addrspace(src) == 0
158179
LLVM.PointerType(remapType(eltype(src)), #=device=# 1)
159180
else
160181
src
161182
end
162-
# TODO: cache
163183
return dst
164184
end
165185

166186
# generate the new function type & definition
167-
new_types = LLVMType[remapType(typ) for typ in parameters(ft)]
187+
new_types = LLVMType[]
188+
for (i, param) in enumerate(parameters(ft))
189+
if byref[i]
190+
push!(new_types, remapType(param::LLVM.PointerType))
191+
else
192+
push!(new_types, param)
193+
end
194+
end
168195
new_ft = LLVM.FunctionType(LLVM.return_type(ft), new_types)
169196
new_f = LLVM.Function(mod, "", new_ft)
170197
linkage!(new_f, linkage(f))
171198
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
172199
LLVM.name!(new_arg, LLVM.name(arg))
173200
end
174201

175-
# map the parameters
176-
value_map = Dict{LLVM.Value, LLVM.Value}(
177-
param => new_param for (param, new_param) in zip(parameters(f), parameters(new_f))
178-
)
179-
value_map[f] = new_f
180-
181-
# before D96531 (part of LLVM 13), clone_into! wants to duplicate debug metadata
182-
# when the functions are part of the same module. that is invalid, because it
183-
# results in desynchronized debug intrinsics (GPUCompiler#284), so remove those.
184-
if LLVM.version() < v"13"
185-
removals = LLVM.Instruction[]
186-
for bb in blocks(f), inst in instructions(bb)
187-
if inst isa LLVM.CallInst && LLVM.name(called_value(inst)) == "llvm.dbg.declare"
188-
push!(removals, inst)
202+
# we cannot simply remap the function arguments, because that will not propagate the
203+
# address space changes across, e.g, bitcasts (the dest would still be in AS 0).
204+
# using a type remapper on the other hand changes too much, including unrelated insts.
205+
# so instead, we load the arguments in stack slots and dereference them so that we can
206+
# keep on using the original IR that assumed pointers without address spaces
207+
new_args = LLVM.Value[]
208+
@dispose builder=Builder(ctx) begin
209+
entry = BasicBlock(new_f, "conversion"; ctx)
210+
position!(builder, entry)
211+
212+
# perform argument conversions
213+
for (i, param) in enumerate(parameters(ft))
214+
if byref[i]
215+
# load the argument in a stack slot
216+
val = load!(builder, parameters(new_f)[i])
217+
ptr = alloca!(builder, eltype(param))
218+
store!(builder, val, ptr)
219+
push!(new_args, ptr)
220+
else
221+
push!(new_args, parameters(new_f)[i])
222+
end
223+
for attr in collect(parameter_attributes(f, i))
224+
push!(parameter_attributes(new_f, i), attr)
189225
end
190226
end
191-
for inst in removals
192-
@assert isempty(uses(inst))
193-
unsafe_delete!(LLVM.parent(inst), inst)
194-
end
195-
changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
196-
else
197-
changes = LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly
198-
end
199227

200-
function type_mapper(typ)
201-
remapType(typ)
202-
end
203-
204-
clone_into!(new_f, f; value_map, changes, type_mapper)
205-
206-
# update calls to overloaded intrinsic, re-mangling their names
207-
# XXX: shouldn't clone_into! do this?
208-
LLVM.@dispose builder=Builder(ctx) begin
209-
for bb in blocks(new_f), inst in instructions(bb)
210-
if inst isa LLVM.CallBase
211-
callee_f = called_value(inst)
212-
LLVM.isintrinsic(callee_f) || continue
213-
intr = Intrinsic(callee_f)
214-
isoverloaded(intr) || continue
215-
216-
# get an appropriately-overloaded intrinsic instantiation
217-
# NOTE: the overload types differs from the argument types
218-
intr_f = if intr == Intrinsic("llvm.memcpy")
219-
LLVM.Function(mod, intr, llvmtype.(arguments(inst)[1:end-1]))
220-
elseif intr == Intrinsic("llvm.lifetime.start") ||
221-
intr == Intrinsic("llvm.lifetime.end")
222-
LLVM.Function(mod, intr, [llvmtype(arguments(inst)[end])])
223-
else
224-
# TODO: use matchIntrinsicSignature to do this generically
225-
error("""Unsupported intrinsic call:
226-
$inst.
227-
Please file an issue with at https://github.com/JuliaGPU/GPUCompiler.jl""")
228-
end
228+
# map the arguments
229+
value_map = Dict{LLVM.Value, LLVM.Value}(
230+
param => new_args[i] for (i,param) in enumerate(parameters(f))
231+
)
229232

230-
# create a call to the new intrinsic
231-
# TODO: wrap setCalledFunction instead of using an IRBuilder
232-
position!(builder, inst)
233-
new_inst = if inst isa LLVM.CallInst
234-
call!(builder, intr_f, arguments(inst), operand_bundles(inst))
235-
else
236-
# TODO: invoke and callbr
237-
error("Rewrite of $(typeof(inst))-based calls is not implemented: $inst")
238-
end
239-
callconv!(new_inst, callconv(inst))
233+
value_map[f] = new_f
234+
clone_into!(new_f, f; value_map,
235+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
240236

241-
# replace the old call
242-
replace_uses!(inst, new_inst)
243-
@assert isempty(uses(inst))
244-
unsafe_delete!(LLVM.parent(inst), inst)
245-
end
246-
end
237+
# fall through
238+
br!(builder, blocks(new_f)[2])
247239
end
248240

249241
# remove the old function
@@ -253,6 +245,16 @@ function add_address_spaces!(mod::LLVM.Module, f::LLVM.Function)
253245
unsafe_delete!(mod, f)
254246
LLVM.name!(new_f, fn)
255247

248+
# clean-up after this pass (which runs after optimization)
249+
@dispose pm=ModulePassManager() begin
250+
cfgsimplification!(pm)
251+
scalar_repl_aggregates!(pm)
252+
early_cse!(pm)
253+
instruction_combining!(pm)
254+
255+
run!(pm, mod)
256+
end
257+
256258
return new_f
257259
end
258260

test/metal.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,6 @@ end
3131
@test occursin(r"@.*julia.*kernel.*\(({ i64 }|\[1 x i64\]) addrspace\(1\)\*", ir)
3232
end
3333

34-
@testset "byref aggregates with memcpy" begin
35-
ir = """
36-
declare void @llvm.memcpy.p0i8.p0i8.i64(i8*, i8*, i64, i1 immarg)
37-
38-
define void @kernel(i8* %0, i8* %1) {
39-
entry:
40-
call void @llvm.memcpy.p0i8.p0i8.i64(i8* %0, i8* %1, i64 0, i1 false)
41-
ret void
42-
}
43-
"""
44-
Context() do ctx
45-
mod = parse(LLVM.Module, ir; ctx)
46-
f = functions(mod)["kernel"]
47-
GPUCompiler.add_address_spaces!(mod, f)
48-
LLVM.verify(mod)
49-
end
50-
return
51-
52-
end
53-
5434
@testset "byref primitives" begin
5535
kernel(x) = return
5636

0 commit comments

Comments
 (0)