Skip to content

Commit be2bcb9

Browse files
committed
reconstruct Complex from Tangent, do we need this?
1 parent 9a17580 commit be2bcb9

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

src/differentials/abstract_zero.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Base.transpose(z::AbstractZero) = z
2525
Base.:/(z::AbstractZero, ::Any) = z
2626

2727
Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T)
28+
(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T)
29+
30+
(::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y)
31+
(::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false)
2832

2933
Base.getindex(z::AbstractZero, k) = z
3034

src/projection.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ end
173173
(project::ProjectTo{<:Real})(dx::Complex) = project(real(dx))
174174
(project::ProjectTo{<:Complex})(dx::Real) = project(complex(dx))
175175

176+
# Tangents: we prefer to reconstruct numbers, but only safe to try when their constructor
177+
# understands, including a mix of Zeros & reals. Other cases, we just let through:
178+
(project::ProjectTo{<:Complex})(dx::Tangent{<:Number}) = project(Complex(dx.re, dx.im))
179+
(::ProjectTo{<:Number})(dx::Tangent{<:Number}) = dx
180+
176181
# Arrays
177182
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
178183
# no structure worth re-imposing. Then any array is acceptable as a gradient.

test/differentials/abstract_zero.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,15 @@
6464
@test complex(z, z) === z
6565
@test complex(z, 2.0) === Complex{Float64}(0.0, 2.0)
6666
@test complex(1.5, z) === Complex{Float64}(1.5, 0.0)
67+
@test Complex(z, 2.0) === Complex{Float64}(0.0, 2.0)
68+
@test Complex(1.5, z) === Complex{Float64}(1.5, 0.0)
69+
@test ComplexF64(z, 2.0) === Complex{Float64}(0.0, 2.0)
70+
@test ComplexF64(1.5, z) === Complex{Float64}(1.5, 0.0)
6771

68-
@test convert(Int64, ZeroTangent()) == 0
69-
@test convert(Float64, ZeroTangent()) == 0.0
72+
@test convert(Bool, ZeroTangent()) === false
73+
@test convert(Int64, ZeroTangent()) === 0
74+
@test convert(Float32, ZeroTangent()) === 0.0f0
75+
@test convert(ComplexF64, ZeroTangent()) === 0.0 + 0.0im
7076
end
7177

7278
@testset "NoTangent" begin

test/projection.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
3333
@test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im
3434
@test ProjectTo(big(1.0))(2) === 2
3535
@test ProjectTo(1.0)(2) === 2.0
36+
37+
# Tangents
38+
ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im
3639
end
3740

3841
@testset "Dual" begin # some weird Real subtype that we should basically leave alone

0 commit comments

Comments
 (0)