Skip to content

Commit 8218c2c

Browse files
authored
Merge pull request #427 from mcabbott/projecttangents
Fix #426 -- gradient of Ref is a Tangent
2 parents 2208660 + eb872cd commit 8218c2c

File tree

5 files changed

+104
-56
lines changed

5 files changed

+104
-56
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.3.0"
3+
version = "1.3.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/differentials/abstract_zero.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Base.transpose(z::AbstractZero) = z
2525
Base.:/(z::AbstractZero, ::Any) = z
2626

2727
Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T)
28+
(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T)
29+
30+
(::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y)
31+
(::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false)
2832

2933
Base.getindex(z::AbstractZero, k) = z
3034

src/projection.jl

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,8 @@ function generic_projector(x::T; kw...) where {T}
5959
fields_nt::NamedTuple = backing(x)
6060
fields_proj = map(_maybe_projector, fields_nt)
6161
# We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
62-
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
63-
# but if it doesn't `construct` will give a good error message.
62+
# `Foo{Diagaonal{E}}` etc. Official API for this? https://github.com/JuliaLang/julia/issues/35543
6463
wrapT = T.name.wrapper
65-
# Official API for this? https://github.com/JuliaLang/julia/issues/35543
6664
return ProjectTo{wrapT}(; fields_proj..., kw...)
6765
end
6866

@@ -72,12 +70,6 @@ function generic_projection(project::ProjectTo{T}, dx::T) where {T}
7270
return construct(T, map(_maybe_call, sub_projects, sub_dxs))
7371
end
7472

75-
function (project::ProjectTo{T})(dx::Tangent) where {T}
76-
sub_projects = backing(project)
77-
sub_dxs = backing(canonicalize(dx))
78-
return construct(T, map(_maybe_call, sub_projects, sub_dxs))
79-
end
80-
8173
# Used for encoding fields, leaves alone non-diff types:
8274
_maybe_projector(x::Union{AbstractArray,Number,Ref}) = ProjectTo(x)
8375
_maybe_projector(x) = x
@@ -123,7 +115,6 @@ ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2),
123115
ProjectTo(::Any) # just to attach docstring
124116

125117
# Generic
126-
(::ProjectTo{T})(dx::T) where {T} = dx # not always correct but we have special cases for when it isn't
127118
(::ProjectTo{T})(dx::AbstractZero) where {T} = dx
128119
(::ProjectTo{T})(dx::NotImplemented) where {T} = dx
129120

@@ -133,7 +124,17 @@ ProjectTo(::Any) # just to attach docstring
133124
# Zero
134125
ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector,
135126
(::ProjectTo{NoTangent})(dx) = NoTangent() # but this is the projection only for nonzero gradients,
136-
(::ProjectTo{NoTangent})(::NoTangent) = NoTangent() # and this one solves an ambiguity.
127+
(::ProjectTo{NoTangent})(dx::AbstractZero) = dx # and this one solves an ambiguity.
128+
129+
# Also, any explicit construction with fields, where all fields project to zero, itself
130+
# projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
131+
const _PZ = ProjectTo{<:AbstractZero}
132+
ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}()
133+
134+
# Tangent
135+
# We haven't entirely figured out when to convert Tangents to "natural" representations such as
136+
# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
137+
(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx
137138

138139
#####
139140
##### `Base`
@@ -165,27 +166,29 @@ end
165166
(::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
166167

167168
# Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through.
168-
# We assume (lacking evidence to the contrary) that it is the right subspace of numebers
169-
# The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different
170-
# Number type that might not be a subtype of the `project_type`.
169+
# We assume (lacking evidence to the contrary) that it is the right subspace of numebers.
171170
(::ProjectTo{<:Number})(dx::Number) = dx
172171

173172
(project::ProjectTo{<:Real})(dx::Complex) = project(real(dx))
174173
(project::ProjectTo{<:Complex})(dx::Real) = project(complex(dx))
175174

175+
# Tangents: we prefer to reconstruct numbers, but only safe to try when their constructor
176+
# understands, including a mix of Zeros & reals. Other cases, we just let through:
177+
(project::ProjectTo{<:Complex})(dx::Tangent{<:Complex}) = project(Complex(dx.re, dx.im))
178+
(::ProjectTo{<:Number})(dx::Tangent{<:Number}) = dx
179+
176180
# Arrays
177181
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
178182
# no structure worth re-imposing. Then any array is acceptable as a gradient.
179183

180184
# For arrays of numbers, just store one projector:
181185
function ProjectTo(x::AbstractArray{T}) where {T<:Number}
182-
element = T <: Irrational ? ProjectTo{Real}() : ProjectTo(zero(T))
183-
if element isa ProjectTo{<:AbstractZero}
184-
return ProjectTo{NoTangent}() # short-circuit if all elements project to zero
185-
else
186-
return ProjectTo{AbstractArray}(; element=element, axes=axes(x))
187-
end
186+
return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x))
188187
end
188+
ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}()
189+
190+
_eltype_projectto(::Type{T}) where {T<:Number} = ProjectTo(zero(T))
191+
_eltype_projectto(::Type{<:Irrational}) = ProjectTo{Real}()
189192

