Skip to content

Commit 9794666

Browse files
committed
fix and test taylor_compatible
1 parent 34cedf8 commit 9794666

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

src/stage1/forward.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
5555
end
5656

5757
function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
58-
partial(r, 1)[1] = primal(r)[2] || return false
59-
return all(1:N-1) do ii
58+
partial(r, 1)[1] == primal(r)[2] || return false
59+
return all(1:N-1) do i
6060
partial(r, i+1)[1] == partial(r, i)[2]
6161
end
6262
end
6363
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
6464
the_primal = primal(r)[1]
6565
if taylor_compatible(r)
6666
the_partials = ntuple(N+1) do i
67-
if ii <= N
67+
if i <= N
6868
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
6969
else # ii = N+1
7070
partial(r, i-1)[2]

test/forward.jl

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
module forward_tests
22
using Diffractor
3-
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig, ZeroBundle
3+
using Diffractor: TaylorBundle
44
using ChainRules
55
using ChainRulesCore
66
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
77
using LinearAlgebra
8-
98
using Test
109

11-
const fwd = Diffractor.PrimeDerivativeFwd
12-
const bwd = Diffractor.PrimeDerivativeBack
13-
1410

1511

1612
# Minimal 2-nd order forward smoke test
@@ -131,4 +127,41 @@ end
131127
end
132128
end
133129

130+
131+
@testset "taylor_compatible" begin
132+
taylor_compatible = Diffractor.taylor_compatible
133+
134+
@test taylor_compatible(
135+
TaylorBundle{1}(10.0, (20.0,)),
136+
TaylorBundle{1}(20.0, (30.0,))
137+
)
138+
@test !taylor_compatible(
139+
TaylorBundle{1}(10.0, (20.0,)),
140+
TaylorBundle{1}(21.0, (30.0,))
141+
)
142+
@test taylor_compatible(
143+
TaylorBundle{2}(10.0, (20.0, 30.)),
144+
TaylorBundle{2}(20.0, (30.0, 40.))
145+
)
146+
@test !taylor_compatible(
147+
TaylorBundle{2}(10.0, (20.0, 30.0)),
148+
TaylorBundle{2}(20.0, (31.0, 40.0))
149+
)
150+
151+
152+
tuptan(args...) = Tangent{typeof(args)}(args...)
153+
@test taylor_compatible(
154+
TaylorBundle{1}((10.0, 20.0), (tuptan(20.0, 30.0),)),
155+
)
156+
@test taylor_compatible(
157+
TaylorBundle{2}((10.0, 20.0), (tuptan(20.0, 30.0),tuptan(30.0, 40.0))),
158+
)
159+
@test !taylor_compatible(
160+
TaylorBundle{1}((10.0, 20.0), (tuptan(21.0, 30.0),)),
161+
)
162+
@test !taylor_compatible(
163+
TaylorBundle{2}((10.0, 20.0), (tuptan(20.0, 31.0),tuptan(30.0, 40.0))),
164+
)
165+
end
166+
134167
end

0 commit comments

Comments
 (0)