Skip to content

Commit c687cae

Browse files
authored
Merge pull request #335 from lcw/lcw/always_inline
Add `always_inline` target property to ptx backend
2 parents 9b8f0f2 + 9c1a20b commit c687cae

File tree

6 files changed

+81
-11
lines changed

6 files changed

+81
-11
lines changed

src/interface.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
192192

193193
# provide a specific interpreter to use.
194194
get_interpreter(@nospecialize(job::CompilerJob)) =
195-
GPUInterpreter(ci_cache(job), method_table(job), job.source.world)
195+
GPUInterpreter(ci_cache(job), method_table(job), job.source.world, inference_params(job), optimization_params(job))
196196

197197
# does this target support throwing Julia exceptions with jl_throw?
198198
# if not, calls to throw will be replaced with calls to the GPU runtime
@@ -265,11 +265,34 @@ link_libraries!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
265265
valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false
266266

267267
# the codeinfo cache to use
268-
ci_cache(@nospecialize(job::CompilerJob)) = GLOBAL_CI_CACHE
268+
function ci_cache(@nospecialize(job::CompilerJob))
269+
lock(GLOBAL_CI_CACHES_LOCK) do
270+
cache = get!(GLOBAL_CI_CACHES, (typeof(job.target), inference_params(job), optimization_params(job))) do
271+
CodeCache()
272+
end
273+
return cache
274+
end
275+
end
269276

270277
# the method table to use
271278
method_table(@nospecialize(job::CompilerJob)) = GLOBAL_METHOD_TABLE
272279

280+
# the inference parameters to use when constructing the GPUInterpreter
281+
function inference_params(@nospecialize(job::CompilerJob))
282+
return InferenceParams(;unoptimize_throw_blocks=false)
283+
end
284+
285+
# the optimization parameters to use when constructing the GPUInterpreter
286+
function optimization_params(@nospecialize(job::CompilerJob))
287+
kwargs = NamedTuple()
288+
289+
if VERSION < v"1.8.0-DEV.486"
290+
kwargs = (kwargs..., unoptimize_throw_blocks=false)
291+
end
292+
293+
return OptimizationParams(;kwargs...)
294+
end
295+
273296
# how much debuginfo to emit
274297
function llvm_debug_info(@nospecialize(job::CompilerJob))
275298
if Base.JLOptions().debug_level == 0

src/jlgen.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## cache
44

5-
using Core.Compiler: CodeInstance, MethodInstance
5+
using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams
66

77
struct CodeCache
88
dict::Dict{MethodInstance,Vector{CodeInstance}}
@@ -39,7 +39,8 @@ end
3939

4040
Base.empty!(cc::CodeCache) = empty!(cc.dict)
4141

42-
const GLOBAL_CI_CACHE = CodeCache()
42+
const GLOBAL_CI_CACHES = Dict{Tuple{DataType, InferenceParams, OptimizationParams}, CodeCache}()
43+
const GLOBAL_CI_CACHES_LOCK = ReentrantLock()
4344

4445

4546
## method invalidations
@@ -182,7 +183,8 @@ struct GPUInterpreter <: AbstractInterpreter
182183
inf_params::InferenceParams
183184
opt_params::OptimizationParams
184185

185-
function GPUInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt)
186+
187+
function GPUInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams)
186188
@assert world <= Base.get_world_counter()
187189

188190
return new(
@@ -196,9 +198,8 @@ struct GPUInterpreter <: AbstractInterpreter
196198
world,
197199

198200
# parameters for inference and optimization
199-
InferenceParams(unoptimize_throw_blocks=false),
200-
VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() :
201-
OptimizationParams(unoptimize_throw_blocks=false),
201+
ip,
202+
op
202203
)
203204
end
204205
end

src/ptx.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
2121
maxthreads::Union{Nothing,Int,NTuple{<:Any,Int}} = nothing
2222
blocks_per_sm::Union{Nothing,Int} = nothing
2323
maxregs::Union{Nothing,Int} = nothing
24+
always_inline::Bool = false
2425
end
2526

2627
function Base.hash(target::PTXCompilerTarget, h::UInt)
@@ -35,6 +36,7 @@ function Base.hash(target::PTXCompilerTarget, h::UInt)
3536
h = hash(target.maxthreads, h)
3637
h = hash(target.blocks_per_sm, h)
3738
h = hash(target.maxregs, h)
39+
h = hash(target.always_inline, h)
3840

