Skip to content

Commit 9ad5f44

Browse files
committed
tweak
1 parent 7d8fcdb commit 9ad5f44

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

src/projection.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,8 @@ function generic_projector(x::T; kw...) where {T}
5959
fields_nt::NamedTuple = backing(x)
6060
fields_proj = map(_maybe_projector, fields_nt)
6161
# We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
62-
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
63-
# but if it doesn't `construct` will give a good error message.
62+
# `Foo{Diagaonal{E}}` etc. Official API for this? https://github.com/JuliaLang/julia/issues/35543
6463
wrapT = T.name.wrapper
65-
# Official API for this? https://github.com/JuliaLang/julia/issues/35543
6664
return ProjectTo{wrapT}(; fields_proj..., kw...)
6765
end
6866

@@ -252,11 +250,18 @@ end
252250

253251
# Ref
254252
# This can't be its own tangent, so it standardises on a Tangent{<:Ref}
255-
ProjectTo(x::Ref) = ProjectTo{Ref}(; reftype=typeof(x), x=ProjectTo(x[]))
256-
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.reftype}(; x=project.x(dx[]))
257-
(project::ProjectTo{Ref})(dx::Tangent) = Tangent{project.reftype}(; x=project.x(dx.x))
253+
function ProjectTo(x::Ref)
254+
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
255+
if sub isa ProjectTo{<:AbstractZero}
256+
return ProjectTo{NoTangent}()
257+
else
258+
return ProjectTo{Ref}(; type=typeof(x), x=sub)
259+
end
260+
end
261+
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[]))
262+
(project::ProjectTo{Ref})(dx::Tangent) = Tangent{project.type}(; x=project.x(dx.x))
258263
# Since this works like a zero-array in broadcasting, it should also accept a number:
259-
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.reftype}(; x=project.x(dx))
264+
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx))
260265

261266
#####
262267
##### `LinearAlgebra`

test/projection.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
123123
prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents
124124
@test prefvec(Ref(1:3)).x isa Vector{ComplexF64}
125125
@test_throws DimensionMismatch prefvec(Ref{Any}(1:5))
126+
127+
@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
128+
@test_broken ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
126129
end
127130

128131
#####

0 commit comments

Comments
 (0)