Skip to content

Commit 00e17d2

Browse files
authored
Handle Base.tail & friends (#567)
* handle Base.tail & friens * empty cases * tail on NamedTuples too
1 parent f2e3ac5 commit 00e17d2

File tree

7 files changed

+53
-2
lines changed

7 files changed

+53
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.15.2"
3+
version = "1.15.3"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_types/abstract_zero.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ Base.iszero(::AbstractZero) = true
1616
Base.iterate(x::AbstractZero) = (x, nothing)
1717
Base.iterate(::AbstractZero, ::Any) = nothing
1818

19+
Base.first(x::AbstractZero) = x
20+
Base.tail(x::AbstractZero) = x
21+
Base.last(x::AbstractZero) = x
22+
1923
Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x)
2024
Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T()
2125

src/tangent_types/tangent.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ function Base.show(io::IO, tangent::Tangent{P}) where {P}
9696
end
9797
end
9898

99+
Base.first(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = first(backing(canonicalize(tangent)))
100+
Base.last(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = last(backing(canonicalize(tangent)))
101+
102+
Base.tail(t::Tangent{P}) where {P<:Tuple} = Tangent{_tailtype(P)}(Base.tail(backing(canonicalize(t)))...)
103+
@generated _tailtype(::Type{P}) where {P<:Tuple} = Tuple{P.parameters[2:end]...}
104+
Base.tail(t::Tangent{<:Tuple{Any}}) = NoTangent()
105+
Base.tail(t::Tangent{<:Tuple{}}) = NoTangent()
106+
107+
Base.tail(t::Tangent{P}) where {P<:NamedTuple} = Tangent{_tailtype(P)}(; Base.tail(backing(canonicalize(t)))...)
108+
_tailtype(::Type{NamedTuple{S,P}}) where {S,P} = NamedTuple{Base.tail(S), _tailtype(P)}
109+
Base.tail(t::Tangent{<:NamedTuple{<:Any, <:Tuple{Any}}}) = NoTangent()
110+
Base.tail(t::Tangent{<:NamedTuple{<:Any, <:Tuple{}}}) = NoTangent()
111+
99112
function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}}
100113
back = backing(canonicalize(tangent))
101114
return unthunk(getfield(back, idx))
@@ -127,6 +140,7 @@ end
127140

128141
Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...)
129142
Base.length(tangent::Tangent) = length(backing(tangent))
143+
130144
Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T)
131145
function Base.reverse(tangent::Tangent)
132146
rev_backing = reverse(backing(tangent))

src/tangent_types/thunks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ end
2424
return element, (underlying_object, new_state)
2525
end
2626

27+
Base.first(x::AbstractThunk) = first(unthunk(x))
28+
Base.last(x::AbstractThunk) = last(unthunk(x))
29+
Base.tail(x::AbstractThunk) = Base.tail(unthunk(x))
30+
2731
Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b)
2832

2933
Base.:(-)(a::AbstractThunk) = -unthunk(a)

test/tangent_types/abstract_zero.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@
8686
@test z[1:3] === z
8787
@test z[1, 2] === z
8888
@test getindex(z) === z
89+
90+
@test first(z) === z
91+
@test last(z) === z
92+
@test Base.tail(z) === z
8993
end
9094

9195
@testset "NoTangent" begin

test/tangent_types/tangent.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,22 @@ end
7272
end
7373
@test Tangent{Foo}(; x=2.5).x == 2.5
7474

75-
@test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1)
75+
tang1 = Tangent{Tuple{Float64}}(2.0)
76+
@test keys(tang1) == Base.OneTo(1)
7677
@test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,)
7778
@test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0
7879
@test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0
7980
@test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0
8081
@test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0
82+
@test NoTangent() === @inferred Base.tail(tang1)
83+
@test NoTangent() === @inferred Base.tail(Tangent{Tuple{}}())
84+
85+
tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4)
86+
@test @inferred(first(tang3)) === tang3[1] === 1.0
87+
@test @inferred(last(tang3)) isa Thunk
88+
@test unthunk(last(tang3)) == [7.0]
89+
@test Tuple(@inferred Base.tail(tang3))[1] === NoTangent()
90+
@test Tuple(Base.tail(tang3))[end] isa Thunk
8191

8292
NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}}
8393
@test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0
@@ -89,6 +99,14 @@ end
8999
@test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent()
90100
@test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent()
91101
@test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0
102+
103+
@test first(Tangent{NT}(; a=(@thunk 2.0^2))) isa Thunk
104+
@test unthunk(first(Tangent{NT}(; a=(@thunk 2.0^2)))) == 4.0
105+
@test last(Tangent{NT}(; a=(@thunk 2.0^2))) isa ZeroTangent
106+
107+
ntang1 = @inferred Base.tail(Tangent{NT}(; b=(@thunk 2.0^2)))
108+
@test ntang1 isa Tangent{<:NamedTuple{(:b,)}}
109+
@test NoTangent() === @inferred Base.tail(ntang1)
92110

93111
# TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516
94112
if VERSION >= v"1.8-"

test/tangent_types/thunks.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616

1717
@test nothing === iterate(@thunk ()) == iterate(())
1818
end
19+
20+
@testset "first, last, tail" begin
21+
@test first(@thunk (1,2,3) .+ 4) === 5
22+
@test last(@thunk (1,2,3) .+ 4) === 7
23+
@test Base.tail(@thunk (1,2,3) .+ 4) === (6, 7)
24+
@test Base.tail(@thunk NoTangent() * 5) === NoTangent()
25+
end
1926

2027
@testset "show" begin
2128
rep = repr(Thunk(rand))

0 commit comments

Comments
 (0)