diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 54dd00cf..750a5943 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -282,16 +282,11 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; else inst = ir[SSAValue(ssa)] stmt = inst[:inst] - if isexpr(stmt, :invoke) - newargs = map(stmt.args[2:end]) do @nospecialize arg + if isexpr(stmt, :invoke) || isexpr(stmt, :call) || isexpr(stmt, :new) + newargs = map(@view stmt.args[isexpr(stmt, :invoke) + 1:end]) do @nospecialize arg maparg(arg, SSAValue(ssa), order) end - replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order}(), newargs...)) - elseif isexpr(stmt, :call) || isexpr(stmt, :new) - newargs = map(stmt.args) do @nospecialize arg - maparg(arg, SSAValue(ssa), order) - end - f = isexpr(stmt, :call) ? ∂☆{order}() : ∂☆new{order}() + f = isexpr(stmt, :new) ? ∂☆new{order}() : ∂☆{order}() replace_call!(ir, SSAValue(ssa), Expr(:call, f, newargs...)) elseif isa(stmt, PiNode) # TODO: New PiNode that discriminates based on primal?