Skip to content

Commit 9771bf3

Browse files
committed
add tests
1 parent ac44a97 commit 9771bf3

File tree

1 file changed

+36
-10
lines changed

1 file changed

+36
-10
lines changed

test/forward_diff_no_inf.jl

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ module forward_diff_no_inf
44
const CC = Core.Compiler
55

66
using Diffractor, Test
7+
8+
##################### Helpers:
9+
710
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
811
identity_transform!(ir, ssa::SSAValue, order, _) = ir[ssa]
912
function identity_transform!(ir, arg::Core.Argument, order, _)
@@ -28,7 +31,23 @@ module forward_diff_no_inf
2831
(rt, nothrow) = CC._ir_abstract_constant_propagation(interp, irsv)
2932
return rt
3033
end
31-
34+
35+
function isfully_inferred(ir)
36+
for stmt in ir.stmts
37+
inst = stmt[:inst]
38+
if Meta.isexpr(inst, :call) || Meta.isexpr(inst, :invoke)
39+
typ = stmt[:type]
40+
!isa(typ, Type) && continue # If not a Type then something even more informed like a Const
41+
if isabstracttype(typ) || typ <: Union || typ <: UnionAll
42+
#@error "Not fully inferred" inst typ
43+
return false
44+
end
45+
end
46+
end
47+
return true
48+
end
49+
50+
############################### Actual tests:
3251

3352
@testset "Constructors in forward_diff_no_inf!" begin
3453
struct Bar148
@@ -88,10 +107,8 @@ module forward_diff_no_inf
88107
@test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly
89108
end
90109

91-
@testset "Eras mode" begin
110+
@testset "Eras mode: $eras_mode" for eras_mode in (false, true)
92111
foo(x, y) = x*x + y*y
93-
# Eras mode should make this all inferable
94-
eras_mode = true
95112
ir = first(only(Base.code_ircode(foo, Tuple{Any, Any})))
96113
#@assert ir[SSAValue(1)][:inst].args[1].name == :literal_pow
97114
@assert ir[SSAValue(3)][:inst].args[1].name == :+
@@ -100,14 +117,23 @@ module forward_diff_no_inf
100117
#@assert ir[SSAValue(5)][:inst].args[1] == Diffractor.∂☆{1, eras_mode}()
101118
#@assert ir[SSAValue(5)][:inst].args[2].primal == *
102119
ir.argtypes[2:end] .= Float64
120+
ir = CC.compact!(ir)
103121
infer_ir!(ir)
122+
CC.verify_ir(ir)
123+
@test isfully_inferred(ir) # passes with and without eras mode
104124

105-
Diffractor.forward_diff_no_inf!(ir, [SSAValue(3)] .=> 1; transform! = identity_transform!, eras_mode=eras_mode)
106-
# TODO actually test things here.
107-
108-
125+
Diffractor.forward_diff_no_inf!(ir, [SSAValue(3)] .=> 1; transform! = identity_transform!, eras_mode)
109126
ir = CC.compact!(ir)
127+
infer_ir!(ir)
128+
110129
CC.verify_ir(ir)
111-
infer_ir!(ir)
130+
if eras_mode
131+
@test isfully_inferred(ir)
132+
else
133+
# if this passes outside era mode then this test is wrong
134+
@assert !isfully_inferred(ir)
135+
end
112136
end
113-
end
137+
138+
end # module
139+

0 commit comments

Comments
 (0)