Skip to content

Commit d5aba5d

Browse files
committed
improve optimizability of forward_demand.jl
1 parent 3ea59af commit d5aba5d

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

src/codegen/forward_demand.jl

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff:
2121
end
2222
=#
2323

24-
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, order::Int; custom_diff!, diff_cache)
24+
# TODO interp::AbstractADInterpreter instead interp::AbstractInterpreter?
25+
26+
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
27+
ssa::SSAValue, order::Int;
28+
custom_diff!, diff_cache)
2529
if haskey(diff_cache, ssa)
2630
return diff_cache[ssa]
2731
end
@@ -34,9 +38,19 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSA
3438
end
3539
return Δssa
3640
end
37-
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, val::Union{Integer, AbstractFloat}, order::Int; custom_diff!, diff_cache) = zero(val)
38-
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, @nospecialize(arg), order::Int; custom_diff!, diff_cache) = ChainRulesCore.NoTangent()
39-
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Argument, order::Int; custom_diff!, diff_cache)
41+
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
42+
val::Union{Integer, AbstractFloat}, order::Int;
43+
custom_diff!, diff_cache)
44+
return zero(val)
45+
end
46+
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
47+
@nospecialize(arg), order::Int;
48+
custom_diff!, diff_cache)
49+
return ChainRulesCore.NoTangent()
50+
end
51+
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
52+
arg::Argument, order::Int;
53+
custom_diff!, diff_cache)
4054
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache)
4155
val = custom_diff!(ir, SSAValue(0), arg, recurse)
4256
if val !== nothing
@@ -45,7 +59,9 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Arg
4559
return ChainRulesCore.NoTangent()
4660
end
4761

48-
function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int; custom_diff!, diff_cache)
62+
function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
63+
ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int;
64+
custom_diff!, diff_cache)
4965
stmt = inst[:inst]
5066
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache)
5167
if (val = custom_diff!(ir, ssa, stmt, recurse)) !== nothing
@@ -105,8 +121,7 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
105121
argtypes = Any[argextype(arg, ir) for arg in Δtpl.args[2:end]]
106122
tup_T = CC.tuple_tfunc(CC.typeinf_lattice(interp), argtypes)
107123

108-
Δ = insert_node!(ir, ssa, NewInstruction(
109-
Δtpl, tup_T))
124+
Δ = insert_node!(ir, ssa, NewInstruction(Δtpl, tup_T))
110125

111126
# Now that we know the arguments, do a proper typeinf for this particular callsite
112127
new_spec_types = Tuple{typeof(ChainRulesCore.frule), widenconst(tup_T), (widenconst(argextype(arg, ir)) for arg in args)...}
@@ -175,15 +190,15 @@ function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vecto
175190
error()
176191
end
177192
end
178-
forward_visit!(ir::IRCode, _, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!) = nothing
193+
forward_visit!(::IRCode, @nospecialize(x), ::Int, ::Vector{Pair{Int, Bool}}, _) = nothing
179194
function forward_visit!(ir::IRCode, a::Argument, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!)
180195
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
181196
return visit_custom!(ir, a, order, recurse)
182197
end
183198

184199

185200
"""
186-
forward_diff_no_inf!(ir, to_diff; visit_custom!, transform)
201+
forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; visit_custom!, transform!)
187202
188203
Internal method which generates the code for forward mode diffentiation
189204
@@ -192,13 +207,14 @@ Internal method which generates the code for forward mode diffentiation
192207
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
193208
paired with the order (first deriviative, second derivative etc)
194209
195-
- `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
210+
- `visit_custom!(ir::IRCode, stmt, order::Int, recurse::Bool) -> Bool`:
196211
decides if the custom `transform!` should be applied to a `stmt` or not
197212
Default: `false` for all statements
198-
- `transform!(ir, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
213+
- `transform!(ir::IRCode, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
199214
"""
200-
function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
201-
visit_custom! = (args...)->false, transform! = (args...)->error())
215+
function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
216+
visit_custom! = (@nospecialize args...)->false,
217+
transform! = (@nospecialize args...)->error())
202218
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
203219
ssa_orders = [0=>false for i = 1:length(ir.stmts)]
204220
for (ssa, order) in to_diff
@@ -208,7 +224,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
208224
truncation_map = Dict{Pair{SSAValue, Int}, SSAValue}()
209225

210226
# Step 2: Transform
211-
function maparg(arg, ssa, order)
227+
function maparg(@nospecialize(arg), ssa::SSAValue, order::Int)
212228
if isa(arg, SSAValue)
213229
if arg.id > length(ssa_orders)
214230
# This is possible if the custom transform touched another statement.
@@ -259,10 +275,16 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
259275
inst = ir[SSAValue(ssa)]
260276
stmt = inst[:inst]
261277
if isexpr(stmt, :invoke)
262-
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args[2:end])...)
278+
newargs = map(stmt.args[2:end]) do @nospecialize arg
279+
maparg(arg, SSAValue(ssa), order)
280+
end
281+
inst[:inst] = Expr(:call, ∂☆{order}(), newargs...)
263282
inst[:type] = Any
264283
elseif isexpr(stmt, :call)
265-
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args)...)
284+
newargs = map(stmt.args) do @nospecialize arg
285+
maparg(arg, SSAValue(ssa), order)
286+
end
287+
inst[:inst] = Expr(:call, ∂☆{order}(), newargs...)
266288
inst[:type] = Any
267289
elseif isa(stmt, PiNode)
268290
# TODO: New PiNode that discriminates based on primal?
@@ -288,7 +310,6 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
288310
end
289311
end
290312

291-
292313
function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::MethodInstance,
293314
to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
294315
forward_diff_no_inf!(ir, to_diff; kwargs...)

0 commit comments

Comments
 (0)