Skip to content

Commit 5369090

Browse files
committed
Add promotion rules for ZeroTangent and NoTangent
1 parent 9627bd6 commit 5369090

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/tangent_types/abstract_zero.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Base.:/(z::AbstractZero, ::Any) = z
3232

3333
Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T)
3434
# (::Type{T})(::AbstractZero, ::AbstractZero...) where {T<:Number} = zero(T)
35+
Base.promote_rule(T::Type{<:Number}, S::Type{<:AbstractZero}) = T
3536

3637
(::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y)
3738
(::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false)

test/tangent_types/abstract_zero.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@
8282
@test convert(Float32, ZeroTangent()) === 0.0f0
8383
@test convert(ComplexF64, ZeroTangent()) === 0.0 + 0.0im
8484

85+
@test promote_type(ZeroTangent, Bool) == Bool
86+
@test promote_type(Bool, ZeroTangent) == Bool
87+
@test promote_type(ZeroTangent, Int64) == Int64
88+
@test promote_type(Int64, ZeroTangent) == Int64
89+
@test promote_type(ZeroTangent, Float32) == Float32
90+
@test promote_type(Float32, ZeroTangent) == Float32
91+
@test promote_type(ZeroTangent, ComplexF64) == ComplexF64
92+
@test promote_type(ComplexF64, ZeroTangent) == ComplexF64
93+
8594
@test z[1] === z
8695
@test z[1:3] === z
8796
@test z[1, 2] === z
@@ -110,6 +119,15 @@
110119
@test dot(dne, 17.2) == dne
111120
@test dot(11.9, dne) == dne
112121

122+
@test promote_type(NoTangent, Bool) == Bool
123+
@test promote_type(Bool, NoTangent) == Bool
124+
@test promote_type(NoTangent, Int64) == Int64
125+
@test promote_type(Int64, NoTangent) == Int64
126+
@test promote_type(NoTangent, Float32) == Float32
127+
@test promote_type(Float32, NoTangent) == Float32
128+
@test promote_type(NoTangent, ComplexF64) == ComplexF64
129+
@test promote_type(ComplexF64, NoTangent) == ComplexF64
130+
113131
@test ZeroTangent() + dne == dne
114132
@test dne + ZeroTangent() == dne
115133
@test ZeroTangent() - dne == dne

0 commit comments

Comments
 (0)