Skip to content

Commit e99e747

Browse files
committed
never construct a ProjectTo all with no non-(Zero-projector) fields
1 parent 7e5ae8e commit e99e747

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

src/projection.jl

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas
126126
(::ProjectTo{NoTangent})(dx) = NoTangent() # but this is the projection only for nonzero gradients,
127127
(::ProjectTo{NoTangent})(::NoTangent) = NoTangent() # and this one solves an ambiguity.
128128

129+
# Also, any explicit construction with fields, where all fields project to zero, itself
130+
# projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
131+
const _PZ = ProjectTo{<:AbstractZero}
132+
ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}()
133+
129134
# Tangent
130135
# This may be produced from e.g. x=range(1,2,length=3). There need not be any
131136
# AbstractArray representation of such a tangent, so we just pass it along,
@@ -265,12 +270,10 @@ end
265270
##### `LinearAlgebra`
266271
#####
267272

273+
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
274+
268275
# Row vectors
269-
function ProjectTo(x::LinearAlgebra.AdjointAbsVec)
270-
sub = ProjectTo(parent(x))
271-
sub isa ProjectTo{<:AbstractZero} && return sub
272-
return ProjectTo{Adjoint}(; parent=sub)
273-
end
276+
ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x)))
274277
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
275278
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
276279
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
@@ -285,11 +288,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
285288
return adjoint(project.parent(dy))
286289
end
287290

288-
function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
289-
sub = ProjectTo(parent(x))
290-
sub isa ProjectTo{<:AbstractZero} && return sub
291-
return ProjectTo{Transpose}(; parent=sub)
292-
end
291+
ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x)))
293292
function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec)
294293
return transpose(project.parent(transpose(dx)))
295294
end
@@ -302,11 +301,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
302301
end
303302

304303
# Diagonal
305-
function ProjectTo(x::Diagonal)
306-
sub = ProjectTo(x.diag)
307-
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if Diagonal(NoTangent()) worked
308-
return ProjectTo{Diagonal}(; diag=sub)
309-
end
304+
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
310305
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
311306
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
312307

@@ -318,7 +313,8 @@ for (SymHerm, chk, fun) in (
318313
@eval begin
319314
function ProjectTo(x::$SymHerm)
320315
sub = ProjectTo(parent(x))
321-
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if Hermitian(NoTangent()) etc. worked
316+
# Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial:
317+
sub isa ProjectTo{<:AbstractZero} && return sub
322318
return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub)
323319
end
324320
function (project::ProjectTo{$SymHerm})(dx::AbstractArray)
@@ -343,12 +339,7 @@ end
343339
# Triangular
344340
for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg
345341
@eval begin
346-
function ProjectTo(x::$UL)
347-
sub = ProjectTo(parent(x))
348-
# TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
349-
sub isa ProjectTo{<:AbstractZero} && return sub
350-
return ProjectTo{$UL}(; parent=sub)
351-
end
342+
ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x)))
352343
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx))
353344
function (project::ProjectTo{$UL})(dx::Diagonal)
354345
sub = project.parent

0 commit comments

Comments
 (0)