Skip to content

Commit 998267b

Browse files
authored
Merge pull request #148 from JuliaDiff/sf/fix_new_zero_bundle_error
Forward mode handingly for `Expr(:new, )` in `forward_diff_no_inf!`
2 parents fb4e03d + 0d6b465 commit 998267b

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/codegen/forward_demand.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vecto
183183
foreach(recurse, stmt.args)
184184
elseif isa(stmt, SSAValue)
185185
recurse(stmt)
186+
elseif isexpr(stmt, :code_coverage_effect)
187+
return
186188
elseif !isa(stmt, Expr)
187189
return
188190
else
@@ -280,11 +282,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
280282
end
281283
inst[:inst] = Expr(:call, ∂☆{order}(), newargs...)
282284
inst[:type] = Any
283-
elseif isexpr(stmt, :call)
285+
elseif isexpr(stmt, :call) || isexpr(stmt, :new)
284286
newargs = map(stmt.args) do @nospecialize arg
285287
maparg(arg, SSAValue(ssa), order)
286288
end
287-
inst[:inst] = Expr(:call, ∂☆{order}(), newargs...)
289+
f = isexpr(stmt, :call) ? ∂☆{order}() : ∂☆new{order}()
290+
inst[:inst] = Expr(:call, f, newargs...)
288291
inst[:type] = Any
289292
elseif isa(stmt, PiNode)
290293
# TODO: New PiNode that discriminates based on primal?

test/stage2_fwd.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,23 @@ module stage2_fwd
4343
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
4444
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,)))
4545
end
46+
47+
@testset "Constructors in forward_diff_no_inf!" begin
48+
struct Bar148
49+
v
50+
end
51+
foo_148(x) = Bar148(x)
52+
53+
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
54+
identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa]
55+
function identity_transform!(ir, arg::Core.Argument, order)
56+
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
57+
end
58+
59+
ir = first(only(Base.code_ircode(foo_148, Tuple{Float64})))
60+
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
61+
ir2 = Core.Compiler.compact!(ir)
62+
f = Core.OpaqueClosure(ir2; do_compile=false)
63+
@test f(1.0) == Bar148(1.0) # This would error if we were not handling constructors (%new) right
64+
end
4665
end

0 commit comments

Comments
 (0)