Skip to content

Commit 1893e82

Browse files
authored
fix getproperty(::Tangent, ::Int) for NamedTuple (#460)
also fixes some issues with `getindex`, mainly that it didn't canonicalize tangents with NamedTuple backing, which I'd argue is a bug. It's also more consistent with `getproperty` now, so `tangent[:a]` returns `ZeroTangent()` instead of erroring if `tangent`'s backing doesn't have a field `a`. ref JuliaDiff/Diffractor.jl#39
1 parent 1db6c47 commit 1893e82

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.4.0"
3+
version = "1.5.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/differentials/composite.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,25 @@ function Base.show(io::IO, comp::Tangent{P}) where P
6868
end
6969
end
7070

71-
Base.getindex(comp::Tangent, idx) = getindex(backing(comp), idx)
71+
function Base.getindex(comp::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}}
72+
back = backing(canonicalize(comp))
73+
return unthunk(getfield(back, idx))
74+
end
75+
function Base.getindex(comp::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple}
76+
hasfield(T, idx) || return ZeroTangent()
77+
return unthunk(getfield(backing(comp), idx))
78+
end
79+
function Base.getindex(comp::Tangent, idx) where {P, T<:AbstractDict}
80+
return unthunk(getindex(backing(comp), idx))
81+
end
7282

73-
# for Tuple
74-
Base.getproperty(comp::Tangent, idx::Int) = unthunk(getproperty(backing(comp), idx))
75-
function Base.getproperty(
76-
comp::Tangent{P, T}, idx::Symbol
77-
) where {P, T<:NamedTuple}
83+
function Base.getproperty(comp::Tangent, idx::Int)
84+
back = backing(canonicalize(comp))
85+
return unthunk(getfield(back, idx))
86+
end
87+
function Base.getproperty(comp::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple}
7888
hasfield(T, idx) || return ZeroTangent()
79-
return unthunk(getproperty(backing(comp), idx))
89+
return unthunk(getfield(backing(comp), idx))
8090
end
8191

8292
Base.keys(comp::Tangent) = keys(backing(comp))

test/differentials/composite.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,21 @@ end
5454

5555
@test keys(Tangent{Tuple{Float64,}}(2.0)) == Base.OneTo(1)
5656
@test propertynames(Tangent{Tuple{Float64,}}(2.0)) == (1,)
57+
@test getindex(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0
58+
@test getindex(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0
5759
@test getproperty(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0
5860
@test getproperty(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0
59-
@test getproperty(Tangent{Tuple{Float64,}}(a=(@thunk 2.0^2),), :a) == 4.0
61+
62+
NT = NamedTuple{(:a, :b), Tuple{Float64, Float64}}
63+
@test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0
64+
@test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent()
65+
@test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent()
66+
@test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0
67+
68+
@test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0
69+
@test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent()
70+
@test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent()
71+
@test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0
6072

6173
# TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516
6274
@test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true

0 commit comments

Comments
 (0)