@@ -125,7 +125,7 @@ function add_backedge!(li::CodeInstance, caller::OptimizationState)
125
125
nothing
126
126
end
127
127
128
- function isinlineable (m:: Method , me:: OptimizationState , params:: OptimizationParams , bonus:: Int = 0 )
128
+ function isinlineable (m:: Method , me:: OptimizationState , params:: OptimizationParams , union_penalties :: Bool , bonus:: Int = 0 )
129
129
# compute the cost (size) of inlining this code
130
130
inlineable = false
131
131
cost_threshold = params. inline_cost_threshold
@@ -143,7 +143,7 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
143
143
end
144
144
end
145
145
if ! inlineable
146
- inlineable = inline_worthy (me. src. code, me. src, me. sptypes, me. slottypes, params, cost_threshold + bonus)
146
+ inlineable = inline_worthy (me. src. code, me. src, me. sptypes, me. slottypes, params, union_penalties, cost_threshold + bonus)
147
147
end
148
148
return inlineable
149
149
end
@@ -219,15 +219,14 @@ function optimize(opt::OptimizationState, params::OptimizationParams, @nospecial
219
219
replace_code_newstyle! (opt. src, ir, nargs)
220
220
221
221
# determine and cache inlineability
222
+ union_penalties = false
222
223
if ! force_noinline
223
- # don't keep ASTs for functions specialized on a Union argument
224
- # TODO : this helps avoid a type-system bug mis-computing sparams during intersection
225
224
sig = unwrap_unionall (opt. linfo. specTypes)
226
225
if isa (sig, DataType) && sig. name === Tuple. name
227
226
for P in sig. parameters
228
227
P = unwrap_unionall (P)
229
228
if isa (P, Union)
230
- force_noinline = true
229
+ union_penalties = true
231
230
break
232
231
end
233
232
end
@@ -252,7 +251,7 @@ function optimize(opt::OptimizationState, params::OptimizationParams, @nospecial
252
251
# For functions declared @inline, increase the cost threshold 20x
253
252
bonus += params. inline_cost_threshold* 19
254
253
end
255
- opt. src. inlineable = isinlineable (def, opt, params, bonus)
254
+ opt. src. inlineable = isinlineable (def, opt, params, union_penalties, bonus)
256
255
end
257
256
end
258
257
nothing
@@ -281,7 +280,9 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
281
280
# known return type
282
281
isknowntype (@nospecialize T) = (T === Union{}) || isa (T, Const) || isconcretetype (widenconst (T))
283
282
284
- function statement_cost (ex:: Expr , line:: Int , src:: CodeInfo , sptypes:: Vector{Any} , slottypes:: Vector{Any} , params:: OptimizationParams , error_path:: Bool = false )
283
+ function statement_cost (ex:: Expr , line:: Int , src:: CodeInfo , sptypes:: Vector{Any} ,
284
+ slottypes:: Vector{Any} , union_penalties:: Bool ,
285
+ params:: OptimizationParams , error_path:: Bool = false )
285
286
head = ex. head
286
287
if is_meta_expr_head (head)
287
288
return 0
@@ -314,6 +315,13 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
314
315
# tuple iteration/destructuring makes that impossible
315
316
# return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
316
317
return 0
318
+ elseif f === Main. Core. isa
319
+ # If we're in a union context, we penalize type computations
320
+ # on union types. In such cases, it is usually better to perform
321
+ # union splitting on the outside.
322
+ if union_penalties && isa (argextype (ex. args[2 ], src, sptypes, slottypes), Union)
323
+ return params. inline_nonleaf_penalty
324
+ end
317
325
elseif (f === Main. Core. arrayref || f === Main. Core. const_arrayref) && length (ex. args) >= 3
318
326
atyp = argextype (ex. args[3 ], src, sptypes, slottypes)
319
327
return isknowntype (atyp) ? 4 : error_path ? params. inline_error_path_cost : params. inline_nonleaf_penalty
@@ -362,10 +370,12 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
362
370
return 0
363
371
end
364
372
365
- function statement_or_branch_cost (@nospecialize (stmt), line:: Int , src:: CodeInfo , sptypes:: Vector{Any} , slottypes:: Vector{Any} , params:: OptimizationParams , throw_blocks:: Union{Nothing,BitSet} )
373
+ function statement_or_branch_cost (@nospecialize (stmt), line:: Int , src:: CodeInfo , sptypes:: Vector{Any} ,
374
+ slottypes:: Vector{Any} , union_penalties:: Bool , params:: OptimizationParams ,
375
+ throw_blocks:: Union{Nothing,BitSet} )
366
376
thiscost = 0
367
377
if stmt isa Expr
368
- thiscost = statement_cost (stmt, line, src, sptypes, slottypes, params,
378
+ thiscost = statement_cost (stmt, line, src, sptypes, slottypes, union_penalties, params,
369
379
params. unoptimize_throw_blocks && line in throw_blocks):: Int
370
380
elseif stmt isa GotoNode
371
381
# loops are generally always expensive
@@ -379,24 +389,24 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::CodeInfo,
379
389
end
380
390
381
391
function inline_worthy (body:: Array{Any,1} , src:: CodeInfo , sptypes:: Vector{Any} , slottypes:: Vector{Any} ,
382
- params:: OptimizationParams , cost_threshold:: Integer = params. inline_cost_threshold)
392
+ params:: OptimizationParams , union_penalties :: Bool = false , cost_threshold:: Integer = params. inline_cost_threshold)
383
393
bodycost:: Int = 0
384
394
throw_blocks = params. unoptimize_throw_blocks ? find_throw_blocks (body) : nothing
385
395
for line = 1 : length (body)
386
396
stmt = body[line]
387
- thiscost = statement_or_branch_cost (stmt, line, src, sptypes, slottypes, params, throw_blocks)
397
+ thiscost = statement_or_branch_cost (stmt, line, src, sptypes, slottypes, union_penalties, params, throw_blocks)
388
398
bodycost = plus_saturate (bodycost, thiscost)
389
399
bodycost > cost_threshold && return false
390
400
end
391
401
return true
392
402
end
393
403
394
- function statement_costs! (cost:: Vector{Int} , body:: Vector{Any} , src:: CodeInfo , sptypes:: Vector{Any} , params:: OptimizationParams )
404
+ function statement_costs! (cost:: Vector{Int} , body:: Vector{Any} , src:: CodeInfo , sptypes:: Vector{Any} , unionpenalties :: Bool , params:: OptimizationParams )
395
405
throw_blocks = params. unoptimize_throw_blocks ? find_throw_blocks (body) : nothing
396
406
maxcost = 0
397
407
for line = 1 : length (body)
398
408
stmt = body[line]
399
- thiscost = statement_or_branch_cost (stmt, line, src, sptypes, src. slottypes, params, throw_blocks)
409
+ thiscost = statement_or_branch_cost (stmt, line, src, sptypes, src. slottypes, unionpenalties, params, throw_blocks)
400
410
cost[line] = thiscost
401
411
if thiscost > maxcost
402
412
maxcost = thiscost
0 commit comments