Skip to content

Commit 960e74f

Browse files
committed
add missing TaylorBundle constructor, thus fixing map
1 parent 17c2264 commit 960e74f

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

src/tangent.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,20 +206,25 @@ end
206206

207207
const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
208208

209+
210+
function TaylorBundle{N, B, P}(primal::B, coeffs::P) where {N, B, P}
211+
check_taylor_invariants(coeffs, primal, N)
212+
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
213+
end
209214
function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
210215
check_taylor_invariants(coeffs, primal, N)
211216
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
212217
end
218+
function TaylorBundle{N}(primal, coeffs) where {N}
219+
check_taylor_invariants(coeffs, primal, N)
220+
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
221+
end
213222

214223
function check_taylor_invariants(coeffs, primal, N)
215224
@assert length(coeffs) == N
216-
217225
end
218226
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)
219227

220-
function TaylorBundle{N}(primal, coeffs) where {N}
221-
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
222-
end
223228

224229
function Base.show(io::IO, x::TaylorBundle{1})
225230
print(io, x.primal)

test/forward.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
module forward_tests
1+
#module forward_tests
22
using Diffractor
3-
using Diffractor: TaylorBundle, ZeroBundle
3+
using Diffractor: TaylorBundle, ZeroBundle, ∂☆
44
using ChainRules
55
using ChainRulesCore
66
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
@@ -61,7 +61,7 @@ end
6161
end
6262

6363
# Special case if there is no derivative information at all:
64-
@test (Diffractor.∂☆{1}())(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
64+
@test ∂☆{1}()(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
6565
@test frule_calls[] == 0
6666
@test primal_calls[] == 1
6767
end
@@ -88,6 +88,14 @@ end
8888
end
8989

9090

91+
@testset "map" begin
92+
@test ==(
93+
∂☆{1}()(ZeroBundle{1}(xs->(map(x->2*x, xs))), TaylorBundle{1}([1.0, 2.0], ([10.0, 100.0],))),
94+
TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
95+
)
96+
end
97+
98+
9199
@testset "structs" begin
92100
struct IDemo
93101
x::Float64

0 commit comments

Comments
 (0)