Skip to content

Commit 38c4089

Browse files
mcabbottoxinabox
andauthored
ProjectTo{<:Tangent} for tuples & Ref (#488)
* tuple -> tangent, take 1 * Apply 4 suggestions Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * test + version * rm "only" for 1.0 * version + comments Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent d0c3599 commit 38c4089

File tree

3 files changed

+62
-8
lines changed

3 files changed

+62
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.8.0"
3+
version = "1.9.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,18 +272,55 @@ end
272272
#####
273273

274274
# Ref
275+
# Note that Ref is mutable. This causes Zygote to represent its structral tangent not as a NamedTuple,
276+
# but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see
277+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
275278
function ProjectTo(x::Ref)
276279
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
277-
if sub isa ProjectTo{<:AbstractZero}
280+
return ProjectTo{Tangent{typeof(x)}}(; x=sub)
281+
end
282+
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx))))
283+
function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref)
284+
dy = project.x(dx[])
285+
return project_type(project)(; x=dy)
286+
end
287+
# Since this works like a zero-array in broadcasting, it should also accept a number:
288+
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx))
289+
290+
# Tuple
291+
function ProjectTo(x::Tuple)
292+
elements = map(ProjectTo, x)
293+
if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
278294
return ProjectTo{NoTangent}()
279295
else
280-
return ProjectTo{Ref}(; type=typeof(x), x=sub)
296+
return ProjectTo{Tangent{typeof(x)}}(; elements=elements)
281297
end
282298
end
283-
(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x))
284-
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[]))
285-
# Since this works like a zero-array in broadcasting, it should also accept a number:
286-
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx))
299+
# This method means that projection is re-applied to the contents of a Tangent.
300+
# We're not entirely sure whether this is every necessary; but it should be safe,
301+
# and should often compile away:
302+
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx))
303+
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
304+
len = length(project.elements)
305+
if length(dx) != len
306+
str = "tuple with length(x) == $len cannot have a gradient with length(dx) == $(length(dx))"
307+
throw(DimensionMismatch(str))
308+
end
309+
# Here map will fail if the lengths don't match, but gives a much less helpful error:
310+
dy = map((f, x) -> f(x), project.elements, dx)
311+
return project_type(project)(dy...)
312+
end
313+
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
314+
for d in 1:ndims(dx)
315+
if size(dx, d) != get(length(project.elements), d, 1)
316+
throw(_projection_mismatch(axes(project.elements), size(dx)))
317+
end
318+
end
319+
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
320+
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
321+
return project_type(project)(dz...)
322+
end
323+
287324

288325
#####
289326
##### `LinearAlgebra`

test/projection.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,29 @@ struct NoSuperType end
137137
prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents
138138
@test prefvec(Ref(1:3)).x isa Vector{ComplexF64}
139139
@test prefvec(Tangent{Base.RefValue}(; x=1:3)).x isa Vector{ComplexF64}
140-
@test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5))
140+
@test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5))
141141

142142
@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
143143
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
144144
end
145145

146+
@testset "Base: Tuple" begin
147+
pt1 = ProjectTo((1.0,))
148+
@test pt1((1 + im,)) == Tangent{Tuple{Float64}}(1.0,)
149+
@test pt1(pt1((1,))) == pt1(pt1((1,))) # accepts correct Tangent
150+
@test pt1(Tangent{Any}(1)) == pt1((1,)) # accepts Tangent{Any}
151+
@test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector
152+
@test pt1(NoTangent()) === NoTangent()
153+
@test pt1(ZeroTangent()) === ZeroTangent()
154+
155+
@test_throws Exception pt1([1, 2]) # DimensionMismatch, wrong length
156+
@test_throws Exception pt1([])
157+
158+
pt3 = ProjectTo(([1, 2, 3], false, :gamma)) # partly non-differentiable
159+
@test pt3((1:3, 4, 5)) == Tangent{Tuple{Vector{Int}, Bool, Symbol}}([1.0, 2.0, 3.0], NoTangent(), NoTangent())
160+
@test ProjectTo((true, [false])) isa ProjectTo{NoTangent}
161+
end
162+
146163
@testset "Base: non-diff" begin
147164
@test ProjectTo(:a)(1) == NoTangent()
148165
@test ProjectTo('b')(2) == NoTangent()

0 commit comments

Comments
 (0)