Skip to content

Commit e0ff39b

Browse files
authored
Merge pull request #225 from JuliaDiff/ox/tangent_type_polish
Add polish to Bundle Types
2 parents 4aab1d7 + d5670d2 commit e0ff39b

File tree

3 files changed

+122
-34
lines changed

3 files changed

+122
-34
lines changed

src/stage1/forward.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ partial(x::TangentBundle, i) = partial(getfield(x, :tangent), i)
22
partial(x::ExplicitTangent, i) = getfield(getfield(x, :partials), i)
33
partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
44
partial(x::UniformTangent, i) = getfield(x, :val)
5-
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
65
partial(x::AbstractZero, i) = x
76

87

src/tangent.jl

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,8 @@ struct TaylorTangentIndex <: TangentIndex
7474
i::Int
7575
end
7676

77-
function Base.getindex(a::AbstractTangentBundle, b::TaylorTangentIndex)
78-
error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous")
79-
end
80-
8177
abstract type AbstractTangentSpace; end
78+
Base.:(==)(x::AbstractTangentSpace, y::AbstractTangentSpace) = ==(promote(x, y)...)
8279

8380
"""
8481
struct ExplicitTangent{P}
@@ -89,13 +86,23 @@ represented by a vector of `2^N-1` partials.
8986
struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace
9087
partials::P
9188
end
89+
Base.:(==)(a::ExplicitTangent, b::ExplicitTangent) = a.partials == b.partials
90+
Base.hash(tt::ExplicitTangent, h::UInt64) = hash(tt.partials, h)
91+
92+
Base.getindex(tangent::ExplicitTangent, b::CanonicalTangentIndex) = tangent.partials[b.i]
93+
function Base.getindex(tangent::ExplicitTangent, b::TaylorTangentIndex)
94+
if lastindex(tangent.partials) == exp2(b.i) - 1
95+
return tangent.partials[end]
96+
end
97+
# TODO: should we also allow other indexes if all the partials at that level are equal up regardless of order?
98+
throw(DomainError(b, "$(typeof(tangent)) is not taylor-like. Taylor indexing is ambiguous"))
99+
end
100+
92101

93102
@eval struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace
94103
coeffs::C
95104
TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs))
96105
end
97-
Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs
98-
Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h)
99106

100107
"""
101108
struct TaylorTangent{C}
@@ -122,15 +129,13 @@ by analogy with the (truncated) Taylor series
122129
"""
123130
TaylorTangent
124131

125-
"""
126-
struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}}
132+
Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs
133+
Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h)
134+
135+
136+
Base.getindex(tangent::TaylorTangent, tti::TaylorTangentIndex) = tangent.coeffs[tti.i]
137+
Base.getindex(tangent::TaylorTangent, tti::CanonicalTangentIndex) = tangent.coeffs[count_ones(tti.i)]
127138

128-
Represents the product space of the given representations of the
129-
tangent space.
130-
"""
131-
struct ProductTangent{T <: Tuple} <: AbstractTangentSpace
132-
factors::T
133-
end
134139

135140
"""
136141
struct UniformTangent
@@ -141,6 +146,28 @@ useful for representing singleton values.
141146
struct UniformTangent{U} <: AbstractTangentSpace
142147
val::U
143148
end
149+
Base.hash(t::UniformTangent, h::UInt64) = hash(t.val, h)
150+
Base.:(==)(t1::UniformTangent, t2::UniformTangent) = t1.val == t2.val
151+
152+
Base.getindex(tangent::UniformTangent, ::Any) = tangent.val
153+
154+
# Conversion and promotion
155+
Base.promote_rule(et::Type{<:ExplicitTangent}, ::Type{<:AbstractTangentSpace}) = et
156+
Base.promote_rule(tt::Type{<:TaylorTangent}, ::Type{<:AbstractTangentSpace}) = tt
157+
Base.promote_rule(et::Type{<:ExplicitTangent}, ::Type{<:TaylorTangent}) = et
158+
Base.promote_rule(::Type{<:TaylorTangent}, et::Type{<:ExplicitTangent}) = et
159+
160+
num_partials(::Type{TaylorTangent{P}}) where P = fieldcount(P)
161+
num_partials(::Type{ExplicitTangent{P}}) where P = fieldcount(P)
162+
Base.eltype(::Type{TaylorTangent{P}}) where P = eltype(P)
163+
Base.eltype(::Type{ExplicitTangent{P}}) where P = eltype(P)
164+
function Base.convert(::Type{T}, ut::UniformTangent) where {T<:Union{TaylorTangent, ExplicitTangent}}
165+
# can't just use T to construct as the inner constructor doesn't accept type params. So get T_wrapper
166+
T_wrapper = T<:TaylorTangent ? TaylorTangent : ExplicitTangent
167+
T_wrapper(ntuple(_->convert(eltype(T), ut.val), num_partials(T)))
168+
end
169+
Base.convert(T::Type{<:ExplicitTangent}, tt::TaylorTangent) = ExplicitTangent(ntuple(i->tt[CanonicalTangentIndex(i)], num_partials(T)))
170+
#TODO: Should we define the reverse: Explict->Taylor for the cases where that is actually defined?
144171

