Skip to content

Commit addf6d9

Browse files
devmotionmcabbott
andauthored
Add ProjectTo(::NamedTuple) (#515)
* Add `ProjectTo(::NamedTuple)` Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> * Allow different order and subset of named tuples Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
1 parent 2dcd44b commit addf6d9

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.11.1"
3+
version = "1.11.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ end
287287
# Since this works like a zero-array in broadcasting, it should also accept a number:
288288
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx))
289289

290-
# Tuple
290+
# Tuple and NamedTuple
291291
function ProjectTo(x::Tuple)
292292
elements = map(ProjectTo, x)
293293
if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
@@ -296,10 +296,22 @@ function ProjectTo(x::Tuple)
296296
return ProjectTo{Tangent{typeof(x)}}(; elements=elements)
297297
end
298298
end
299+
function ProjectTo(x::NamedTuple)
300+
elements = map(ProjectTo, x)
301+
if Tuple(elements) isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
302+
return ProjectTo{NoTangent}()
303+
else
304+
return ProjectTo{Tangent{typeof(x)}}(; elements...)
305+
end
306+
end
307+
299308
# This method means that projection is re-applied to the contents of a Tangent.
300309
# We're not entirely sure whether this is every necessary; but it should be safe,
301310
# and should often compile away:
302-
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx))
311+
function (project::ProjectTo{<:Tangent{<:Union{Tuple,NamedTuple}}})(dx::Tangent)
312+
return project(backing(dx))
313+
end
314+
303315
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
304316
len = length(project.elements)
305317
if length(dx) != len
@@ -310,6 +322,45 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
310322
dy = map((f, x) -> f(x), project.elements, dx)
311323
return project_type(project)(dy...)
312324
end
325+
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple)
326+
dy = _project_namedtuple(backing(project), dx)
327+
return project_type(project)(; dy...)
328+
end
329+
330+
# Diffractor returns not necessarily a named tuple with all keys and of the same order as
331+
# the projector
332+
# Thus we can't use `map`
333+
function _project_namedtuple(f::NamedTuple{fn,ft}, x::NamedTuple{xn,xt}) where {fn,ft,xn,xt}
334+
if @generated
335+
vals = Any[
336+
if xn[i] in fn
337+
:(getfield(f, $(QuoteNode(xn[i])))(getfield(x, $(QuoteNode(xn[i])))))
338+
else
339+
throw(
340+
ArgumentError(
341+
"named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])",
342+
),
343+
)
344+
end for i in 1:length(xn)
345+
]
346+
:(NamedTuple{$xn}(($(vals...),)))
347+
else
348+
vals = ntuple(Val(length(xn))) do i
349+
name = xn[i]
350+
if name in fn
351+
getfield(f, name)(getfield(x, name))
352+
else
353+
throw(
354+
ArgumentError(
355+
"named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])",
356+
),
357+
)
358+
end
359+
end
360+
NamedTuple{xn}(vals)
361+
end
362+
end
363+
313364
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
314365
for d in 1:ndims(dx)
315366
if size(dx, d) != get(length(project.elements), d, 1)

test/projection.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,42 @@ struct NoSuperType end
160160
@test ProjectTo((true, [false])) isa ProjectTo{NoTangent}
161161
end
162162

163+
@testset "Base: NamedTuple" begin
164+
pt1 = @inferred(ProjectTo((a=1.0,)))
165+
@test @inferred(pt1((a=1 + im,))) ==
166+
Tangent{NamedTuple{(:a,),Tuple{Float64}}}(; a=1.0)
167+
@test @inferred(pt1(pt1((a=1,)))) == @inferred(pt1(pt1((a=1,)))) # accepts correct Tangent
168+
@test @inferred(pt1(Tangent{Any}(; a=1))) == pt1((a=1,)) # accepts Tangent{Any}
169+
@test @inferred(pt1(NoTangent())) === NoTangent()
170+
@test @inferred(pt1(ZeroTangent())) === ZeroTangent()
171+
172+
@test_throws Exception pt1((a=1, b=2)) # no projector for `b`
173+
@test_throws Exception pt1((b=1,)) # no projector for `b`
174+
175+
# subset is allowed (required for Diffractor)
176+
@test @inferred(pt1(NamedTuple())) === Tangent{NamedTuple{(:a,),Tuple{Float64}}}()
177+
178+
pt3 = @inferred(ProjectTo((a=[1, 2, 3], b=false, c=:gamma))) # partly non-differentiable
179+
@test @inferred(pt3((a=1:3, b=4, c=5))) ==
180+
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
181+
a=[1.0, 2.0, 3.0], b=NoTangent(), c=NoTangent()
182+
)
183+
184+
# different order
185+
@test @inferred(pt3((b=4, a=1:3, c=5))) ==
186+
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
187+
b=NoTangent(), a=[1.0, 2.0, 3.0], c=NoTangent()
188+
)
189+
190+
# only a subset
191+
@test @inferred(pt3((c=5,))) ==
192+
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
193+
c=NoTangent()
194+
)
195+
196+
@test @inferred(ProjectTo((a=true, b=[false]))) isa ProjectTo{NoTangent}
197+
end
198+
163199
@testset "Base: non-diff" begin
164200
@test ProjectTo(:a)(1) == NoTangent()
165201
@test ProjectTo('b')(2) == NoTangent()

0 commit comments

Comments
 (0)