Skip to content

Commit 7e3c80b

Browse files
committed
Fix closures
1 parent 9578366 commit 7e3c80b

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/stage1/forward.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ function shuffle_base(r)
107107
end
108108

109109
function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
110+
# @info "∂☆internal{1}" args
110111
r = frule(#=DiffractorRuleConfig(),=# map(first_partial, args), map(primal, args)...)
111112
if r === nothing
112113
return ∂☆recurse{1}()(args...)
@@ -132,12 +133,12 @@ end
132133
function (::∂☆internal{N})(f::AbstractZeroBundle{N}, args::AbstractZeroBundle{N}...) where {N}
133134
f_v = primal(f)
134135
args_v = map(primal, args)
135-
return ZeroBundle{N}(f_v(args_v...))
136+
return zero_bundle{N}()(f_v(args_v...))
136137
end
137138
function (::∂☆internal{1})(f::AbstractZeroBundle{1}, args::AbstractZeroBundle{1}...)
138139
f_v = primal(f)
139140
args_v = map(primal, args)
140-
return ZeroBundle{1}(f_v(args_v...))
141+
return zero_bundle{1}()(f_v(args_v...))
141142
end
142143

143144
function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}

test/forward_mutation.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,22 @@ end
6666
@test 🐇[TaylorTangentIndex(2)].x == 2.0 # returns 20
6767
end
6868

69+
@testset "closure" begin
70+
function bar(x)
71+
z = x + 1.0
72+
function foo!(y)
73+
z = z * y
74+
return z
75+
end
76+
77+
foo!(2)
78+
foo!(2)
79+
return z
80+
end
81+
82+
🥯 = ∂☆{1}()(ZeroBundle{1}(bar), TaylorBundle{1}(10.0, (1.0,)))
83+
@test 🥯[TaylorTangentIndex(1)] == 4.0
84+
end
85+
6986

7087
# end # module

0 commit comments

Comments
 (0)