Skip to content

Commit 9e1b91c

Browse files
committed
split function based on taylor_or_bust
1 parent c48bfd5 commit 9e1b91c

File tree

3 files changed

+57
-42
lines changed

3 files changed

+57
-42
lines changed

src/codegen/forward_demand.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,10 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
290290
end
291291
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order, eras_mode}(), newargs...))
292292
elseif isexpr(stmt, :call) || isexpr(stmt, :new)
293-
newargs = map(stmt.args) do @nospecialize arg
293+
newargs = map(stmt.args) do @nospecialize argq
294294
maparg(arg, SSAValue(ssa), order)
295295
end
296-
f = isexpr(stmt, :call) ? ∂☆{order, eras_mode}() : ∂☆new{order, eras_mode}()
296+
f = isexpr(stmt, :call) ? ∂☆{order, eras_mode}() : ∂☆new{order}()
297297
replace_call!(ir, SSAValue(ssa), Expr(:call, f, newargs...))
298298
elseif isa(stmt, PiNode)
299299
# TODO: New PiNode that discriminates based on primal?

src/stage1/forward.jl

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -74,48 +74,63 @@ function taylor_failure_values(r::TaylorBundle{<:Any, Tuple{Any,Any}}, fail_orde
7474
return partial(r, i+1)[1], partial(r, i)[2]
7575
end
7676

77-
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}, ::Val{taylor_or_bust}) where {B1,B2, taylor_or_bust}
78-
z₀ = primal(r)[1]
79-
z₁ = partial(r, 1)[1]
80-
z₂ = primal(r)[2]
81-
z₁₂ = partial(r, 1)[2]
82-
83-
taylor_fail_order = find_taylor_incompatibility(r)
84-
if taylor_fail_order < 0
85-
return TaylorBundle{2}(z₀, (z₁, z₁₂))
86-
elseif taylor_or_bust
87-
@assert taylor_fail_order == 0 # can't be higher
88-
throw(TaylorRequired(taylor_fail_order, z₁, z₂))
89-
else
90-
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
77+
for taylor_or_bust in (false, true)
78+
@eval function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}, ::Val{$taylor_or_bust}) where {B1,B2}
79+
z₀ = primal(r)[1]
80+
z₁ = partial(r, 1)[1]
81+
z₂ = primal(r)[2]
82+
z₁₂ = partial(r, 1)[2]
83+
84+
taylor_fail_order = find_taylor_incompatibility(r)
85+
if taylor_fail_order < 0
86+
return TaylorBundle{2}(z₀, (z₁, z₁₂))
87+
else
88+
$(
89+
if taylor_or_bust
90+
quote
91+
@assert taylor_fail_order == 0 # can't be higher
92+
throw(TaylorRequired(taylor_fail_order, z₁, z₂))
93+
end
94+
else
95+
:(return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂)))
96+
end
97+
)
98+
end
9199
end
92-
end
93100

94-
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}, ::Val{taylor_or_bust}) where {N, B1,B2, taylor_or_bust}
95-
the_primal = primal(r)[1]
96-
taylor_fail_order = find_taylor_incompatibility(r)
97-
if taylor_fail_order(r) < 0
98-
the_partials = ntuple(N+1) do i
99-
if i <= N
100-
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
101-
else # ii = N+1
102-
partial(r, i-1)[2]
101+
@eval function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}, ::Val{$taylor_or_bust}) where {N, B1,B2}
102+
the_primal = primal(r)[1]
103+
taylor_fail_order = find_taylor_incompatibility(r)
104+
if taylor_fail_order(r) < 0
105+
the_partials = ntuple(N+1) do i
106+
if i <= N
107+
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
108+
else # ii = N+1
109+
partial(r, i-1)[2]
110+
end
103111
end
112+
return TaylorBundle{N+1}(the_primal, the_partials)
113+
else
114+
$(
115+
if taylor_or_bust
116+
quote
117+
@assert taylor_fail_order < N
118+
throw(TaylorRequired(taylor_fail_order, taylor_failure_values(r, taylor_fail_order)...))
119+
end
120+
else
121+
quote
122+
#XXX: am dubious of the correctness of this
123+
a_partials = ntuple(i->partial(r, i)[1], N)
124+
b_partials = ntuple(i->partial(r, i)[2], N)
125+
the_partials = (a_partials..., primal_b, b_partials...)
126+
return TangentBundle{N+1}(the_primal, the_partials)
127+
end
128+
end
129+
)
104130
end
105-
return TaylorBundle{N+1}(the_primal, the_partials)
106-
elseif taylor_or_bust
107-
@assert taylor_fail_order < N
108-
throw(TaylorRequired(taylor_fail_order, taylor_failure_values(r, taylor_fail_order)...))
109-
else
110-
#XXX: am dubious of the correctness of this
111-
a_partials = ntuple(i->partial(r, i)[1], N)
112-
b_partials = ntuple(i->partial(r, i)[2], N)
113-
the_partials = (a_partials..., primal_b, b_partials...)
114-
return TangentBundle{N+1}(the_primal, the_partials)
115131
end
116132
end
117133

118-
119134
function shuffle_up(r::UniformBundle{N, B, U}, _::Val) where {N, B, U}
120135
(a, b) = primal(r)
121136
if r.tangent.val === b

test/forward_diff_no_inf.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,14 @@ module forward_diff_no_inf
100100
#@assert ir[SSAValue(5)][:inst].args[1] == Diffractor.∂☆{1, eras_mode}()
101101
#@assert ir[SSAValue(5)][:inst].args[2].primal == *
102102
ir.argtypes[2:end] .= Float64
103-
@assert infer_ir!(ir) == Float64
104-
105-
Diffractor.forward_diff_no_inf!(ir, [SSAValue(6)] .=> 1; transform! = identity_transform!, eras_mode=eras_mode)
106-
ir = CC.compact!(ir)
107-
CC.verify_ir(ir)
108103
infer_ir!(ir)
109104

105+
Diffractor.forward_diff_no_inf!(ir, [SSAValue(3)] .=> 1; transform! = identity_transform!, eras_mode=eras_mode)
106+
# TODO actually test things here.
110107

111108

109+
ir = CC.compact!(ir)
110+
CC.verify_ir(ir)
111+
infer_ir!(ir)
112112
end
113113
end

0 commit comments

Comments
 (0)