@@ -21,7 +21,11 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff:
21
21
end
22
22
=#
23
23
24
- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , order:: Int ; custom_diff!, diff_cache)
24
+ # TODO interp::AbstractADInterpreter instead interp::AbstractInterpreter?
25
+
26
+ function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
27
+ ssa:: SSAValue , order:: Int ;
28
+ custom_diff!, diff_cache)
25
29
if haskey (diff_cache, ssa)
26
30
return diff_cache[ssa]
27
31
end
@@ -34,9 +38,19 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSA
34
38
end
35
39
return Δssa
36
40
end
37
- forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , val:: Union{Integer, AbstractFloat} , order:: Int ; custom_diff!, diff_cache) = zero (val)
38
- forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , @nospecialize (arg), order:: Int ; custom_diff!, diff_cache) = ChainRulesCore. NoTangent ()
39
- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , arg:: Argument , order:: Int ; custom_diff!, diff_cache)
41
+ function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
42
+ val:: Union{Integer, AbstractFloat} , order:: Int ;
43
+ custom_diff!, diff_cache)
44
+ return zero (val)
45
+ end
46
+ function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
47
+ @nospecialize (arg), order:: Int ;
48
+ custom_diff!, diff_cache)
49
+ return ChainRulesCore. NoTangent ()
50
+ end
51
+ function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
52
+ arg:: Argument , order:: Int ;
53
+ custom_diff!, diff_cache)
40
54
recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache)
41
55
val = custom_diff! (ir, SSAValue (0 ), arg, recurse)
42
56
if val != = nothing
@@ -45,7 +59,9 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Arg
45
59
return ChainRulesCore. NoTangent ()
46
60
end
47
61
48
- function forward_diff_uncached! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , inst:: Core.Compiler.Instruction , order:: Int ; custom_diff!, diff_cache)
62
+ function forward_diff_uncached! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
63
+ ssa:: SSAValue , inst:: Core.Compiler.Instruction , order:: Int ;
64
+ custom_diff!, diff_cache)
49
65
stmt = inst[:inst ]
50
66
recurse (x) = forward_diff! (ir, interp, irsv, x, order; custom_diff!, diff_cache)
51
67
if (val = custom_diff! (ir, ssa, stmt, recurse)) != = nothing
@@ -105,8 +121,7 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
105
121
argtypes = Any[argextype (arg, ir) for arg in Δtpl. args[2 : end ]]
106
122
tup_T = CC. tuple_tfunc (CC. typeinf_lattice (interp), argtypes)
107
123
108
- Δ = insert_node! (ir, ssa, NewInstruction (
109
- Δtpl, tup_T))
124
+ Δ = insert_node! (ir, ssa, NewInstruction (Δtpl, tup_T))
110
125
111
126
# Now that we know the arguments, do a proper typeinf for this particular callsite
112
127
new_spec_types = Tuple{typeof (ChainRulesCore. frule), widenconst (tup_T), (widenconst (argextype (arg, ir)) for arg in args). .. }
@@ -175,15 +190,15 @@ function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vecto
175
190
error ()
176
191
end
177
192
end
178
- forward_visit! (ir :: IRCode , _, order :: Int , ssa_orders :: Vector{Pair{Int, Bool}} , visit_custom! ) = nothing
193
+ forward_visit! (:: IRCode , @nospecialize (x), :: Int , :: Vector{Pair{Int, Bool}} , _ ) = nothing
179
194
function forward_visit! (ir:: IRCode , a:: Argument , order:: Int , ssa_orders:: Vector{Pair{Int, Bool}} , visit_custom!)
180
195
recurse (@nospecialize (val)) = forward_visit! (ir, val, order, ssa_orders, visit_custom!)
181
196
return visit_custom! (ir, a, order, recurse)
182
197
end
183
198
184
199
185
200
"""
186
- forward_diff_no_inf!(ir, to_diff; visit_custom!, transform)
201
+ forward_diff_no_inf!(ir::IRCode , to_diff::Vector{Pair{SSAValue,Int}} ; visit_custom!, transform! )
187
202
188
203
Internal method which generates the code for forward mode diffentiation
189
204
@@ -192,13 +207,14 @@ Internal method which generates the code for forward mode diffentiation
192
207
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
193
208
paired with the order (first deriviative, second derivative etc)
194
209
195
- - `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
210
+ - `visit_custom!(ir::IRCode , stmt, order::Int, recurse::Bool) -> Bool `:
196
211
decides if the custom `transform!` should be applied to a `stmt` or not
197
212
Default: `false` for all statements
198
- - `transform!(ir, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
213
+ - `transform!(ir::IRCode , ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
199
214
"""
200
- function forward_diff_no_inf! (ir:: IRCode , to_diff:: Vector{Pair{SSAValue, Int}} ;
201
- visit_custom! = (args... )-> false , transform! = (args... )-> error ())
215
+ function forward_diff_no_inf! (ir:: IRCode , to_diff:: Vector{Pair{SSAValue,Int}} ;
216
+ visit_custom! = (@nospecialize args... )-> false ,
217
+ transform! = (@nospecialize args... )-> error ())
202
218
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
203
219
ssa_orders = [0 => false for i = 1 : length (ir. stmts)]
204
220
for (ssa, order) in to_diff
@@ -208,7 +224,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
208
224
truncation_map = Dict {Pair{SSAValue, Int}, SSAValue} ()
209
225
210
226
# Step 2: Transform
211
- function maparg (arg, ssa, order)
227
+ function maparg (@nospecialize ( arg) , ssa:: SSAValue , order:: Int )
212
228
if isa (arg, SSAValue)
213
229
if arg. id > length (ssa_orders)
214
230
# This is possible if the custom transform touched another statement.
@@ -259,10 +275,16 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
259
275
inst = ir[SSAValue (ssa)]
260
276
stmt = inst[:inst ]
261
277
if isexpr (stmt, :invoke )
262
- inst[:inst ] = Expr (:call , ∂☆ {order} (), map (arg-> maparg (arg, SSAValue (ssa), order), stmt. args[2 : end ])... )
278
+ newargs = map (stmt. args[2 : end ]) do @nospecialize arg
279
+ maparg (arg, SSAValue (ssa), order)
280
+ end
281
+ inst[:inst ] = Expr (:call , ∂☆ {order} (), newargs... )
263
282
inst[:type ] = Any
264
283
elseif isexpr (stmt, :call )
265
- inst[:inst ] = Expr (:call , ∂☆ {order} (), map (arg-> maparg (arg, SSAValue (ssa), order), stmt. args)... )
284
+ newargs = map (stmt. args) do @nospecialize arg
285
+ maparg (arg, SSAValue (ssa), order)
286
+ end
287
+ inst[:inst ] = Expr (:call , ∂☆ {order} (), newargs... )
266
288
inst[:type ] = Any
267
289
elseif isa (stmt, PiNode)
268
290
# TODO : New PiNode that discriminates based on primal?
@@ -288,7 +310,6 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
288
310
end
289
311
end
290
312
291
-
292
313
function forward_diff! (interp:: ADInterpreter , ir:: IRCode , src:: CodeInfo , mi:: MethodInstance ,
293
314
to_diff:: Vector{Pair{SSAValue, Int}} ; kwargs... )
294
315
forward_diff_no_inf! (ir, to_diff; kwargs... )
0 commit comments