Skip to content

Commit 26edcbd

Browse files
committed
Streamline creation of GPUCompiler objects.
1 parent 557def9 commit 26edcbd

File tree

7 files changed

+316
-309
lines changed

7 files changed

+316
-309
lines changed

src/CUDA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ include("../lib/cupti/CUPTI.jl")
6767
export CUPTI
6868

6969
# compiler implementation
70-
include("compiler/gpucompiler.jl")
70+
include("compiler/compilation.jl")
7171
include("compiler/execution.jl")
7272
include("compiler/exceptions.jl")
7373
include("compiler/reflection.jl")

src/compiler/compilation.jl

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
## gpucompiler interface implementation
2+
3+
struct CUDACompilerParams <: AbstractCompilerParams end
4+
const CUDACompilerConfig = CompilerConfig{PTXCompilerTarget, CUDACompilerParams}
5+
const CUDACompilerJob = CompilerJob{PTXCompilerTarget,CUDACompilerParams}
6+
7+
GPUCompiler.runtime_module(@nospecialize(job::CUDACompilerJob)) = CUDA
8+
9+
# filter out functions from libdevice and cudadevrt
10+
GPUCompiler.isintrinsic(@nospecialize(job::CUDACompilerJob), fn::String) =
11+
invoke(GPUCompiler.isintrinsic,
12+
Tuple{CompilerJob{PTXCompilerTarget}, typeof(fn)},
13+
job, fn) ||
14+
fn == "__nvvm_reflect" || startswith(fn, "cuda")
15+
16+
function GPUCompiler.link_libraries!(@nospecialize(job::CUDACompilerJob), mod::LLVM.Module,
17+
undefined_fns::Vector{String})
18+
invoke(GPUCompiler.link_libraries!,
19+
Tuple{CompilerJob{PTXCompilerTarget}, typeof(mod), typeof(undefined_fns)},
20+
job, mod, undefined_fns)
21+
link_libdevice!(mod, job.config.target.cap, undefined_fns)
22+
end
23+
24+
GPUCompiler.method_table(@nospecialize(job::CUDACompilerJob)) = method_table
25+
26+
GPUCompiler.kernel_state_type(job::CUDACompilerJob) = KernelState
27+
28+
29+
## compiler implementation (cache, configure, compile, and link)
30+
31+
# cache of compilation caches, per context
32+
const _compiler_caches = Dict{CuContext, Dict{UInt, Any}}();
33+
function compiler_cache(ctx::CuContext)
34+
cache = get(_compiler_caches, ctx, nothing)
35+
if cache === nothing
36+
cache = Dict{UInt, Any}()
37+
_compiler_caches[ctx] = cache
38+
end
39+
return cache
40+
end
41+
42+
# cache of compiler configurations, per device (but additionally configurable via kwargs)
43+
const _toolchain = Ref{Any}()
44+
const _compiler_configs = Dict{UInt, CUDACompilerConfig}()
45+
function compiler_config(dev; kwargs...)
46+
h = hash(dev, hash(kwargs))
47+
config = get(_compiler_configs, h, nothing)
48+
if config === nothing
49+
config = _compiler_config(dev; kwargs...)
50+
_compiler_configs[h] = config
51+
end
52+
return config
53+
end
54+
@noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false, kwargs...)
55+
# determine the toolchain (cached, because this is slow)
56+
if !isassigned(_toolchain)
57+
_toolchain[] = supported_toolchain()
58+
end
59+
toolchain = _toolchain[]::@NamedTuple{cap::Vector{VersionNumber}, ptx::Vector{VersionNumber}}
60+
61+
# select the highest capability that is supported by both the toolchain and device
62+
caps = filter(toolchain_cap -> toolchain_cap <= capability(dev), toolchain.cap)
63+
isempty(caps) &&
64+
error("Your $(name(dev)) GPU with capability v$(capability(dev)) is not supported by the available toolchain")
65+
cap = maximum(caps)
66+
67+
# select the PTX ISA we assume to be available
68+
# (we actually only need 6.2, but NVPTX doesn't support that)
69+
ptx = v"6.3"
70+
71+
# we need to take care emitting LLVM instructions like `unreachable`, which
72+
# may result in thread-divergent control flow that older `ptxas` doesn't like.
73+
# see e.g. JuliaGPU/CUDAnative.jl#4
74+
unreachable = true
75+
if cap < v"7" || runtime_version() < v"11.3"
76+
unreachable = false
77+
end
78+
79+
# there have been issues with emitting PTX `exit` instead of `trap` as well,
80+
# see e.g. JuliaGPU/CUDA.jl#431 and NVIDIA bug #3231266 (but since switching
81+
# to the toolkit's `ptxas` that specific machine/GPU now _requires_ exit...)
82+
exitable = true
83+
if cap < v"7"
84+
exitable = false
85+
end
86+
87+
# NVIDIA bug #3600554: ptxas segfaults with our debug info, fixed in 11.7
88+
debuginfo = runtime_version() >= v"11.7"
89+
90+
# create GPUCompiler objects
91+
target = PTXCompilerTarget(; cap, ptx, debuginfo, unreachable, exitable, kwargs...)
92+
params = CUDACompilerParams()
93+
CompilerConfig(target, params; kernel, name, always_inline)
94+
end
95+
96+
# compile to executable machine code
97+
function compile(@nospecialize(job::CompilerJob))
98+
# TODO: on 1.9, this actually creates a context. cache those.
99+
JuliaContext() do ctx
100+
compile(job, ctx)
101+
end
102+
end
103+
function compile(@nospecialize(job::CompilerJob), ctx)
104+
# lower to PTX
105+
mi, mi_meta = GPUCompiler.emit_julia(job)
106+
ir, ir_meta = GPUCompiler.emit_llvm(job, mi; ctx)
107+
asm, asm_meta = GPUCompiler.emit_asm(job, ir; format=LLVM.API.LLVMAssemblyFile)
108+
109+
# remove extraneous debug info on lower debug levels
110+
if Base.JLOptions().debug_level < 2
111+
# LLVM sets `.target debug` as soon as the debug emission kind isn't NoDebug. this
112+
# is unwanted, as the flag makes `ptxas` behave as if `--device-debug` were set.
113+
# ideally, we'd need something like LocTrackingOnly/EmitDebugInfo from D4234, but
114+
# that got removed in favor of NoDebug in D18808, seemingly breaking the use case of
115+
# only emitting `.loc` instructions...
116+
#
117+
# according to NVIDIA, "it is fine for PTX producers to produce debug info but not
118+
# set `.target debug` and if `--device-debug` isn't passed, PTXAS will compile in
119+
# release mode".
120+
asm = replace(asm, r"(\.target .+), debug" => s"\1")
121+
end
122+
123+
# check if we'll need the device runtime
124+
undefined_fs = filter(collect(functions(ir))) do f
125+
isdeclaration(f) && !LLVM.isintrinsic(f)
126+
end
127+
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
128+
"__nvvm_reflect" #= TODO: should have been optimized away =#]
129+
needs_cudadevrt = !isempty(setdiff(LLVM.name.(undefined_fs), intrinsic_fns))
130+
131+
# find externally-initialized global variables; we'll access those using CUDA APIs.
132+
external_gvars = filter(isextinit, collect(globals(ir))) .|> LLVM.name
133+
134+
# prepare invocations of CUDA compiler tools
135+
ptxas_opts = String[]
136+
nvlink_opts = String[]
137+
## debug flags
138+
if Base.JLOptions().debug_level == 1
139+
push!(ptxas_opts, "--generate-line-info")
140+
elseif Base.JLOptions().debug_level >= 2
141+
push!(ptxas_opts, "--device-debug")
142+
push!(nvlink_opts, "--debug")
143+
end
144+
## relocatable device code
145+
if needs_cudadevrt
146+
push!(ptxas_opts, "--compile-only")
147+
end
148+
149+
arch = "sm_$(job.config.target.cap.major)$(job.config.target.cap.minor)"
150+
151+
# compile to machine code
152+
# NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow
153+
ptx_input = tempname(cleanup=false) * ".ptx"
154+
ptxas_output = tempname(cleanup=false) * ".cubin"
155+
write(ptx_input, asm)
156+
157+
# we could use the driver's embedded JIT compiler, but that has several disadvantages:
158+
# 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to
159+
# upgrade the toolkit to get a newer compiler;
160+
# 2. version checking is simpler, we otherwise need to use NVML to query the driver
161+
# version, which is hard to correlate to PTX JIT improvements;
162+
# 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an
163+
# older driver, we should use the newer compiler to ensure compatibility.
164+
append!(ptxas_opts, [
165+
"--verbose",
166+
"--gpu-name", arch,
167+
"--output-file", ptxas_output,
168+
ptx_input
169+
])
170+
proc, log = run_and_collect(`$(ptxas()) $ptxas_opts`)
171+
log = strip(log)
172+
if !success(proc)
173+
reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" :
174+
"ptxas exited with code $(proc.exitcode)"
175+
msg = "Failed to compile PTX code ($reason)"
176+
msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))"
177+
if !isempty(log)
178+
msg *= "\n" * log
179+
end
180+
msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)"
181+
error(msg)
182+
elseif !isempty(log)
183+
@debug "PTX compiler log:\n" * log
184+
end
185+
rm(ptx_input)
186+
187+
# link device libraries, if necessary
188+
#
189+
# this requires relocatable device code, which prevents certain optimizations and
190+
# hurts performance. as such, we only do so when absolutely necessary.
191+
# TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`.
192+
# fails with `Ignoring -lto option because no LTO objects found`
193+
if needs_cudadevrt
194+
nvlink_output = tempname(cleanup=false) * ".cubin"
195+
append!(nvlink_opts, [
196+
"--verbose", "--extra-warnings",
197+
"--arch", arch,
198+
"--library-path", dirname(libcudadevrt),
199+
"--library", "cudadevrt",
200+
"--output-file", nvlink_output,
201+
ptxas_output
202+
])
203+
proc, log = run_and_collect(`$(nvlink()) $nvlink_opts`)
204+
log = strip(log)
205+
if !success(proc)
206+
reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" :
207+
"nvlink exited with code $(proc.exitcode)"
208+
msg = "Failed to link PTX code ($reason)"
209+
msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))"
210+
if !isempty(log)
211+
msg *= "\n" * log
212+
end
213+
msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)"
214+
error(msg)
215+
elseif !isempty(log)
216+
@debug "PTX linker info log:\n" * log
217+
end
218+
rm(ptxas_output)
219+
220+
image = read(nvlink_output)
221+
rm(nvlink_output)
222+
else
223+
image = read(ptxas_output)
224+
rm(ptxas_output)
225+
end
226+
227+
return (image, entry=LLVM.name(ir_meta.entry), external_gvars)
228+
end
229+
230+
# link into an executable kernel
231+
function link(@nospecialize(job::CompilerJob), compiled)
232+
# load as an executable kernel object
233+
ctx = context()
234+
mod = CuModule(compiled.image)
235+
CuFunction(mod, compiled.entry)
236+
end
237+
238+
239+
## helpers
240+
241+
# run a binary and collect all relevant output
242+
function run_and_collect(cmd)
243+
stdout = Pipe()
244+
proc = run(pipeline(ignorestatus(cmd); stdout, stderr=stdout), wait=false)
245+
close(stdout.in)
246+
247+
reader = Threads.@spawn String(read(stdout))
248+
Base.wait(proc)
249+
log = strip(fetch(reader))
250+
251+
return proc, log
252+
end

0 commit comments

Comments
 (0)