@@ -59,10 +59,8 @@ function generic_projector(x::T; kw...) where {T}
59
59
fields_nt:: NamedTuple = backing (x)
60
60
fields_proj = map (_maybe_projector, fields_nt)
61
61
# We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
62
- # `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
63
- # but if it doesn't `construct` will give a good error message.
62
+ # `Foo{Diagaonal{E}}` etc. Official API for this? https://github.com/JuliaLang/julia/issues/35543
64
63
wrapT = T. name. wrapper
65
- # Official API for this? https://github.com/JuliaLang/julia/issues/35543
66
64
return ProjectTo {wrapT} (; fields_proj... , kw... )
67
65
end
68
66
@@ -72,12 +70,6 @@ function generic_projection(project::ProjectTo{T}, dx::T) where {T}
72
70
return construct (T, map (_maybe_call, sub_projects, sub_dxs))
73
71
end
74
72
75
- function (project:: ProjectTo{T} )(dx:: Tangent ) where {T}
76
- sub_projects = backing (project)
77
- sub_dxs = backing (canonicalize (dx))
78
- return construct (T, map (_maybe_call, sub_projects, sub_dxs))
79
- end
80
-
81
73
# Used for encoding fields, leaves alone non-diff types:
82
74
_maybe_projector (x:: Union{AbstractArray,Number,Ref} ) = ProjectTo (x)
83
75
_maybe_projector (x) = x
@@ -123,7 +115,6 @@ ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2),
123
115
ProjectTo (:: Any ) # just to attach docstring
124
116
125
117
# Generic
126
- (:: ProjectTo{T} )(dx:: T ) where {T} = dx # not always correct but we have special cases for when it isn't
127
118
(:: ProjectTo{T} )(dx:: AbstractZero ) where {T} = dx
128
119
(:: ProjectTo{T} )(dx:: NotImplemented ) where {T} = dx
129
120
@@ -133,7 +124,17 @@ ProjectTo(::Any) # just to attach docstring
133
124
# Zero
134
125
ProjectTo (:: AbstractZero ) = ProjectTo {NoTangent} () # Any x::Zero in forward pass makes this one projector,
135
126
(:: ProjectTo{NoTangent} )(dx) = NoTangent () # but this is the projection only for nonzero gradients,
136
- (:: ProjectTo{NoTangent} )(:: NoTangent ) = NoTangent () # and this one solves an ambiguity.
127
+ (:: ProjectTo{NoTangent} )(dx:: AbstractZero ) = dx # and this one solves an ambiguity.
128
+
129
+ # Also, any explicit construction with fields, where all fields project to zero, itself
130
+ # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
131
+ const _PZ = ProjectTo{<: AbstractZero }
132
+ ProjectTo {P} (:: NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}} ) where {P,T} = ProjectTo {NoTangent} ()
133
+
134
+ # Tangent
135
+ # We haven't entirely figured out when to convert Tangents to "natural" representations such as
136
+ # dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
137
+ (:: ProjectTo{T} )(dx:: Tangent{<:T} ) where {T} = dx
137
138
138
139
# ####
139
140
# #### `Base`
@@ -165,27 +166,29 @@ end
165
166
(:: ProjectTo{T} )(dx:: Integer ) where {T<: Complex{<:AbstractFloat} } = convert (T, dx)
166
167
167
168
# Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through.
168
- # We assume (lacking evidence to the contrary) that it is the right subspace of numebers
169
- # The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different
170
- # Number type that might not be a subtype of the `project_type`.
169
+ # We assume (lacking evidence to the contrary) that it is the right subspace of numebers.
171
170
(:: ProjectTo{<:Number} )(dx:: Number ) = dx
172
171
173
172
(project:: ProjectTo{<:Real} )(dx:: Complex ) = project (real (dx))
174
173
(project:: ProjectTo{<:Complex} )(dx:: Real ) = project (complex (dx))
175
174
175
+ # Tangents: we prefer to reconstruct numbers, but only safe to try when their constructor
176
+ # understands, including a mix of Zeros & reals. Other cases, we just let through:
177
+ (project:: ProjectTo{<:Complex} )(dx:: Tangent{<:Complex} ) = project (Complex (dx. re, dx. im))
178
+ (:: ProjectTo{<:Number} )(dx:: Tangent{<:Number} ) = dx
179
+
176
180
# Arrays
177
181
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
178
182
# no structure worth re-imposing. Then any array is acceptable as a gradient.
179
183
180
184
# For arrays of numbers, just store one projector:
181
185
function ProjectTo (x:: AbstractArray{T} ) where {T<: Number }
182
- element = T <: Irrational ? ProjectTo {Real} () : ProjectTo (zero (T))
183
- if element isa ProjectTo{<: AbstractZero }
184
- return ProjectTo {NoTangent} () # short-circuit if all elements project to zero
185
- else
186
- return ProjectTo {AbstractArray} (; element= element, axes= axes (x))
187
- end
186
+ return ProjectTo {AbstractArray} (; element= _eltype_projectto (T), axes= axes (x))
188
187
end
188
+ ProjectTo (x:: AbstractArray{Bool} ) = ProjectTo {NoTangent} ()
189
+
190
+ _eltype_projectto (:: Type{T} ) where {T<: Number } = ProjectTo (zero (T))
191
+ _eltype_projectto (:: Type{<:Irrational} ) = ProjectTo {Real} ()
189
192
190
193
# In other cases, store a projector per element:
191
194
function ProjectTo (xs:: AbstractArray )
@@ -241,27 +244,39 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
241
244
return fill (project. element (dx))
242
245
end
243
246
244
- # Ref -- works like a zero-array, also allows restoration from a number:
245
- ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x= ProjectTo (x[]))
246
- (project:: ProjectTo{Ref} )(dx:: Ref ) = Ref (project. x (dx[]))
247
- (project:: ProjectTo{Ref} )(dx:: Number ) = Ref (project. x (dx))
248
-
249
247
function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
250
248
size_x = map (length, axes_x)
251
249
return DimensionMismatch (
252
250
" variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx "
253
251
)
254
252
end
255
253
254
+ # ####
255
+ # #### `Base`, part II: return of the Tangent
256
+ # ####
257
+
258
+ # Ref
259
+ function ProjectTo (x:: Ref )
260
+ sub = ProjectTo (x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
261
+ if sub isa ProjectTo{<: AbstractZero }
262
+ return ProjectTo {NoTangent} ()
263
+ else
264
+ return ProjectTo {Ref} (; type= typeof (x), x= sub)
265
+ end
266
+ end
267
+ (project:: ProjectTo{Ref} )(dx:: Tangent{<:Ref} ) = Tangent {project.type} (; x= project. x (dx. x))
268
+ (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.type} (; x= project. x (dx[]))
269
+ # Since this works like a zero-array in broadcasting, it should also accept a number:
270
+ (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.type} (; x= project. x (dx))
271
+
256
272
# ####
257
273
# #### `LinearAlgebra`
258
274
# ####
259
275
276
+ using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
277
+
260
278
# Row vectors
261
- function ProjectTo (x:: LinearAlgebra.AdjointAbsVec )
262
- sub = ProjectTo (parent (x))
263
- return ProjectTo {Adjoint} (; parent= sub)
264
- end
279
+ ProjectTo (x:: AdjointAbsVec ) = ProjectTo {Adjoint} (; parent= ProjectTo (parent (x)))
265
280
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
266
281
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
267
282
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
@@ -276,10 +291,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
276
291
return adjoint (project. parent (dy))
277
292
end
278
293
279
- function ProjectTo (x:: LinearAlgebra.TransposeAbsVec )
280
- sub = ProjectTo (parent (x))
281
- return ProjectTo {Transpose} (; parent= sub)
282
- end
294
+ ProjectTo (x:: LinearAlgebra.TransposeAbsVec ) = ProjectTo {Transpose} (; parent= ProjectTo (parent (x)))
283
295
function (project:: ProjectTo{Transpose} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
284
296
return transpose (project. parent (transpose (dx)))
285
297
end
@@ -292,11 +304,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
292
304
end
293
305
294
306
# Diagonal
295
- function ProjectTo (x:: Diagonal )
296
- sub = ProjectTo (x. diag)
297
- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if Diagonal(NoTangent()) worked
298
- return ProjectTo {Diagonal} (; diag= sub)
299
- end
307
+ ProjectTo (x:: Diagonal ) = ProjectTo {Diagonal} (; diag= ProjectTo (x. diag))
300
308
(project:: ProjectTo{Diagonal} )(dx:: AbstractMatrix ) = Diagonal (project. diag (diag (dx)))
301
309
(project:: ProjectTo{Diagonal} )(dx:: Diagonal ) = Diagonal (project. diag (dx. diag))
302
310
@@ -308,7 +316,8 @@ for (SymHerm, chk, fun) in (
308
316
@eval begin
309
317
function ProjectTo (x:: $SymHerm )
310
318
sub = ProjectTo (parent (x))
311
- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if Hermitian(NoTangent()) etc. worked
319
+ # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial:
320
+ sub isa ProjectTo{<: AbstractZero } && return sub
312
321
return ProjectTo {$SymHerm} (; uplo= LinearAlgebra. sym_uplo (x. uplo), parent= sub)
313
322
end
314
323
function (project:: ProjectTo{$SymHerm} )(dx:: AbstractArray )
333
342
# Triangular
334
343
for UL in (:UpperTriangular , :LowerTriangular , :UnitUpperTriangular , :UnitLowerTriangular ) # UpperHessenberg
335
344
@eval begin
336
- function ProjectTo (x:: $UL )
337
- sub = ProjectTo (parent (x))
338
- # TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
339
- sub isa ProjectTo{<: AbstractZero } && return sub
340
- return ProjectTo {$UL} (; parent= sub)
341
- end
345
+ ProjectTo (x:: $UL ) = ProjectTo {$UL} (; parent= ProjectTo (parent (x)))
342
346
(project:: ProjectTo{$UL} )(dx:: AbstractArray ) = $ UL (project. parent (dx))
343
347
function (project:: ProjectTo{$UL} )(dx:: Diagonal )
344
348
sub = project. parent
0 commit comments