@@ -72,12 +72,6 @@ function generic_projection(project::ProjectTo{T}, dx::T) where {T}
72
72
return construct (T, map (_maybe_call, sub_projects, sub_dxs))
73
73
end
74
74
75
- function (project:: ProjectTo{T} )(dx:: Tangent ) where {T}
76
- sub_projects = backing (project)
77
- sub_dxs = backing (canonicalize (dx))
78
- return construct (T, map (_maybe_call, sub_projects, sub_dxs))
79
- end
80
-
81
75
# Used for encoding fields, leaves alone non-diff types:
82
76
_maybe_projector (x:: Union{AbstractArray,Number,Ref} ) = ProjectTo (x)
83
77
_maybe_projector (x) = x
@@ -135,6 +129,14 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas
135
129
(:: ProjectTo{NoTangent} )(dx) = NoTangent () # but this is the projection only for nonzero gradients,
136
130
(:: ProjectTo{NoTangent} )(:: NoTangent ) = NoTangent () # and this one solves an ambiguity.
137
131
132
+ # Tangent
133
+ # This may be produced from e.g. x=range(1,2,length=3). There need not be any
134
+ # AbstractArray representation of such a tangent, so we just pass it along,
135
+ # and trust that projection on fields before the constructor will act if necessary.
136
+ (:: ProjectTo{T} )(dx:: Tangent{<:T} ) where {T} = dx
137
+
138
+ # (project::ProjectTo{<:AbstractArray})(dx::Tangent{<:AbstractArray}) = dx
139
+
138
140
# ####
139
141
# #### `Base`
140
142
# ####
@@ -241,18 +243,21 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
241
243
return fill (project. element (dx))
242
244
end
243
245
244
- # Ref -- works like a zero-array, also allows restoration from a number:
245
- ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x= ProjectTo (x[]))
246
- (project:: ProjectTo{Ref} )(dx:: Ref ) = Ref (project. x (dx[]))
247
- (project:: ProjectTo{Ref} )(dx:: Number ) = Ref (project. x (dx))
248
-
249
246
function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
250
247
size_x = map (length, axes_x)
251
248
return DimensionMismatch (
252
249
" variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx "
253
250
)
254
251
end
255
252
253
+ # Ref
254
+ # This can't be its own tangent, so it standardises on a Tangent{<:Ref}
255
+ ProjectTo (x:: Ref ) = ProjectTo {Ref} (; reftype= typeof (x), x= ProjectTo (x[]))
256
+ (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.reftype} (; x= project. x (dx[]))
257
+ (project:: ProjectTo{Ref} )(dx:: Tangent ) = Tangent {project.reftype} (; x= project. x (dx. x))
258
+ # Since this works like a zero-array in broadcasting, it should also accept a number:
259
+ (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.reftype} (; x= project. x (dx))
260
+
256
261
# ####
257
262
# #### `LinearAlgebra`
258
263
# ####
0 commit comments