Skip to content

Commit 81c5ce7

Browse files
authored
Fixes for LLVM 17 (#584)
* Switch deferred codegen to use integer input. All consumers seem to rely on this. * Don't create typed pointer when not needed. * Fix pass manager for LICM. * Fix NVVM Reflection on LLVM 17.
1 parent e18cdd2 commit 81c5ce7

File tree

2 files changed

+42
-15
lines changed

2 files changed

+42
-15
lines changed

src/driver.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,9 @@ end
184184
# generated functions so use the current world counter, which may be too new
185185
# for the world we're compiling for.
186186

187-
pseudo_ptr = reinterpret(Ptr{Cvoid}, id)
188187
quote
189188
# TODO: add an edge to this method instance to support method redefinitions
190-
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $pseudo_ptr)
189+
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
191190
end
192191
end
193192

@@ -277,7 +276,10 @@ const __llvm_initialized = Ref(false)
277276
for call in worklist[dyn_job]
278277
@dispose builder=IRBuilder() begin
279278
position!(builder, call)
280-
fptr = if VERSION >= v"1.12.0-DEV.225"
279+
fptr = if LLVM.version() >= v"17"
280+
T_ptr = LLVM.PointerType()
281+
bitcast!(builder, dyn_entry, T_ptr)
282+
elseif VERSION >= v"1.12.0-DEV.225"
281283
T_ptr = LLVM.PointerType(LLVM.Int8Type())
282284
bitcast!(builder, dyn_entry, T_ptr)
283285
else

src/ptx.jl

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
167167
# but Julia's pass sequence only invokes the simple unroller.
168168
add!(fpm, LoopUnrollPass(LoopUnrollOptions(; job.config.opt_level)))
169169
add!(fpm, InstCombinePass()) # clean-up redundancy
170-
add!(fpm, NewPMLoopPassManager) do lpm
170+
add!(fpm, NewPMLoopPassManager, #=UseMemorySSA=#true) do lpm
171171
add!(lpm, LICMPass()) # the inner runtime check might be
172172
# outer loop invariant
173173
end
@@ -439,20 +439,42 @@ function nvvm_reflect!(fun::LLVM.Function)
439439
for use in uses(reflect_function)
440440
call = user(use)
441441
isa(call, LLVM.CallInst) || continue
442-
length(operands(call)) == 2 || error("Wrong number of operands to __nvvm_reflect function")
442+
if length(operands(call)) != 2
443+
@error """Unrecognized format of __nvvm_reflect call:
444+
$(string(call))
445+
Wrong number of operands: expected 2, got $(length(operands(call)))."""
446+
continue
447+
end
443448

444449
# decode the string argument
445-
str = operands(call)[1]
446-
isa(str, LLVM.ConstantExpr) || error("Format of __nvvm__reflect function not recognized")
447-
sym = operands(str)[1]
448-
if isa(sym, LLVM.ConstantExpr) && opcode(sym) == LLVM.API.LLVMGetElementPtr
449-
# CUDA 11.0 or below
450-
sym = operands(sym)[1]
450+
if LLVM.version() >= v"17"
451+
sym = operands(call)[1]
452+
else
453+
str = operands(call)[1]
454+
if !isa(str, LLVM.ConstantExpr) || opcode(str) != LLVM.API.LLVMGetElementPtr
455+
@safe_error """Unrecognized format of __nvvm_reflect call:
456+
$(string(call))
457+
Operand should be a GEP instruction, got a $(typeof(str)). Please file an issue."""
458+
continue
459+
end
460+
sym = operands(str)[1]
461+
if isa(sym, LLVM.ConstantExpr) && opcode(sym) == LLVM.API.LLVMGetElementPtr
462+
# CUDA 11.0 or below
463+
sym = operands(sym)[1]
464+
end
465+
end
466+
if !isa(sym, LLVM.GlobalVariable)
467+
@safe_error """Unrecognized format of __nvvm_reflect call:
468+
$(string(call))
469+
Operand should be a global variable, got a $(typeof(sym)). Please file an issue."""
470+
continue
451471
end
452-
isa(sym, LLVM.GlobalVariable) || error("Format of __nvvm__reflect function not recognized")
453472
sym_op = operands(sym)[1]
454-
isa(sym_op, LLVM.ConstantArray) || isa(sym_op, LLVM.ConstantDataArray) ||
455-
error("Format of __nvvm__reflect function not recognized")
473+
if !isa(sym_op, LLVM.ConstantArray) && !isa(sym_op, LLVM.ConstantDataArray)
474+
@safe_error """Unrecognized format of __nvvm_reflect call:
475+
$(string(call))
476+
Operand should be a constant array, got a $(typeof(sym_op)). Please file an issue."""
477+
end
456478
chars = convert.(Ref(UInt8), collect(sym_op))
457479
reflect_arg = String(chars[1:end-1])
458480

@@ -477,7 +499,10 @@ function nvvm_reflect!(fun::LLVM.Function)
477499
elseif reflect_arg == "__CUDA_ARCH"
478500
ConstantInt(reflect_typ, job.config.target.cap.major*100 + job.config.target.cap.minor*10)
479501
else
480-
@warn "Unknown __nvvm_reflect argument: $reflect_arg. Please file an issue."
502+
@safe_error """Unrecognized format of __nvvm_reflect call:
503+
$(string(call))
504+
Unknown argument $reflect_arg. Please file an issue."""
505+
continue
481506
end
482507

483508
replace_uses!(call, reflect_val)

0 commit comments

Comments
 (0)