@@ -126,6 +126,11 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas
126
126
(:: ProjectTo{NoTangent} )(dx) = NoTangent () # but this is the projection only for nonzero gradients,
127
127
(:: ProjectTo{NoTangent} )(:: NoTangent ) = NoTangent () # and this one solves an ambiguity.
128
128
129
+ # Also, any explicit construction with fields, where all fields project to zero, itself
130
+ # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
131
+ const _PZ = ProjectTo{<: AbstractZero }
132
+ ProjectTo {P} (:: NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}} ) where {P,T} = ProjectTo {NoTangent} ()
133
+
129
134
# Tangent
130
135
# This may be produced from e.g. x=range(1,2,length=3). There need not be any
131
136
# AbstractArray representation of such a tangent, so we just pass it along,
@@ -265,12 +270,10 @@ end
265
270
# #### `LinearAlgebra`
266
271
# ####
267
272
273
+ using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
274
+
268
275
# Row vectors
269
- function ProjectTo (x:: LinearAlgebra.AdjointAbsVec )
270
- sub = ProjectTo (parent (x))
271
- sub isa ProjectTo{<: AbstractZero } && return sub
272
- return ProjectTo {Adjoint} (; parent= sub)
273
- end
276
+ ProjectTo (x:: AdjointAbsVec ) = ProjectTo {Adjoint} (; parent= ProjectTo (parent (x)))
274
277
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
275
278
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
276
279
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
@@ -285,11 +288,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
285
288
return adjoint (project. parent (dy))
286
289
end
287
290
288
- function ProjectTo (x:: LinearAlgebra.TransposeAbsVec )
289
- sub = ProjectTo (parent (x))
290
- sub isa ProjectTo{<: AbstractZero } && return sub
291
- return ProjectTo {Transpose} (; parent= sub)
292
- end
291
+ ProjectTo (x:: LinearAlgebra.TransposeAbsVec ) = ProjectTo {Transpose} (; parent= ProjectTo (parent (x)))
293
292
function (project:: ProjectTo{Transpose} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
294
293
return transpose (project. parent (transpose (dx)))
295
294
end
@@ -302,11 +301,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
302
301
end
303
302
304
303
# Diagonal
305
- function ProjectTo (x:: Diagonal )
306
- sub = ProjectTo (x. diag)
307
- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if Diagonal(NoTangent()) worked
308
- return ProjectTo {Diagonal} (; diag= sub)
309
- end
304
+ ProjectTo (x:: Diagonal ) = ProjectTo {Diagonal} (; diag= ProjectTo (x. diag))
310
305
(project:: ProjectTo{Diagonal} )(dx:: AbstractMatrix ) = Diagonal (project. diag (diag (dx)))
311
306
(project:: ProjectTo{Diagonal} )(dx:: Diagonal ) = Diagonal (project. diag (dx. diag))
312
307
@@ -318,7 +313,8 @@ for (SymHerm, chk, fun) in (
318
313
@eval begin
319
314
function ProjectTo (x:: $SymHerm )
320
315
sub = ProjectTo (parent (x))
321
- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if Hermitian(NoTangent()) etc. worked
316
+ # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial:
317
+ sub isa ProjectTo{<: AbstractZero } && return sub
322
318
return ProjectTo {$SymHerm} (; uplo= LinearAlgebra. sym_uplo (x. uplo), parent= sub)
323
319
end
324
320
function (project:: ProjectTo{$SymHerm} )(dx:: AbstractArray )
343
339
# Triangular
344
340
for UL in (:UpperTriangular , :LowerTriangular , :UnitUpperTriangular , :UnitLowerTriangular ) # UpperHessenberg
345
341
@eval begin
346
- function ProjectTo (x:: $UL )
347
- sub = ProjectTo (parent (x))
348
- # TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
349
- sub isa ProjectTo{<: AbstractZero } && return sub
350
- return ProjectTo {$UL} (; parent= sub)
351
- end
342
+ ProjectTo (x:: $UL ) = ProjectTo {$UL} (; parent= ProjectTo (parent (x)))
352
343
(project:: ProjectTo{$UL} )(dx:: AbstractArray ) = $ UL (project. parent (dx))
353
344
function (project:: ProjectTo{$UL} )(dx:: Diagonal )
354
345
sub = project. parent
0 commit comments