Skip to content

Commit 084212a

Browse files
committed
define construstinstance over UniformBundles so differentiating map over closure wrt closed variable works
1 parent 960e74f commit 084212a

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/tangent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ function ChainRulesCore.rrule(::typeof(unbundle), atb::AbstractTangentBundle)
360360
unbundle(atb), Δ->throw(Δ)
361361
end
362362

363-
function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...)
363+
function StructArrays.createinstance(T::Type{<:UniformBundle}, args...)
364364
T(args[1], args[2])
365365
end
366366

test/forward.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ end
9393
∂☆{1}()(ZeroBundle{1}(xs->(map(x->2*x, xs))), TaylorBundle{1}([1.0, 2.0], ([10.0, 100.0],))),
9494
TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
9595
)
96+
97+
98+
# map over all closure, wrt the closed variable
99+
mulby(x) = y->x*y
100+
🐇 = ∂☆{1}()(
101+
ZeroBundle{1}(x->(map(mulby(x), [2.0, 4.0]))),
102+
TaylorBundle{1}(2.0, (10.0,))
103+
)
104+
@test 🐇 == TaylorBundle{1}([4.0, 8.0], ([20.0, 40.0],))
105+
96106
end
97107

98108

0 commit comments

Comments
 (0)