Skip to content

Commit ade0c3d

Browse files
committed
Support types that have no tangent space in zero_tangent
1 parent 59fc470 commit ade0c3d

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

src/tangent_types/abstract_zero.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,11 @@ struct NoTangent <: AbstractZero end
9696
zero_tangent(primal)
9797
9898
This returns an appropriate zero tangent suitable for accumulating tangents of the primal.
99-
For mutable composites types this is a structural []`MutableTangent`](@ref)
99+
For mutable composites types this is a structural [`MutableTangent`](@ref)
100100
For `Array`s, it is applied recursively for each element.
101-
For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is default out-of-place for contexts where mutation does not apply.
102-
(Where mutation is not to be supported even for mutable types, then [`ZeroTangent()`](@ref) should be used for everything)
101+
For other types, in particular immutable types, we do not make promises beyond that it will be `iszero`
102+
and suitable for accumulating against.
103+
In general though, it is more likely to produce a structural tangent.
103104
104105
!!! warning Exprimental
105106
`zero_tangent`is an experimental feature, and is part of the mutation support featureset.
@@ -110,7 +111,10 @@ function zero_tangent end
110111

111112
zero_tangent(x::Number) = zero(x)
112113

114+
zero_tangent(::Type) = NoTangent()
115+
113116
@generated function zero_tangent(primal)
117+
fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero.
114118
zfield_exprs = map(fieldnames(primal)) do fname
115119
fval = :(
116120
if isdefined(primal, $(QuoteNode(fname)))

test/tangent_types/abstract_zero.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,25 +162,38 @@
162162
end
163163

164164
@testset "zero_tangent" begin
165-
@test zero_tangent(1) === 0
166-
@test zero_tangent(1.0) === 0.0
167-
mutable struct MutDemo
168-
x::Float64
169-
end
170-
struct Demo
171-
x::Float64
172-
end
173-
@test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo}
174-
@test iszero(zero_tangent(MutDemo(1.5)))
165+
@testset "basics" begin
166+
@test zero_tangent(1) === 0
167+
@test zero_tangent(1.0) === 0.0
168+
mutable struct MutDemo
169+
x::Float64
170+
end
171+
struct Demo
172+
x::Float64
173+
end
174+
@test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo}
175+
@test iszero(zero_tangent(MutDemo(1.5)))
175176

176-
@test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))}
177-
@test zero_tangent(Demo(1.2)) isa Tangent{Demo}
178-
@test zero_tangent(Demo(1.2)).x === 0.0
177+
@test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))}
178+
@test zero_tangent(Demo(1.2)) isa Tangent{Demo}
179+
@test zero_tangent(Demo(1.2)).x === 0.0
179180

180-
@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
181-
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]
181+
@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
182+
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]
183+
184+
@test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0)
185+
end
186+
187+
@testset "Weird types" begin
188+
@test iszero(zero_tangent(typeof(Int))) # primative type
189+
@test iszero(zero_tangent(typeof(Base.RefValue))) # struct
190+
@test iszero(zero_tangent(Vector)) # UnionAll
191+
@test iszero(zero_tangent(Union{Int, Float64})) # Union
192+
@test iszero(zero_tangent(:abc))
193+
@test iszero(zero_tangent("abc"))
194+
@test iszero(zero_tangent(sin))
195+
end
182196

183-
@test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0)
184197
@testset "undef elements Vector" begin
185198
x = Vector{Vector{Float64}}(undef, 3)
186199
x[2] = [1.0, 2.0]

0 commit comments

Comments
 (0)