Skip to content

Commit 8f6041c

Browse files
committed
use Tangent for Ref's gradient, and pass them along
1 parent 1014715 commit 8f6041c

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

src/projection.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,6 @@ function generic_projection(project::ProjectTo{T}, dx::T) where {T}
7272
return construct(T, map(_maybe_call, sub_projects, sub_dxs))
7373
end
7474

75-
function (project::ProjectTo{T})(dx::Tangent) where {T}
76-
sub_projects = backing(project)
77-
sub_dxs = backing(canonicalize(dx))
78-
return construct(T, map(_maybe_call, sub_projects, sub_dxs))
79-
end
80-
8175
# Used for encoding fields, leaves alone non-diff types:
8276
_maybe_projector(x::Union{AbstractArray,Number,Ref}) = ProjectTo(x)
8377
_maybe_projector(x) = x
@@ -135,6 +129,14 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas
135129
(::ProjectTo{NoTangent})(dx) = NoTangent() # but this is the projection only for nonzero gradients,
136130
(::ProjectTo{NoTangent})(::NoTangent) = NoTangent() # and this one solves an ambiguity.
137131

132+
# Tangent
133+
# This may be produced from e.g. x=range(1,2,length=3). There need not be any
134+
# AbstractArray representation of such a tangent, so we just pass it along,
135+
# and trust that projection on fields before the constructor will act if necessary.
136+
(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx
137+
138+
# (project::ProjectTo{<:AbstractArray})(dx::Tangent{<:AbstractArray}) = dx
139+
138140
#####
139141
##### `Base`
140142
#####
@@ -241,18 +243,21 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
241243
return fill(project.element(dx))
242244
end
243245

244-
# Ref -- works like a zero-array, also allows restoration from a number:
245-
ProjectTo(x::Ref) = ProjectTo{Ref}(; x=ProjectTo(x[]))
246-
(project::ProjectTo{Ref})(dx::Ref) = Ref(project.x(dx[]))
247-
(project::ProjectTo{Ref})(dx::Number) = Ref(project.x(dx))
248-
249246
function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
250247
size_x = map(length, axes_x)
251248
return DimensionMismatch(
252249
"variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx"
253250
)
254251
end
255252

253+
# Ref
254+
# This can't be its own tangent, so it standardises on a Tangent{<:Ref}
255+
ProjectTo(x::Ref) = ProjectTo{Ref}(; reftype=typeof(x), x=ProjectTo(x[]))
256+
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.reftype}(; x=project.x(dx[]))
257+
(project::ProjectTo{Ref})(dx::Tangent) = Tangent{project.reftype}(; x=project.x(dx.x))
258+
# Since this works like a zero-array in broadcasting, it should also accept a number:
259+
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.reftype}(; x=project.x(dx))
260+
256261
#####
257262
##### `LinearAlgebra`
258263
#####

0 commit comments

Comments
 (0)