1
1
using Core. Compiler: IRInterpretationState, construct_postdomtree, PiNode,
2
2
is_known_call, argextype, postdominates
3
3
4
- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , pantelides:: Vector{SSAValue} ; custom_diff! = (args... )-> nothing , diff_cache= Dict {SSAValue, SSAValue} ())
4
+ #=
5
+ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
5
6
Δs = SSAValue[]
6
7
rets = findall(@nospecialize(x)->isa(x, ReturnNode) && isdefined(x, :val), ir.stmts.inst)
7
8
postdomtree = construct_postdomtree(ir.cfg.blocks)
8
- for ssa in pantelides
9
- Δssa = forward_diff! (ir, interp, irsv, ssa; custom_diff!, diff_cache)
9
+ for ( ssa, order) in to_diff
10
+ Δssa = forward_diff!(ir, interp, irsv, ssa, order ; custom_diff!, diff_cache)
10
11
Δblock = block_for_inst(ir, Δssa.id)
11
12
for idx in rets
12
13
retblock = block_for_inst(ir, idx)
@@ -18,31 +19,24 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelid
18
19
end
19
20
return (ir, Δs)
20
21
end
22
+ =#
21
23
22
- function diff_unassigned_variable! (ir, ssa)
23
- return insert_node! (ir, ssa, NewInstruction (
24
- Expr (:call , GlobalRef (Intrinsics, :state_ddt ), ssa), Float64), #= attach_after=# true )
25
- end
26
-
27
- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue ; custom_diff!, diff_cache)
24
+ function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , order:: Int ; custom_diff!, diff_cache)
28
25
if haskey (diff_cache, ssa)
29
26
return diff_cache[ssa]
30
27
end
31
28
inst = ir[ssa]
32
29
stmt = inst[:inst ]
33
- if isa (stmt, SSAValue)
34
- return forward_diff! (ir, interp, irsv, stmt; custom_diff!, diff_cache)
35
- end
36
- Δssa = forward_diff_uncached! (ir, interp, irsv, ssa, inst; custom_diff!, diff_cache)
30
+ Δssa = forward_diff_uncached! (ir, interp, irsv, ssa, inst, order:: Int ; custom_diff!, diff_cache)
37
31
@assert Δssa != = nothing
38
32
if isa (Δssa, SSAValue)
39
33
diff_cache[ssa] = Δssa
40
34
end
41
35
return Δssa
42
36
end
43
- forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , val:: Union{Integer, AbstractFloat} ; custom_diff!, diff_cache) = zero (val)
44
- forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , @nospecialize (arg); custom_diff!, diff_cache) = ChainRulesCore. NoTangent ()
45
- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , arg:: Argument ; custom_diff!, diff_cache)
37
+ forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , val:: Union{Integer, AbstractFloat} , order :: Int ; custom_diff!, diff_cache) = zero (val)
38
+ forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , @nospecialize (arg), order :: Int ; custom_diff!, diff_cache) = ChainRulesCore. NoTangent ()
39
+ function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , arg:: Argument , order :: Int ; custom_diff!, diff_cache)
46
40
recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache)
47
41
val = custom_diff! (ir, SSAValue (0 ), arg, recurse)
48
42
if val != = nothing
@@ -51,13 +45,15 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Arg
51
45
return ChainRulesCore. NoTangent ()
52
46
end
53
47
54
- function forward_diff_uncached! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , inst:: Core.Compiler.Instruction ; custom_diff!, diff_cache)
48
+ function forward_diff_uncached! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , inst:: Core.Compiler.Instruction , order :: Int ; custom_diff!, diff_cache)
55
49
stmt = inst[:inst ]
56
- recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache)
50
+ recurse (x) = forward_diff! (ir, interp, irsv, x, order ; custom_diff!, diff_cache)
57
51
if (val = custom_diff! (ir, ssa, stmt, recurse)) != = nothing
58
52
return val
59
53
elseif isa (stmt, PiNode)
60
54
return recurse (stmt. val)
55
+ elseif isa (stmt, SSAValue)
56
+ return recurse (stmt)
61
57
elseif isa (stmt, PhiNode)
62
58
Δphi = PhiNode (copy (stmt. edges), similar (stmt. values))
63
59
T = Union{}
@@ -152,3 +148,108 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
152
148
return Δssa
153
149
end
154
150
end
151
+
152
+ function forward_visit! (ir:: IRCode , ssa:: SSAValue , order:: Int , ssa_orders:: Vector{Pair{Int, Bool}} , visit_custom!)
153
+ if ssa_orders[ssa. id][1 ] >= order
154
+ return
155
+ end
156
+ ssa_orders[ssa. id] = order => ssa_orders[ssa. id][2 ]
157
+ inst = ir[ssa]
158
+ stmt = inst[:inst ]
159
+ recurse (@nospecialize (val)) = forward_visit! (ir, val, order, ssa_orders, visit_custom!)
160
+ if visit_custom! (ir, stmt, order, recurse)
161
+ ssa_orders[ssa. id] = order => true
162
+ return
163
+ elseif isa (stmt, PiNode)
164
+ return recurse (stmt. val)
165
+ elseif isa (stmt, PhiNode)
166
+ for i = 1 : length (stmt. values)
167
+ isassigned (stmt. values, i) || continue
168
+ recurse (stmt. values[i])
169
+ end
170
+ return
171
+ elseif isexpr (stmt, :new ) || isexpr (stmt, :invoke )
172
+ foreach (recurse, stmt. args[2 : end ])
173
+ elseif isexpr (stmt, :call )
174
+ foreach (recurse, stmt. args)
175
+ elseif isa (stmt, SSAValue)
176
+ recurse (stmt)
177
+ elseif ! isa (stmt, Expr)
178
+ return
179
+ else
180
+ @show stmt
181
+ error ()
182
+ end
183
+ end
184
+ forward_visit! (ir:: IRCode , _, order:: Int , ssa_orders:: Vector{Pair{Int, Bool}} , visit_custom!) = nothing
185
+ function forward_visit! (ir:: IRCode , a:: Argument , order:: Int , ssa_orders:: Vector{Pair{Int, Bool}} , visit_custom!)
186
+ recurse (@nospecialize (val)) = forward_visit! (ir, val, order, ssa_orders, visit_custom!)
187
+ return visit_custom! (ir, a, order, recurse)
188
+ end
189
+
190
+
191
+ function forward_diff_no_inf! (ir:: IRCode , interp, mi:: MethodInstance , world, to_diff:: Vector{Pair{SSAValue, Int}} ;
192
+ visit_custom! = (args... )-> false , transform! = (args... )-> error ())
193
+ # Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
194
+ ssa_orders = [0 => false for i = 1 : length (ir. stmts)]
195
+ for (ssa, order) in to_diff
196
+ forward_visit! (ir, ssa, order, ssa_orders, visit_custom!)
197
+ end
198
+
199
+ # Step 2: Transform
200
+ function maparg (arg, ssa, order)
201
+ if isa (arg, Argument)
202
+ # TODO : Should we remember whether the callbacks wanted the arg?
203
+ return transform! (ir, arg, order)
204
+ elseif isa (arg, SSAValue)
205
+ # TODO : Bundle truncation if necessary
206
+ return arg
207
+ end
208
+ @assert ! isa (arg, Expr)
209
+ return insert_node! (ir, ssa, NewInstruction (Expr (:call , ZeroBundle{order}, arg), Any))
210
+ end
211
+
212
+ for (ssa, (order, custom)) in enumerate (ssa_orders)
213
+ if order == 0
214
+ # TODO : Bundle truncation?
215
+ continue
216
+ end
217
+ if custom
218
+ transform! (ir, SSAValue (ssa), order)
219
+ else
220
+ inst = ir[SSAValue (ssa)]
221
+ stmt = inst[:inst ]
222
+ if isexpr (stmt, :invoke )
223
+ inst[:inst ] = Expr (:call , ∂☆ {order} (), map (arg-> maparg (arg, SSAValue (ssa), order), stmt. args[2 : end ])... )
224
+ inst[:type ] = Any
225
+ elseif ! isa (stmt, Expr)
226
+ inst[:inst ] = maparg (stmt, ssa, order)
227
+ inst[:type ] = Any
228
+ else
229
+ @show stmt
230
+ error ()
231
+ end
232
+ end
233
+ end
234
+
235
+ end
236
+
237
+ function forward_diff! (ir:: IRCode , interp, mi:: MethodInstance , world, to_diff:: Vector{Pair{SSAValue, Int}} ; kwargs... )
238
+ forward_diff_no_inf! (ir, interp, mi, world, to_diff; kwargs... )
239
+
240
+ # Step 3: Re-inference
241
+ ir = compact! (ir)
242
+
243
+ extra_reprocess = CC. BitSet ()
244
+ for i = 1 : length (ir. stmts)
245
+ if ir[SSAValue (i)][:type ] == Any
246
+ CC. push! (extra_reprocess, i)
247
+ end
248
+ end
249
+
250
+ interp′ = enable_reinference (interp)
251
+ irsv = IRInterpretationState (interp′, ir, mi, world, ir. argtypes[1 : mi. def. nargs])
252
+ rt = CC. _ir_abstract_constant_propagation (interp′, irsv; extra_reprocess)
253
+
254
+ return ir
255
+ end
0 commit comments