@@ -272,18 +272,55 @@ end
272
272
# ####
273
273
274
274
# Ref
275
+ # Note that Ref is mutable. This causes Zygote to represent its structral tangent not as a NamedTuple,
276
+ # but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see
277
+ # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
275
278
function ProjectTo (x:: Ref )
276
279
sub = ProjectTo (x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
277
- if sub isa ProjectTo{<: AbstractZero }
280
+ return ProjectTo {Tangent{typeof(x)}} (; x= sub)
281
+ end
282
+ (project:: ProjectTo{<:Tangent{<:Ref}} )(dx:: Tangent ) = project (Ref (first (backing (dx))))
283
+ function (project:: ProjectTo{<:Tangent{<:Ref}} )(dx:: Ref )
284
+ dy = project. x (dx[])
285
+ return project_type (project)(; x= dy)
286
+ end
287
+ # Since this works like a zero-array in broadcasting, it should also accept a number:
288
+ (project:: ProjectTo{<:Tangent{<:Ref}} )(dx:: Number ) = project (Ref (dx))
289
+
290
+ # Tuple
291
+ function ProjectTo (x:: Tuple )
292
+ elements = map (ProjectTo, x)
293
+ if elements isa NTuple{<: Any ,ProjectTo{<: AbstractZero }}
278
294
return ProjectTo {NoTangent} ()
279
295
else
280
- return ProjectTo {Ref} (; type = typeof (x), x = sub )
296
+ return ProjectTo {Tangent{ typeof(x)}} (; elements = elements )
281
297
end
282
298
end
283
- (project:: ProjectTo{Ref} )(dx:: Tangent{<:Ref} ) = Tangent {project.type} (; x= project. x (dx. x))
284
- (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.type} (; x= project. x (dx[]))
285
- # Since this works like a zero-array in broadcasting, it should also accept a number:
286
- (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.type} (; x= project. x (dx))
299
+ # This method means that projection is re-applied to the contents of a Tangent.
300
+ # We're not entirely sure whether this is every necessary; but it should be safe,
301
+ # and should often compile away:
302
+ (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: Tangent ) = project (backing (dx))
303
+ function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: Tuple )
304
+ len = length (project. elements)
305
+ if length (dx) != len
306
+ str = " tuple with length(x) == $len cannot have a gradient with length(dx) == $(length (dx)) "
307
+ throw (DimensionMismatch (str))
308
+ end
309
+ # Here map will fail if the lengths don't match, but gives a much less helpful error:
310
+ dy = map ((f, x) -> f (x), project. elements, dx)
311
+ return project_type (project)(dy... )
312
+ end
313
+ function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: AbstractArray )
314
+ for d in 1 : ndims (dx)
315
+ if size (dx, d) != get (length (project. elements), d, 1 )
316
+ throw (_projection_mismatch (axes (project. elements), size (dx)))
317
+ end
318
+ end
319
+ dy = reshape (dx, axes (project. elements)) # allows for dx::OffsetArray
320
+ dz = ntuple (i -> project. elements[i](dy[i]), length (project. elements))
321
+ return project_type (project)(dz... )
322
+ end
323
+
287
324
288
325
# ####
289
326
# #### `LinearAlgebra`
0 commit comments