@@ -59,10 +59,8 @@ function generic_projector(x::T; kw...) where {T}
59
59
fields_nt:: NamedTuple = backing (x)
60
60
fields_proj = map (_maybe_projector, fields_nt)
61
61
# We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
62
- # `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
63
- # but if it doesn't `construct` will give a good error message.
62
+ # `Foo{Diagaonal{E}}` etc. Official API for this? https://github.com/JuliaLang/julia/issues/35543
64
63
wrapT = T. name. wrapper
65
- # Official API for this? https://github.com/JuliaLang/julia/issues/35543
66
64
return ProjectTo {wrapT} (; fields_proj... , kw... )
67
65
end
68
66
@@ -252,11 +250,18 @@ end
252
250
253
251
# Ref
254
252
# This can't be its own tangent, so it standardises on a Tangent{<:Ref}
255
- ProjectTo (x:: Ref ) = ProjectTo {Ref} (; reftype= typeof (x), x= ProjectTo (x[]))
256
- (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.reftype} (; x= project. x (dx[]))
257
- (project:: ProjectTo{Ref} )(dx:: Tangent ) = Tangent {project.reftype} (; x= project. x (dx. x))
253
+ function ProjectTo (x:: Ref )
254
+ sub = ProjectTo (x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
255
+ if sub isa ProjectTo{<: AbstractZero }
256
+ return ProjectTo {NoTangent} ()
257
+ else
258
+ return ProjectTo {Ref} (; type= typeof (x), x= sub)
259
+ end
260
+ end
261
+ (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.type} (; x= project. x (dx[]))
262
+ (project:: ProjectTo{Ref} )(dx:: Tangent ) = Tangent {project.type} (; x= project. x (dx. x))
258
263
# Since this works like a zero-array in broadcasting, it should also accept a number:
259
- (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.reftype } (; x= project. x (dx))
264
+ (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.type } (; x= project. x (dx))
260
265
261
266
# ####
262
267
# #### `LinearAlgebra`
0 commit comments