Skip to content

Commit ee8af85

Browse files
authored
Merge pull request #215 from pepijndevos/pv/exptang
add truncate for explicit tangent
2 parents 7a9e9ba + ee36daa commit ee8af85

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/tangent.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ abstract type AbstractTangentSpace; end
8484
struct ExplicitTangent{P}
8585
8686
A fully explicit coordinate representation of the tangent space,
87-
represented by a vector of `2^(N-1)` partials.
87+
represented by a vector of `2^N-1` partials.
8888
"""
8989
struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace
9090
partials::P
@@ -242,6 +242,14 @@ function truncate(tb::TangentBundle, order::Val)
242242
_TangentBundle(order, tb.primal, truncate(tb.tangent, order))
243243
end
244244

245+
function truncate(tb::ExplicitTangent, order::Val{N}) where {N}
246+
ExplicitTangent(tb.partials[1:2^N-1])
247+
end
248+
249+
function truncate(et::ExplicitTangent, order::Val{1})
250+
TaylorTangent(et.partials[1:1])
251+
end
252+
245253
const UniformBundle{N, B, U} = TangentBundle{N, B, UniformTangent{U}}
246254
UniformBundle{N, B, U}(primal::B, partial::U) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(partial))
247255
UniformBundle{N, B, U}(primal::B) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance))

test/tangent.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module tagent
22
using Diffractor
33
using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle
44
using Diffractor: TaylorBundle, TaylorTangentIndex, CompositeBundle
5+
using Diffractor: ExplicitTangent, TaylorTangent, truncate
56
using ChainRulesCore
67
using Test
78

@@ -44,4 +45,12 @@ end
4445
@test f'(23.5) == Tangent{Foo152}(; x=1.0)
4546
end
4647

48+
@testset "truncate" begin
49+
tt = TaylorTangent((1.0,2.0,3.0,4.0,5.0,6.0,7.0))
50+
@test truncate(tt, Val(2)) == TaylorTangent((1.0,2.0))
51+
et = ExplicitTangent((1.0,2.0,3.0,4.0,5.0,6.0,7.0))
52+
@test truncate(et, Val(2)) == ExplicitTangent((1.0,2.0,3.0))
53+
@test truncate(et, Val(1)) == TaylorTangent((1.0,))
54+
end
55+
4756
end # module

0 commit comments

Comments
 (0)