3941
h
4042
end
@@ -74,6 +76,7 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{PTXCompilerTarget}))
7476
job.target.maxthreads !== nothing && print(io, ", maxthreads=$(job.target.maxthreads)")
7577
job.target.blocks_per_sm !== nothing && print(io, ", blocks_per_sm=$(job.target.blocks_per_sm)")
7678
job.target.maxregs !== nothing && print(io, ", maxregs=$(job.target.maxregs)")
79+
job.target.always_inline !== nothing && print(io, ", always_inline=$(job.target.always_inline)")
7780
end
7881

7982
const ptx_intrinsics = ("vprintf", "__assertfail", "malloc", "free")
@@ -86,6 +89,20 @@ runtime_slug(@nospecialize(job::CompilerJob{PTXCompilerTarget})) =
8689
"-debuginfo=$(Int(llvm_debug_info(job)))" *
8790
"-exitable=$(job.target.exitable)"
8891

92+
function optimization_params(@nospecialize(job::CompilerJob{PTXCompilerTarget}))
93+
kwargs = NamedTuple()
94+
95+
if VERSION < v"1.8.0-DEV.486"
96+
kwargs = (kwargs..., unoptimize_throw_blocks=false)
97+
end
98+
99+
if job.target.always_inline
100+
kwargs = (kwargs..., inline_cost_threshold=typemax(Int))
101+
end
102+
103+
return OptimizationParams(;kwargs...)
104+
end
105+
89106
function process_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), mod::LLVM.Module)
90107
ctx = context(mod)
91108

src/validation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ function check_method(@nospecialize(job::CompilerJob))
3232
if job.source.kernel
3333
cache = ci_cache(job)
3434
mt = method_table(job)
35-
interp = GPUInterpreter(cache, mt, world)
35+
ip = inference_params(job)
36+
op = optimization_params(job)
37+
interp = GPUInterpreter(cache, mt, world, ip, op)
3638
rt = return_type(only(ms); interp)
3739

3840
if rt != Nothing

test/definitions/ptx.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ GPUCompiler.runtime_module(::PTXCompilerJob) = PTXTestRuntime
3939

4040
function ptx_job(@nospecialize(func), @nospecialize(types); kernel::Bool=false,
4141
minthreads=nothing, maxthreads=nothing, blocks_per_sm=nothing,
42-
maxregs=nothing, kwargs...)
42+
maxregs=nothing, always_inline=false, kwargs...)
4343
source = FunctionSpec(func, Base.to_tuple_type(types), kernel)
4444
target = PTXCompilerTarget(cap=v"7.0",
4545
minthreads=minthreads, maxthreads=maxthreads,
46-
blocks_per_sm=blocks_per_sm, maxregs=maxregs)
46+
blocks_per_sm=blocks_per_sm, maxregs=maxregs,
47+
always_inline=always_inline)
4748
params = TestCompilerParams()
4849
CompilerJob(target, source, params), kwargs
4950
end

test/ptx.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,32 @@ end
174174
end
175175
end
176176

177+
@testset "always_inline" begin
178+
@eval f_expensive(x) = $(foldl((e, _) -> :(sink($e) + sink(x)), 1:100; init=:x))
179+
function g(x)
180+
f_expensive(x)
181+
return
182+
end
183+
function h(x)
184+
f_expensive(x)
185+
return
186+
end
187+
188+
asm = sprint(io->ptx_code_native(io, g, Tuple{Int64}; kernel=true))
189+
@test occursin(r"\.func .*julia_f_expensive", asm)
190+
191+
asm = sprint(io->ptx_code_native(io, g, Tuple{Int64};
192+
kernel=true, always_inline=true))
193+
@test !occursin(r"\.func .*julia_f_expensive", asm)
194+
195+
asm = sprint(io->ptx_code_native(io, h, Tuple{Int64};
196+
kernel=true, always_inline=true))
197+
@test !occursin(r"\.func .*julia_f_expensive", asm)
198+
199+
asm = sprint(io->ptx_code_native(io, h, Tuple{Int64}; kernel=true))
200+
@test occursin(r"\.func .*julia_f_expensive", asm)
201+
end
202+
177203
@testset "child function reuse" begin
178204
# bug: depending on a child function from multiple parents resulted in
179205
# the child only being present once

0 commit comments

Comments
 (0)