145172
function _TangentBundle end
146173

@@ -154,15 +181,17 @@ end
154181
struct TangentBundle{N, B, P}
155182
156183
Represents a tangent bundle as an explicit primal together
157-
with some representation of (potentially a product of) the tangent space.
184+
with some representation of the tangent space.
158185
"""
159186
TangentBundle
160187

161188
TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
162189
_TangentBundle(Val{N}(), primal, tangent)
163190

164191
Base.hash(tb::TangentBundle, h::UInt64) = hash(tb.primal, h)
165-
Base.:(==)(a::TangentBundle, b::TangentBundle) = (a.primal == b.primal) && (a.tangent == b.tangent)
192+
Base.:(==)(a::TangentBundle, b::TangentBundle) = false # different orders
193+
Base.:(==)(a::TangentBundle{N}, b::TangentBundle{N}) where {N} = (a.primal == b.primal) && (a.tangent == b.tangent)
194+
Base.getindex(tbun::TangentBundle, x) = getindex(tbun.tangent, x)
166195

167196
const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
168197

@@ -197,12 +226,7 @@ function Base.show(io::IO, x::ExplicitTangentBundle)
197226
length(x.partials) >= 7 && print(io, " + ", x.partials[7], " ∂₁ ∂₂ ∂₃")
198227
end
199228

200-
function Base.getindex(a::ExplicitTangentBundle{N}, b::TaylorTangentIndex) where {N}
201-
if b.i === N
202-
return a.tangent.partials[end]
203-
end
204-
error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous")
205-
end
229+
206230

207231
const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
208232

@@ -233,11 +257,6 @@ function Base.show(io::IO, x::TaylorBundle{1})
233257
print(io, x.coeffs[1], " ∂₁")
234258
end
235259

236-
Base.getindex(tb::TaylorBundle, tti::TaylorTangentIndex) = tb.tangent.coeffs[tti.i]
237-
function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
238-
tb.tangent.coeffs[count_ones(tti.i)]
239-
end
240-
241260
"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
242261
function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple}
243262
return ntuple(fieldcount(B)) do field_ii
@@ -307,8 +326,18 @@ function Base.show(io::IO, t::AbstractZeroBundle{N}) where N
307326
print(io, ")")
308327
end
309328

329+
# Conversion and promotion
330+
function Base.promote_rule(::Type{TangentBundle{N, B, P1}}, ::Type{TangentBundle{N, B, P2}}) where {N,B,P1,P2}
331+
return TangentBundle{N, B, promote_type(P1, P2)}
332+
end
333+
334+
function Base.convert(::Type{T}, tbun::TangentBundle{N, B}) where {N, B, P, T<:TangentBundle{N,B,P}}
335+
the_primal = convert(B, primal(tbun))
336+
the_partials = convert(P, tbun.tangent)
337+
return _TangentBundle(Val{N}(), the_primal, the_partials)
338+
end
310339

311-
Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val
340+
# StructureArrays helpers
312341

313342
expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
314343
expand_singleton_to_array(asize, a::AbstractArray) = a

test/tangent.jl

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
module tagent
1+
module tangent
22
using Diffractor
3-
using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle
4-
using Diffractor: TaylorBundle, TaylorTangentIndex
3+
using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle, TaylorBundle, ExplicitTangentBundle
4+
using Diffractor:TaylorTangentIndex, CanonicalTangentIndex
55
using Diffractor: ExplicitTangent, TaylorTangent, truncate
66
using ChainRulesCore
77
using Test
@@ -46,11 +46,71 @@ end
4646
end
4747
end
4848

49-
@testset "== and hash" begin
49+
@testset "getindex" begin
50+
tt = TaylorBundle{2}(1.5, (1.0, 2.0))
51+
@test tt[TaylorTangentIndex(1)] == 1.0
52+
@test tt[TaylorTangentIndex(2)] == 2.0
53+
@test tt[CanonicalTangentIndex(1)] == 1.0
54+
@test tt[CanonicalTangentIndex(2)] == 1.0
55+
@test tt[CanonicalTangentIndex(3)] == 2.0
56+
57+
et = ExplicitTangentBundle{2}(1.5, (1.0, 2.0, 3.0))
58+
@test_throws DomainError et[TaylorTangentIndex(1)] == 1.0
59+
@test et[TaylorTangentIndex(2)] == 3.0
60+
@test et[CanonicalTangentIndex(1)] == 1.0
61+
@test et[CanonicalTangentIndex(2)] == 2.0
62+
@test et[CanonicalTangentIndex(3)] == 3.0
63+
64+
zb = ZeroBundle{2}(1.5)
65+
@test zb[TaylorTangentIndex(1)] == ZeroTangent()
66+
@test zb[TaylorTangentIndex(2)] == ZeroTangent()
67+
@test zb[CanonicalTangentIndex(1)] == ZeroTangent()
68+
@test zb[CanonicalTangentIndex(2)] == ZeroTangent()
69+
@test zb[CanonicalTangentIndex(3)] == ZeroTangent()
70+
end
71+
72+
@testset "promote" begin
73+
@test promote_type(
74+
typeof(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))),
75+
typeof(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)))
76+
) <: ExplicitTangentBundle{1, Vector{Float64}}
77+
78+
@test promote_type(TaylorBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1, Float64}) <: TaylorBundle{1, Float64, Tuple{Float64}}
79+
@test promote_type(ExplicitTangentBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1, Float64}) <: ExplicitTangentBundle{1, Float64, Tuple{Float64}}
80+
end
81+
@testset "convert" begin
82+
@test convert(TaylorBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1}(1.4)) == TaylorBundle{1}(1.4, (0.0,))
83+
@test convert(ExplicitTangentBundle{1, Float64, Tuple{Float64}}, ZeroBundle{1}(1.4)) == ExplicitTangentBundle{1}(1.4, (0.0,))
84+
85+
@test convert(
86+
typeof(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))),
87+
TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
88+
) == ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))
89+
90+
@test convert(
91+
typeof(ExplicitTangentBundle{2}(1.5, (10.0, 10.0, 20.0,))),
92+
TaylorBundle{2}(1.5, (10.0, 20.0))
93+
) === ExplicitTangentBundle{2}(1.5, (10.0, 10.0, 20.0,))
94+
end
95+
@testset "==" begin
5096
@test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
51-
@test hash(TaylorBundle{1}(0.0, (0.0,))) == hash(0)
97+
@test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],))
98+
99+
@test ZeroBundle{3}(1.5) == ZeroBundle{3}(1.5)
100+
@test ZeroBundle{3}(1.5) == TaylorBundle{3}(1.5, (0.0, 0.0, 0.0))
101+
@test ZeroBundle{3}(1.5) == ExplicitTangentBundle{3}(1.5, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
52102
end
53103

104+
@testset "hash" begin
105+
@test hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))) == hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)))
106+
@test hash(TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))) == hash(ExplicitTangentBundle{1}([2.0, 4.0], ([20.0, 200.0],)))
107+
108+
@test hash(ZeroBundle{3}(1.5)) == hash(ZeroBundle{3}(1.5))
109+
@test hash(ZeroBundle{3}(1.5)) == hash(TaylorBundle{3}(1.5, (0.0, 0.0, 0.0)))
110+
@test hash(ZeroBundle{3}(1.5)) == hash(ExplicitTangentBundle{3}(1.5, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)))
111+
end
112+
113+
54114
@testset "truncate" begin
55115
tt = TaylorTangent((1.0,2.0,3.0,4.0,5.0,6.0,7.0))
56116
@test truncate(tt, Val(2)) == TaylorTangent((1.0,2.0))

0 commit comments

Comments
 (0)