190193
# In other cases, store a projector per element:
191194
function ProjectTo(xs::AbstractArray)
@@ -241,27 +244,39 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
241244
return fill(project.element(dx))
242245
end
243246

244-
# Ref -- works like a zero-array, also allows restoration from a number:
245-
ProjectTo(x::Ref) = ProjectTo{Ref}(; x=ProjectTo(x[]))
246-
(project::ProjectTo{Ref})(dx::Ref) = Ref(project.x(dx[]))
247-
(project::ProjectTo{Ref})(dx::Number) = Ref(project.x(dx))
248-
249247
function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
250248
size_x = map(length, axes_x)
251249
return DimensionMismatch(
252250
"variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx"
253251
)
254252
end
255253

254+
#####
255+
##### `Base`, part II: return of the Tangent
256+
#####
257+
258+
# Ref
259+
function ProjectTo(x::Ref)
260+
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
261+
if sub isa ProjectTo{<:AbstractZero}
262+
return ProjectTo{NoTangent}()
263+
else
264+
return ProjectTo{Ref}(; type=typeof(x), x=sub)
265+
end
266+
end
267+
(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x))
268+
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[]))
269+
# Since this works like a zero-array in broadcasting, it should also accept a number:
270+
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx))
271+
256272
#####
257273
##### `LinearAlgebra`
258274
#####
259275

276+
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
277+
260278
# Row vectors
261-
function ProjectTo(x::LinearAlgebra.AdjointAbsVec)
262-
sub = ProjectTo(parent(x))
263-
return ProjectTo{Adjoint}(; parent=sub)
264-
end
279+
ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x)))
265280
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
266281
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
267282
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
@@ -276,10 +291,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
276291
return adjoint(project.parent(dy))
277292
end
278293

279-
function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
280-
sub = ProjectTo(parent(x))
281-
return ProjectTo{Transpose}(; parent=sub)
282-
end
294+
ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x)))
283295
function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec)
284296
return transpose(project.parent(transpose(dx)))
285297
end
@@ -292,11 +304,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
292304
end
293305

294306
# Diagonal
295-
function ProjectTo(x::Diagonal)
296-
sub = ProjectTo(x.diag)
297-
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if Diagonal(NoTangent()) worked
298-
return ProjectTo{Diagonal}(; diag=sub)
299-
end
307+
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
300308
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
301309
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
302310

@@ -308,7 +316,8 @@ for (SymHerm, chk, fun) in (
308316
@eval begin
309317
function ProjectTo(x::$SymHerm)
310318
sub = ProjectTo(parent(x))
311-
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if Hermitian(NoTangent()) etc. worked
319+
# Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial:
320+
sub isa ProjectTo{<:AbstractZero} && return sub
312321
return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub)
313322
end
314323
function (project::ProjectTo{$SymHerm})(dx::AbstractArray)
@@ -333,12 +342,7 @@ end
333342
# Triangular
334343
for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg
335344
@eval begin
336-
function ProjectTo(x::$UL)
337-
sub = ProjectTo(parent(x))
338-
# TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
339-
sub isa ProjectTo{<:AbstractZero} && return sub
340-
return ProjectTo{$UL}(; parent=sub)
341-
end
345+
ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x)))
342346
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx))
343347
function (project::ProjectTo{$UL})(dx::Diagonal)
344348
sub = project.parent

test/differentials/abstract_zero.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,15 @@
6464
@test complex(z, z) === z
6565
@test complex(z, 2.0) === Complex{Float64}(0.0, 2.0)
6666
@test complex(1.5, z) === Complex{Float64}(1.5, 0.0)
67+
@test Complex(z, 2.0) === Complex{Float64}(0.0, 2.0)
68+
@test Complex(1.5, z) === Complex{Float64}(1.5, 0.0)
69+
@test ComplexF64(z, 2.0) === Complex{Float64}(0.0, 2.0)
70+
@test ComplexF64(1.5, z) === Complex{Float64}(1.5, 0.0)
6771

68-
@test convert(Int64, ZeroTangent()) == 0
69-
@test convert(Float64, ZeroTangent()) == 0.0
72+
@test convert(Bool, ZeroTangent()) === false
73+
@test convert(Int64, ZeroTangent()) === Int64(0)
74+
@test convert(Float32, ZeroTangent()) === 0.0f0
75+
@test convert(ComplexF64, ZeroTangent()) === 0.0 + 0.0im
7076
end
7177

7278
@testset "NoTangent" begin

test/projection.jl

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,24 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
3333
@test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im
3434
@test ProjectTo(big(1.0))(2) === 2
3535
@test ProjectTo(1.0)(2) === 2.0
36+
37+
# Tangents
38+
ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im
3639
end
3740

