Skip to content

Commit 699b393

Browse files
authored
Metal: re-mangle memcpy intrinsics when changing pointer address spaces. (#337)
1 parent b3df434 commit 699b393

File tree

2 files changed

+62
-4
lines changed

2 files changed

+62
-4
lines changed

src/metal.jl

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
109109
entry_fn = LLVM.name(entry)
110110

111111
if job.source.kernel
112-
add_address_spaces!(job, mod, entry)
112+
add_address_spaces!(mod, entry)
113113
end
114114

115115
return functions(mod)[entry_fn]
@@ -148,10 +148,9 @@ end
148148
# generic pointer removal
149149
#
150150
# every pointer argument (e.g. byref objs) to a kernel needs an address space attached.
151-
function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
151+
function add_address_spaces!(mod::LLVM.Module, f::LLVM.Function)
152152
ctx = context(mod)
153153
ft = eltype(llvmtype(f))
154-
@compiler_assert LLVM.return_type(ft) == LLVM.VoidType(ctx) job
155154

156155
function remapType(src)
157156
# TODO: recurse in structs
@@ -204,6 +203,45 @@ function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
204203

205204
clone_into!(new_f, f; value_map, changes, type_mapper)
206205

206+
# update calls to overloaded intrinsic, re-mangling their names
207+
# XXX: shouldn't clone_into! do this?
208+
LLVM.@dispose builder=Builder(ctx) begin
209+
for bb in blocks(new_f), inst in instructions(bb)
210+
if inst isa LLVM.CallBase
211+
callee_f = called_value(inst)
212+
LLVM.isintrinsic(callee_f) || continue
213+
intr = Intrinsic(callee_f)
214+
isoverloaded(intr) || continue
215+
216+
# get an appropriately-overloaded intrinsic instantiation
217+
# XXX: apparently it differs per intrinsics which arguments to take into
218+
# consideration when generating an overload? for example, with memcpy
219+
# the trailing i1 argument is not included in the overloaded name.
220+
intr_f = if intr == Intrinsic("llvm.memcpy")
221+
LLVM.Function(mod, intr, llvmtype.(arguments(inst)[1:end-1]))
222+
else
223+
error("Unsupported intrinsic; please file an issue.")
224+
end
225+
226+
# create a call to the new intrinsic
227+
# TODO: wrap setCalledFunction instead of using an IRBuilder
228+
position!(builder, inst)
229+
new_inst = if inst isa LLVM.CallInst
230+
call!(builder, intr_f, arguments(inst), operand_bundles(inst))
231+
else
232+
# TODO: invoke and callbr
233+
error("Rewrite of $(typeof(inst))-based calls is not implemented: $inst")
234+
end
235+
callconv!(new_inst, callconv(inst))
236+
237+
# replace the old call
238+
replace_uses!(inst, new_inst)
239+
@assert isempty(uses(inst))
240+
unsafe_delete!(LLVM.parent(inst), inst)
241+
end
242+
end
243+
end
244+
207245
# remove the old function
208246
fn = LLVM.name(f)
209247
@assert isempty(uses(f))

test/metal.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Metal_LLVM_Tools_jll
1+
using Metal_LLVM_Tools_jll, LLVM
22

33
include("definitions/metal.jl")
44

@@ -31,6 +31,26 @@ end
3131
@test occursin(r"@.*julia.*kernel.*\(({ i64 }|\[1 x i64\]) addrspace\(1\)\*", ir)
3232
end
3333

34+
@testset "byref aggregates with memcpy" begin
35+
ir = """
36+
declare void @llvm.memcpy.p0i8.p0i8.i64(i8*, i8*, i64, i1 immarg)
37+
38+
define void @kernel(i8* %0, i8* %1) {
39+
entry:
40+
call void @llvm.memcpy.p0i8.p0i8.i64(i8* %0, i8* %1, i64 0, i1 false)
41+
ret void
42+
}
43+
"""
44+
Context() do ctx
45+
mod = parse(LLVM.Module, ir; ctx)
46+
f = functions(mod)["kernel"]
47+
GPUCompiler.add_address_spaces!(mod, f)
48+
LLVM.verify(mod)
49+
end
50+
return
51+
52+
end
53+
3454
@testset "byref primitives" begin
3555
kernel(x) = return
3656

0 commit comments

Comments
 (0)