Skip to content

Commit c48bfd5

Browse files
committed
only transform nested recurse in Eras mode
1 parent 48c4317 commit c48bfd5

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/stage1/forward.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,11 @@ function (::∂☆internal{N, E})(args::AbstractTangentBundle{N}...) where {N, E
240240
end
241241
end
242242

243+
# Converts nested AD into Taylor AD (in Eras mode only)
244+
# Note: It does not matter if the inner ∂☆recurse was Eras mode or not, only the outer ∂☆.
243245
# TODO: Generalize to N,M
244-
@inline function (::∂☆{1,E})(rec::AbstractZeroBundle{1, ∂☆recurse{1, E}}, args::ATB{1}...) where E
245-
return shuffle_down_bundle(∂☆recurse{2,E}()(map(shuffle_up_bundle, args)...))
246+
@inline function (::∂☆{1,true})(rec::AbstractZeroBundle{1, ∂☆recurse{1}}, args::ATB{1}...)
247+
return shuffle_down_bundle(∂☆recurse{2,true}()(map(shuffle_up_bundle, args)...))
246248
end
247249

248250
(::∂☆{N,E})(args::AbstractTangentBundle{N}...) where {N,E} = ∂☆internal{N,E}()(args...)

0 commit comments

Comments
 (0)