3841
@testset "Dual" begin # some weird Real subtype that we should basically leave alone
3942
@test ProjectTo(1.0)(Dual(1.0, 2.0)) isa Dual
4043
@test ProjectTo(1.0)(Dual(1, 2)) isa Dual
44+
45+
# real & complex
4146
@test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual}
4247
@test ProjectTo(1.0 + 1im)(
4348
Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))
4449
) isa Complex{<:Dual}
4550
@test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual
51+
52+
# Tangent
53+
@test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value=1.0)) isa Tangent
4654
end
4755

4856
@testset "Base: arrays of numbers" begin
@@ -100,7 +108,7 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
100108
@test ProjectTo(Bool[]) isa ProjectTo{NoTangent}
101109
end
102110

103-
@testset "Base: zero-arrays & Ref" begin
111+
@testset "Base: zero-arrays" begin
104112
pzed = ProjectTo(fill(1.0))
105113
@test pzed(fill(3.14)) == fill(3.14) # easy
106114
@test pzed(fill(3)) == fill(3.0) # broadcast type change must not produce number
@@ -110,17 +118,26 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
110118
@test_throws DimensionMismatch ProjectTo([1])(3.14 + im) # other array projectors don't accept numbers
111119
@test_throws DimensionMismatch ProjectTo(hcat([1, 2]))(3.14)
112120
@test pzed isa ProjectTo{AbstractArray}
121+
end
113122

123+
@testset "Base: Ref" begin
114124
pref = ProjectTo(Ref(2.0))
115-
@test pref(Ref(3 + im))[] === 3.0
116-
@test pref(4)[] === 4.0 # also re-wraps scalars
117-
@test pref(Ref{Any}(5.0)) isa Base.RefValue{Float64}
125+
@test pref(Ref(3 + im)).x === 3.0
126+
@test pref(Tangent{Base.RefValue}(x = 3 + im)).x === 3.0
127+
@test pref(4).x === 4.0 # also re-wraps scalars
128+
@test pref(Ref{Any}(5.0)) isa Tangent{<:Base.RefValue}
129+
118130
pref2 = ProjectTo(Ref{Any}(6 + 7im))
119-
@test pref2(Ref(8))[] === 8.0 + 0.0im
131+
@test pref2(Ref(8)).x === 8.0 + 0.0im
132+
@test pref2(Tangent{Base.RefValue}(x = 8)).x === 8.0 + 0.0im
120133

121134
prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents
122-
@test prefvec(Ref(1:3)) isa Base.RefValue{Vector{ComplexF64}}
123-
@test_throws DimensionMismatch prefvec(Ref{Any}(1:5))
135+
@test prefvec(Ref(1:3)).x isa Vector{ComplexF64}
136+
@test prefvec(Tangent{Base.RefValue}(x = 1:3)).x isa Vector{ComplexF64}
137+
@test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(x = 1:5))
138+
139+
@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
140+
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
124141
end
125142

126143
#####
@@ -167,6 +184,9 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
167184

168185
# issue #410
169186
@test padj([NoTangent() NoTangent() NoTangent()]) === NoTangent()
187+
188+
@test ProjectTo(adj([true, false]))([1 2]) isa AbstractZero
189+
@test ProjectTo(adj([[true], [false]])) isa ProjectTo{<:AbstractZero}
170190
end
171191

172192
@testset "LinearAlgebra: dense structured matrices" begin
@@ -284,11 +304,12 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
284304
@testset "AbstractZero" begin
285305
pz = ProjectTo(ZeroTangent())
286306
pz(0) == NoTangent()
287-
@test_broken pz(ZeroTangent()) === ZeroTangent() # not sure how NB this is to preserve
307+
@test pz(ZeroTangent()) === ZeroTangent() # not sure how NB this is to preserve
288308
@test pz(NoTangent()) === NoTangent()
289309

290310
pb = ProjectTo(true) # Bool is categorical
291311
@test pb(2) === NoTangent()
312+
@test pb(ZeroTangent()) isa AbstractZero # was a method ambiguity!
292313

293314
# all projectors preserve Zero, and specific type, via one fallback method:
294315
@test ProjectTo(pi)(ZeroTangent()) === ZeroTangent()
@@ -305,6 +326,19 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
305326
@test unthunk(pth) === 6.0 + 0.0im
306327
end
307328

329+
@testset "Tangent" begin
330+
x = 1:3.0
331+
dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent());
332+
@test ProjectTo(x)(dx) isa Tangent
333+
@test ProjectTo(x)(dx).step === 0.1
334+
@test ProjectTo(x)(dx).offset isa AbstractZero
335+
336+
pref = ProjectTo(Ref(2.0))
337+
dy = Tangent{typeof(Ref(2.0))}(x = 3+4im)
338+
@test pref(dy) isa Tangent{<:Base.RefValue}
339+
@test pref(dy).x === 3.0
340+
end
341+
308342
@testset "display" begin
309343
@test repr(ProjectTo(1.1)) == "ProjectTo{Float64}()"
310344
@test occursin("ProjectTo{AbstractArray}(element", repr(ProjectTo([1, 2, 3])))

0 commit comments

Comments
 (0)