Skip to content

Commit 1ae20c3

Browse files
committed
accept Ref again, and fix up tests
1 parent be2bcb9 commit 1ae20c3

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

src/projection.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ function ProjectTo(x::Ref)
266266
end
267267
end
268268
(project::ProjectTo{Ref})(dx::Tangent) = Tangent{project.type}(; x=project.x(dx.x))
269+
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[]))
269270
# Since this works like a zero-array in broadcasting, it should also accept a number:
270271
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx))
271272

test/projection.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,19 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
117117

118118
@testset "Base: Ref" begin
119119
pref = ProjectTo(Ref(2.0))
120-
@test_skip pref(Ref(3 + im)).x === 3.0
120+
@test pref(Ref(3 + im)).x === 3.0
121+
@test pref(Tangent{Base.RefValue}(x = 3 + im)).x === 3.0
121122
@test pref(4).x === 4.0 # also re-wraps scalars
122-
@test_skip pref(Ref{Any}(5.0)) isa Tangent{<:Base.RefValue}
123+
@test pref(Ref{Any}(5.0)) isa Tangent{<:Base.RefValue}
124+
123125
pref2 = ProjectTo(Ref{Any}(6 + 7im))
124-
@test_skip pref2(Ref(8)).x === 8.0 + 0.0im
126+
@test pref2(Ref(8)).x === 8.0 + 0.0im
127+
@test pref2(Tangent{Base.RefValue}(x = 8)).x === 8.0 + 0.0im
125128

126129
prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents
127-
@test_skip prefvec(Ref(1:3)).x isa Vector{ComplexF64}
128-
@test_skip @test_throws DimensionMismatch prefvec(Ref{Any}(1:5))
130+
@test prefvec(Ref(1:3)).x isa Vector{ComplexF64}
131+
@test prefvec(Tangent{Base.RefValue}(x = 1:3)).x isa Vector{ComplexF64}
132+
@test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(x = 1:5))
129133

130134
@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
131135
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}

0 commit comments

Comments
 (0)