@@ -4,6 +4,9 @@ module forward_diff_no_inf
4
4
const CC = Core. Compiler
5
5
6
6
using Diffractor, Test
7
+
8
+ # #################### Helpers:
9
+
7
10
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
8
11
identity_transform! (ir, ssa:: SSAValue , order, _) = ir[ssa]
9
12
function identity_transform! (ir, arg:: Core.Argument , order, _)
@@ -28,7 +31,23 @@ module forward_diff_no_inf
28
31
(rt, nothrow) = CC. _ir_abstract_constant_propagation (interp, irsv)
29
32
return rt
30
33
end
31
-
34
+
35
+ function isfully_inferred (ir)
36
+ for stmt in ir. stmts
37
+ inst = stmt[:inst ]
38
+ if Meta. isexpr (inst, :call ) || Meta. isexpr (inst, :invoke )
39
+ typ = stmt[:type ]
40
+ ! isa (typ, Type) && continue # If not a Type then something even more informed like a Const
41
+ if isabstracttype (typ) || typ <: Union || typ <: UnionAll
42
+ # @error "Not fully inferred" inst typ
43
+ return false
44
+ end
45
+ end
46
+ end
47
+ return true
48
+ end
49
+
50
+ # ############################## Actual tests:
32
51
33
52
@testset " Constructors in forward_diff_no_inf!" begin
34
53
struct Bar148
@@ -88,10 +107,8 @@ module forward_diff_no_inf
88
107
@test f (3.5 ) == 3.5 # this will segfault if we are not handling phi nodes correctly
89
108
end
90
109
91
- @testset " Eras mode" begin
110
+ @testset " Eras mode: $eras_mode " for eras_mode in ( false , true )
92
111
foo (x, y) = x* x + y* y
93
- # Eras mode should make this all inferable
94
- eras_mode = true
95
112
ir = first (only (Base. code_ircode (foo, Tuple{Any, Any})))
96
113
# @assert ir[SSAValue(1)][:inst].args[1].name == :literal_pow
97
114
@assert ir[SSAValue (3 )][:inst ]. args[1 ]. name == :+
@@ -100,14 +117,23 @@ module forward_diff_no_inf
100
117
# @assert ir[SSAValue(5)][:inst].args[1] == Diffractor.∂☆{1, eras_mode}()
101
118
# @assert ir[SSAValue(5)][:inst].args[2].primal == *
102
119
ir. argtypes[2 : end ] .= Float64
120
+ ir = CC. compact! (ir)
103
121
infer_ir! (ir)
122
+ CC. verify_ir (ir)
123
+ @test isfully_inferred (ir) # passes with and without eras mode
104
124
105
- Diffractor. forward_diff_no_inf! (ir, [SSAValue (3 )] .=> 1 ; transform! = identity_transform!, eras_mode= eras_mode)
106
- # TODO actually test things here.
107
-
108
-
125
+ Diffractor. forward_diff_no_inf! (ir, [SSAValue (3 )] .=> 1 ; transform! = identity_transform!, eras_mode)
109
126
ir = CC. compact! (ir)
127
+ infer_ir! (ir)
128
+
110
129
CC. verify_ir (ir)
111
- infer_ir! (ir)
130
+ if eras_mode
131
+ @test isfully_inferred (ir)
132
+ else
133
+ # if this passes outside era mode then this test is wrong
134
+ @assert ! isfully_inferred (ir)
135
+ end
112
136
end
113
- end
137
+
138
+ end # module
139
+
0 commit comments