Skip to content

Commit 48c4317

Browse files
committed
Hook up eras mode in stage 2
1 parent c1a6703 commit 48c4317

File tree

3 files changed

+68
-20
lines changed

3 files changed

+68
-20
lines changed

src/codegen/forward_demand.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpre
4040
end
4141
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
4242
val, order::Int;
43-
custom_diff!, diff_cache)
43+
custom_diff!, diff_cache, eras_mode)
4444
return ChainRulesCore.zero_tangent(val)
4545
end
4646
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
4747
arg::Argument, order::Int;
48-
custom_diff!, diff_cache)
49-
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache)
48+
custom_diff!, diff_cache, eras_mode)
49+
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache, eras_mode)
5050
val = custom_diff!(ir, SSAValue(0), arg, recurse)
5151
if val !== nothing
5252
return val
@@ -56,9 +56,9 @@ end
5656

5757
function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
5858
ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int;
59-
custom_diff!, diff_cache)
59+
custom_diff!, diff_cache, eras_mode)
6060
stmt = inst[:inst]
61-
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache)
61+
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache, eras_mode)
6262
if (val = custom_diff!(ir, ssa, stmt, recurse)) !== nothing
6363
return val
6464
elseif isa(stmt, PiNode)
@@ -212,8 +212,10 @@ Internal method which generates the code for forward mode diffentiation
212212
decides if the custom `transform!` should be applied to a `stmt` or not
213213
Default: `false` for all statements
214214
- `transform!(ir::IRCode, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
215+
- `eras_mode`: determines if to error if not all derivatives are taylor
215216
"""
216217
function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
218+
eras_mode = false,
217219
visit_custom! = (@nospecialize args...)->false,
218220
transform! = (@nospecialize args...)->error())
219221
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
@@ -286,12 +288,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
286288
newargs = map(stmt.args[2:end]) do @nospecialize arg
287289
maparg(arg, SSAValue(ssa), order)
288290
end
289-
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order}(), newargs...))
291+
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order, eras_mode}(), newargs...))
290292
elseif isexpr(stmt, :call) || isexpr(stmt, :new)
291293
newargs = map(stmt.args) do @nospecialize arg
292294
maparg(arg, SSAValue(ssa), order)
293295
end
294-
f = isexpr(stmt, :call) ? ∂☆{order}() : ∂☆new{order}()
296+
f = isexpr(stmt, :call) ? ∂☆{order, eras_mode}() : ∂☆new{order, eras_mode}()
295297
replace_call!(ir, SSAValue(ssa), Expr(:call, f, newargs...))
296298
elseif isa(stmt, PiNode)
297299
# TODO: New PiNode that discriminates based on primal?

src/stage2/forward.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919

2020
# Engineering entry point for the 2nd-order forward AD functionality. This is
2121
# unlikely to be the actual interface. For now, it is used for testing.
22-
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
22+
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = false)
2323
interp = ADInterpreter(; forward=true, backward=false)
2424
match = Base._which(tt)
2525
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)
@@ -82,7 +82,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
8282
end
8383
end
8484

85-
ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!)
85+
ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!, eras_mode)
8686

8787
return OpaqueClosure(ir)
8888
end

test/forward_diff_no_inf.jl

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,44 @@
11

22
module forward_diff_no_inf
3+
using Core.Compiler: SSAValue
4+
const CC = Core.Compiler
5+
36
using Diffractor, Test
47
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
5-
identity_transform!(ir, ssa::Core.SSAValue, order, _) = ir[ssa]
8+
identity_transform!(ir, ssa::SSAValue, order, _) = ir[ssa]
69
function identity_transform!(ir, arg::Core.Argument, order, _)
7-
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
10+
return CC.insert_node!(ir, SSAValue(1), CC.NewInstruction(Expr(:call, Diffractor.zero_bundle{1}(), arg), Any))
811
end
912

13+
14+
function infer_ir!(ir)
15+
interp = CC.NativeInterpreter()
16+
mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ());
17+
mi.specTypes = Tuple{map(CC.widenconst, ir.argtypes)...}
18+
mi.def = @__MODULE__
19+
20+
for i in 1:length(ir.stmts) # For testuing purposes we are going to refine everything
21+
ir[SSAValue(i)][:flag] |= CC.IR_FLAG_REFINED
22+
end
23+
24+
method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing)
25+
min_world = world = (interp).world
26+
max_world = Diffractor.get_world_counter()
27+
irsv = CC.IRInterpretationState(interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world)
28+
(rt, nothrow) = CC._ir_abstract_constant_propagation(interp, irsv)
29+
return rt
30+
end
31+
32+
1033
@testset "Constructors in forward_diff_no_inf!" begin
1134
struct Bar148
1235
v
1336
end
1437
foo_148(x) = Bar148(x)
1538

1639
ir = first(only(Base.code_ircode(foo_148, Tuple{Float64})))
17-
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
18-
ir2 = Core.Compiler.compact!(ir)
40+
Diffractor.forward_diff_no_inf!(ir, [SSAValue(1) => 1]; transform! = identity_transform!)
41+
ir2 = CC.compact!(ir)
1942
f = Core.OpaqueClosure(ir2; do_compile=false)
2043
@test f(1.0) == Bar148(1.0) # This would error if we were not handling constructors (%new) right
2144
end
@@ -25,9 +48,9 @@ module forward_diff_no_inf
2548
plus_a_global(x) = x + _coeff
2649

2750
ir = first(only(Base.code_ircode(plus_a_global, Tuple{Float64})))
28-
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
29-
ir2 = Core.Compiler.compact!(ir)
30-
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst globals correctly
51+
Diffractor.forward_diff_no_inf!(ir, [SSAValue(1) => 1]; transform! = identity_transform!)
52+
ir2 = CC.compact!(ir)
53+
CC.verify_ir(ir2) # This would error if we were not handling nonconst globals correctly
3154
# Assert that the reference to `Main._coeff` is properly typed
3255
stmt_idx = findfirst(stmt -> isa(stmt[:inst], GlobalRef), collect(ir2.stmts))
3356
stmt = ir2.stmts[stmt_idx]
@@ -51,17 +74,40 @@ module forward_diff_no_inf
5174
input_ir = first(only(Base.code_ircode(phi_run, Tuple{Float64})))
5275
ir = copy(input_ir)
5376
#Workout where to diff to trigger error
54-
diff_ssa = Core.SSAValue[]
77+
diff_ssa = SSAValue[]
5578
for idx in 1:length(ir.stmts)
5679
if ir.stmts[idx][:inst] isa Core.PhiNode
57-
push!(diff_ssa, Core.SSAValue(idx))
80+
push!(diff_ssa, SSAValue(idx))
5881
end
5982
end
6083

6184
Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!)
62-
ir2 = Core.Compiler.compact!(ir)
63-
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
85+
ir2 = CC.compact!(ir)
86+
CC.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
6487
f = Core.OpaqueClosure(ir2; do_compile=false)
6588
@test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly
6689
end
90+
91+
@testset "Eras mode" begin
92+
foo(x, y) = x*x + y*y
93+
# Eras mode should make this all inferable
94+
eras_mode = true
95+
ir = first(only(Base.code_ircode(foo, Tuple{Any, Any})))
96+
#@assert ir[SSAValue(1)][:inst].args[1].name == :literal_pow
97+
@assert ir[SSAValue(3)][:inst].args[1].name == :+
98+
Diffractor.forward_diff_no_inf!(ir, [SSAValue(1)] .=> 1; transform! = identity_transform!, eras_mode)
99+
ir = CC.compact!(ir)
100+
#@assert ir[SSAValue(5)][:inst].args[1] == Diffractor.∂☆{1, eras_mode}()
101+
#@assert ir[SSAValue(5)][:inst].args[2].primal == *
102+
ir.argtypes[2:end] .= Float64
103+
@assert infer_ir!(ir) == Float64
104+
105+
Diffractor.forward_diff_no_inf!(ir, [SSAValue(6)] .=> 1; transform! = identity_transform!, eras_mode=eras_mode)
106+
ir = CC.compact!(ir)
107+
CC.verify_ir(ir)
108+
infer_ir!(ir)
109+
110+
111+
112+
end
67113
end

0 commit comments

Comments
 (0)