1
1
using Core. Compiler: IRInterpretationState, construct_postdomtree, PiNode,
2
- is_known_call, argextype, postdominates
2
+ is_known_call, argextype, postdominates, userefs
3
3
4
4
#=
5
5
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
@@ -93,12 +93,6 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
93
93
return Δtangent
94
94
else # general frule handling
95
95
info = inst[:info ]
96
- if ! isa (info, FRuleCallInfo)
97
- @show info
98
- @show inst[:inst ]
99
- display (ir)
100
- error ()
101
- end
102
96
if isexpr (stmt, :invoke )
103
97
args = stmt. args[2 : end ]
104
98
else
@@ -196,22 +190,50 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
196
190
forward_visit! (ir, ssa, order, ssa_orders, visit_custom!)
197
191
end
198
192
193
+ truncation_map = Dict {Pair{SSAValue, Int}, SSAValue} ()
194
+
199
195
# Step 2: Transform
200
196
function maparg (arg, ssa, order)
201
- if isa (arg, Argument)
197
+ if isa (arg, SSAValue)
198
+ if arg. id > length (ssa_orders)
199
+ # This is possible if the custom transform touched another statement.
200
+ # In that case just pass this through and assume the `transform!` did
201
+ # it correctly.
202
+ return arg
203
+ end
204
+ (argorder, _) = ssa_orders[arg. id]
205
+ if argorder != order
206
+ @assert order < argorder
207
+ return get! (truncation_map, arg=> order) do
208
+ # TODO : Other orders
209
+ @assert order == 0
210
+ insert_node! (ir, arg, NewInstruction (Expr (:call , primal, arg), Any), #= attach_after=# true )
211
+ end
212
+ end
213
+ return arg
214
+ elseif order == 0
215
+ return arg
216
+ elseif isa (arg, Argument)
202
217
# TODO : Should we remember whether the callbacks wanted the arg?
203
218
return transform! (ir, arg, order)
204
- elseif isa (arg, SSAValue)
205
- # TODO : Bundle truncation if necessary
206
- return arg
219
+ elseif isa (arg, GlobalRef)
220
+ return insert_node! (ir, ssa, NewInstruction (Expr (:call , ZeroBundle{order}, arg), Any))
221
+ elseif isa (arg, QuoteNode)
222
+ return ZeroBundle {order} (arg. value)
207
223
end
208
224
@assert ! isa (arg, Expr)
209
- return insert_node! (ir, ssa, NewInstruction ( Expr ( :call , ZeroBundle{order}, arg), Any) )
225
+ return ZeroBundle {order} ( arg)
210
226
end
211
227
212
228
for (ssa, (order, custom)) in enumerate (ssa_orders)
213
229
if order == 0
214
- # TODO : Bundle truncation?
230
+ inst = ir[SSAValue (ssa)]
231
+ stmt = inst[:inst ]
232
+ urs = userefs (stmt)
233
+ for ur in urs
234
+ ur[] = maparg (ur[], SSAValue (ssa), order)
235
+ end
236
+ inst[:inst ] = urs[]
215
237
continue
216
238
end
217
239
if custom
@@ -222,12 +244,16 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
222
244
if isexpr (stmt, :invoke )
223
245
inst[:inst ] = Expr (:call , ∂☆ {order} (), map (arg-> maparg (arg, SSAValue (ssa), order), stmt. args[2 : end ])... )
224
246
inst[:type ] = Any
225
- elseif ! isa (stmt, Expr )
226
- inst[:inst ] = maparg (stmt, ssa, order)
247
+ elseif isexpr (stmt, :call )
248
+ inst[:inst ] = Expr ( :call , ∂☆ {order} (), map (arg -> maparg (arg, SSAValue ( ssa) , order), stmt . args) ... )
227
249
inst[:type ] = Any
228
250
else
229
- @show stmt
230
- error ()
251
+ urs = userefs (stmt)
252
+ for ur in urs
253
+ ur[] = maparg (ur[], SSAValue (ssa), order)
254
+ end
255
+ inst[:inst ] = urs[]
256
+ inst[:type ] = Any
231
257
end
232
258
end
233
259
end
0 commit comments