Skip to content

Commit f099941

Browse files
authored
Implement reverse and getindex for AbstractZero and thunks (#522)
1 parent a78f2d3 commit f099941

File tree

4 files changed

+25
-1
lines changed

4 files changed

+25
-1
lines changed

src/tangent_types/abstract_zero.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T)
3030
(::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y)
3131
(::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false)
3232

33-
Base.getindex(z::AbstractZero, k) = z
33+
Base.getindex(z::AbstractZero, args...) = z
3434

3535
Base.view(z::AbstractZero, ind...) = z
3636
Base.sum(z::AbstractZero; dims=:) = z
3737
Base.reshape(z::AbstractZero, size...) = z
38+
Base.reverse(z::AbstractZero, args...; kwargs...) = z
3839

3940
"""
4041
ZeroTangent() <: AbstractZero

src/tangent_types/thunks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Base.sum!(r, A::AbstractThunk; kws...) = sum!(r, unthunk(A); kws...)
4747
Base.fill(a::AbstractThunk, b::Integer) = fill(unthunk(a), b)
4848
Base.vec(a::AbstractThunk) = vec(unthunk(a))
4949
Base.reshape(a::AbstractThunk, args...) = reshape(unthunk(a), args...)
50+
Base.reverse(a::AbstractThunk, args...; kwargs...) = reverse(unthunk(a), args...; kwargs...)
5051
Base.getindex(a::AbstractThunk, args...) = getindex(unthunk(a), args...)
5152
Base.setindex!(a::AbstractThunk, value, key...) = throw(MutateThunkException())
5253
Base.selectdim(a::AbstractThunk, args...) = selectdim(unthunk(a), args...)

test/tangent_types/abstract_zero.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313

1414
@test reshape(NoTangent(), (1, :)) === NoTangent()
1515
@test reshape(ZeroTangent(), 2, 3, 4, 5) === ZeroTangent()
16+
17+
@test reverse(NoTangent()) === NoTangent()
18+
@test reverse(ZeroTangent()) === ZeroTangent()
19+
@test reverse(ZeroTangent(); dims=1) === ZeroTangent()
20+
@test reverse(ZeroTangent(), 2, 5) === ZeroTangent()
1621
end
1722

1823
@testset "ZeroTangent" begin
@@ -76,6 +81,11 @@
7681
@test convert(Int64, ZeroTangent()) === Int64(0)
7782
@test convert(Float32, ZeroTangent()) === 0.0f0
7883
@test convert(ComplexF64, ZeroTangent()) === 0.0 + 0.0im
84+
85+
@test z[1] === z
86+
@test z[1:3] === z
87+
@test z[1, 2] === z
88+
@test getindex(z) === z
7989
end
8090

8191
@testset "NoTangent" begin
@@ -115,6 +125,11 @@
115125

116126
@test convert(Int64, NoTangent()) == 0
117127
@test convert(Float64, NoTangent()) == 0.0
128+
129+
@test dne[1] === dne
130+
@test dne[1:3] === dne
131+
@test dne[1, 2] === dne
132+
@test getindex(dne) === dne
118133
end
119134

120135
@testset "ambiguities" begin

test/tangent_types/thunks.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@
107107
v = [1, 2, 3]
108108
t = @thunk(v)
109109

110+
m = rand(3, 3)
111+
tm = @thunk(m)
112+
110113
if VERSION >= v"1.2"
111114
@test 3 == mapreduce(_ -> 1, +, t)
112115
@test 3 == mapreduce((_, _) -> 1, +, v, t)
@@ -120,6 +123,10 @@
120123
@test 1 == getindex(t, 1)
121124
@test_throws MutateThunkException setindex!(t, 0.0, 1)
122125
@test [4; 5; 6] == selectdim([1 2 3; 4 5 6], 1, 2)
126+
127+
@test reverse(t) == reverse(v)
128+
@test reverse(t, 2) == reverse(v, 2)
129+
@test reverse(tm; dims=2) == reverse(m; dims=2)
123130
end
124131

125132
@testset "LinearAlgebra" begin

0 commit comments

Comments
 (0)