1
1
2
2
module forward_diff_no_inf
3
+ using Core. Compiler: SSAValue
4
+ const CC = Core. Compiler
5
+
3
6
using Diffractor, Test
4
7
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
5
- identity_transform! (ir, ssa:: Core. SSAValue , order, _) = ir[ssa]
8
+ identity_transform! (ir, ssa:: SSAValue , order, _) = ir[ssa]
6
9
function identity_transform! (ir, arg:: Core.Argument , order, _)
7
- return Core . Compiler . insert_node! (ir, Core . SSAValue (1 ), Core . Compiler . NewInstruction (Expr (:call , Diffractor. ZeroBundle {1 }, arg), Any))
10
+ return CC . insert_node! (ir, SSAValue (1 ), CC . NewInstruction (Expr (:call , Diffractor. zero_bundle {1} () , arg), Any))
8
11
end
9
12
13
+
14
+ function infer_ir! (ir)
15
+ interp = CC. NativeInterpreter ()
16
+ mi = ccall (:jl_new_method_instance_uninit , Ref{Core. MethodInstance}, ());
17
+ mi. specTypes = Tuple{map (CC. widenconst, ir. argtypes)... }
18
+ mi. def = @__MODULE__
19
+
20
+ for i in 1 : length (ir. stmts) # For testuing purposes we are going to refine everything
21
+ ir[SSAValue (i)][:flag ] |= CC. IR_FLAG_REFINED
22
+ end
23
+
24
+ method_info = CC. MethodInfo (#= propagate_inbounds=# true , nothing )
25
+ min_world = world = (interp). world
26
+ max_world = Diffractor. get_world_counter ()
27
+ irsv = CC. IRInterpretationState (interp, method_info, ir, mi, ir. argtypes, world, min_world, max_world)
28
+ (rt, nothrow) = CC. _ir_abstract_constant_propagation (interp, irsv)
29
+ return rt
30
+ end
31
+
32
+
10
33
@testset " Constructors in forward_diff_no_inf!" begin
11
34
struct Bar148
12
35
v
13
36
end
14
37
foo_148 (x) = Bar148 (x)
15
38
16
39
ir = first (only (Base. code_ircode (foo_148, Tuple{Float64})))
17
- Diffractor. forward_diff_no_inf! (ir, [Core . SSAValue (1 ) => 1 ]; transform! = identity_transform!)
18
- ir2 = Core . Compiler . compact! (ir)
40
+ Diffractor. forward_diff_no_inf! (ir, [SSAValue (1 ) => 1 ]; transform! = identity_transform!)
41
+ ir2 = CC . compact! (ir)
19
42
f = Core. OpaqueClosure (ir2; do_compile= false )
20
43
@test f (1.0 ) == Bar148 (1.0 ) # This would error if we were not handling constructors (%new) right
21
44
end
@@ -25,9 +48,9 @@ module forward_diff_no_inf
25
48
plus_a_global (x) = x + _coeff
26
49
27
50
ir = first (only (Base. code_ircode (plus_a_global, Tuple{Float64})))
28
- Diffractor. forward_diff_no_inf! (ir, [Core . SSAValue (1 ) => 1 ]; transform! = identity_transform!)
29
- ir2 = Core . Compiler . compact! (ir)
30
- Core . Compiler . verify_ir (ir2) # This would error if we were not handling nonconst globals correctly
51
+ Diffractor. forward_diff_no_inf! (ir, [SSAValue (1 ) => 1 ]; transform! = identity_transform!)
52
+ ir2 = CC . compact! (ir)
53
+ CC . verify_ir (ir2) # This would error if we were not handling nonconst globals correctly
31
54
# Assert that the reference to `Main._coeff` is properly typed
32
55
stmt_idx = findfirst (stmt -> isa (stmt[:inst ], GlobalRef), collect (ir2. stmts))
33
56
stmt = ir2. stmts[stmt_idx]
@@ -51,17 +74,40 @@ module forward_diff_no_inf
51
74
input_ir = first (only (Base. code_ircode (phi_run, Tuple{Float64})))
52
75
ir = copy (input_ir)
53
76
# Workout where to diff to trigger error
54
- diff_ssa = Core . SSAValue[]
77
+ diff_ssa = SSAValue[]
55
78
for idx in 1 : length (ir. stmts)
56
79
if ir. stmts[idx][:inst ] isa Core. PhiNode
57
- push! (diff_ssa, Core . SSAValue (idx))
80
+ push! (diff_ssa, SSAValue (idx))
58
81
end
59
82
end
60
83
61
84
Diffractor. forward_diff_no_inf! (ir, diff_ssa .=> 1 ; transform! = identity_transform!)
62
- ir2 = Core . Compiler . compact! (ir)
63
- Core . Compiler . verify_ir (ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
85
+ ir2 = CC . compact! (ir)
86
+ CC . verify_ir (ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
64
87
f = Core. OpaqueClosure (ir2; do_compile= false )
65
88
@test f (3.5 ) == 3.5 # this will segfault if we are not handling phi nodes correctly
66
89
end
90
+
91
+ @testset " Eras mode" begin
92
+ foo (x, y) = x* x + y* y
93
+ # Eras mode should make this all inferable
94
+ eras_mode = true
95
+ ir = first (only (Base. code_ircode (foo, Tuple{Any, Any})))
96
+ # @assert ir[SSAValue(1)][:inst].args[1].name == :literal_pow
97
+ @assert ir[SSAValue (3 )][:inst ]. args[1 ]. name == :+
98
+ Diffractor. forward_diff_no_inf! (ir, [SSAValue (1 )] .=> 1 ; transform! = identity_transform!, eras_mode)
99
+ ir = CC. compact! (ir)
100
+ # @assert ir[SSAValue(5)][:inst].args[1] == Diffractor.∂☆{1, eras_mode}()
101
+ # @assert ir[SSAValue(5)][:inst].args[2].primal == *
102
+ ir. argtypes[2 : end ] .= Float64
103
+ @assert infer_ir! (ir) == Float64
104
+
105
+ Diffractor. forward_diff_no_inf! (ir, [SSAValue (6 )] .=> 1 ; transform! = identity_transform!, eras_mode= eras_mode)
106
+ ir = CC. compact! (ir)
107
+ CC. verify_ir (ir)
108
+ infer_ir! (ir)
109
+
110
+
111
+
112
+ end
67
113
end
0 commit comments