Skip to content

Commit 50a6bd9

Browse files
committed
Add inference and optimization params to interface
This adds the `inference_params` and `optimization_params` functions to the interface. This allows each backend to pass different inference and optimization parameters to the interpreter.
1 parent 9b8f0f2 commit 50a6bd9

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

src/interface.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
192192

193193
# provide a specific interpreter to use.
194194
get_interpreter(@nospecialize(job::CompilerJob)) =
195-
GPUInterpreter(ci_cache(job), method_table(job), job.source.world)
195+
GPUInterpreter(ci_cache(job), method_table(job), job.source.world, inference_params(job), optimization_params(job))
196196

197197
# does this target support throwing Julia exceptions with jl_throw?
198198
# if not, calls to throw will be replaced with calls to the GPU runtime
@@ -270,6 +270,22 @@ ci_cache(@nospecialize(job::CompilerJob)) = GLOBAL_CI_CACHE
270270
# the method table to use
271271
method_table(@nospecialize(job::CompilerJob)) = GLOBAL_METHOD_TABLE
272272

273+
# the inference parameters to use when constructing the GPUInterpreter
274+
function inference_params(@nospecialize(job::CompilerJob))
275+
return InferenceParams(;unoptimize_throw_blocks=false)
276+
end
277+
278+
# the optimization parameters to use when constructing the GPUInterpreter
279+
function optimization_params(@nospecialize(job::CompilerJob))
280+
kwargs = NamedTuple()
281+
282+
if VERSION < v"1.8.0-DEV.486"
283+
kwargs = (kwargs..., unoptimize_throw_blocks=false)
284+
end
285+
286+
return OptimizationParams(;kwargs...)
287+
end
288+
273289
# how much debuginfo to emit
274290
function llvm_debug_info(@nospecialize(job::CompilerJob))
275291
if Base.JLOptions().debug_level == 0

src/jlgen.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ struct GPUInterpreter <: AbstractInterpreter
182182
inf_params::InferenceParams
183183
opt_params::OptimizationParams
184184

185-
function GPUInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt)
185+
186+
function GPUInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams)
186187
@assert world <= Base.get_world_counter()
187188

188189
return new(
@@ -196,9 +197,8 @@ struct GPUInterpreter <: AbstractInterpreter
196197
world,
197198

198199
# parameters for inference and optimization
199-
InferenceParams(unoptimize_throw_blocks=false),
200-
VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() :
201-
OptimizationParams(unoptimize_throw_blocks=false),
200+
ip,
201+
op
202202
)
203203
end
204204
end

src/validation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ function check_method(@nospecialize(job::CompilerJob))
3232
if job.source.kernel
3333
cache = ci_cache(job)
3434
mt = method_table(job)
35-
interp = GPUInterpreter(cache, mt, world)
35+
ip = inference_params(job)
36+
op = optimization_params(job)
37+
interp = GPUInterpreter(cache, mt, world, ip, op)
3638
rt = return_type(only(ms); interp)
3739

3840
if rt != Nothing

0 commit comments

Comments
 (0)