Skip to content

Commit 6925da1

Browse files
authored
Fix test_approx buglets (#259)
* fix nested arrays * fix Tangent AbstractZero * v1.9.3
1 parent 3add381 commit 6925da1

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "1.9.2"
3+
version = "1.9.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/check_result.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,17 @@ for (T1, T2) in
4040
end
4141

4242
test_approx(::AbstractZero, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
43-
test_approx(::AbstractZero, x::AbstractArray{<:AbstractArray}, msg=""; kwargs...) = test_approx(map(zero, x), x, msg; kwargs...)
4443
test_approx(x, ::AbstractZero, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
45-
test_approx(x::AbstractArray{<:AbstractArray}, ::AbstractZero, msg=""; kwargs...) = test_approx(x, map(zero, x), msg; kwargs...)
4644
test_approx(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true
4745
test_approx(x::NoTangent, y::NoTangent, msg=""; kwargs...) = @test true
4846

47+
function test_approx(z::AbstractZero, x::AbstractArray{<:AbstractArray}, msg=""; kwargs...)
48+
for el in x
49+
test_approx(el, z, msg; kwargs...)
50+
end
51+
end
52+
test_approx(x::AbstractArray{<:AbstractArray}, z::AbstractZero, msg=""; kwargs...) = test_approx(z, x, msg; kwargs...)
53+
4954
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
5055
test_approx(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true
5156
test_approx(x::Nothing, y::NoTangent, msg=""; kwargs...) = @test true
@@ -134,8 +139,8 @@ function test_approx(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T
134139
end
135140
test_approx(x, y::Tangent, msg=""; kwargs...) = test_approx(y, x, msg; kwargs...)
136141

137-
test_approx(z::NoTangent, t::Tangent, msg=""; kwargs...) = all(==(NoTangent()), t)
138-
test_approx(t::Tangent, z::NoTangent, msg=""; kwargs...) = all(==(NoTangent()), t)
142+
test_approx(z::NoTangent, t::Tangent, msg=""; kwargs...) = @test all(==(NoTangent()), t)
143+
test_approx(t::Tangent, z::NoTangent, msg=""; kwargs...) = @test all(==(NoTangent()), t)
139144

140145
# This catches comparisons of Tangents and Tuples/NamedTuple
141146
# and gives an error message complaining about that. the `@test` will definitely fail

test/check_result.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ end
3838
test_approx([[1.0], [2.0]], [[1.0], [2.0]])
3939
test_approx([[0.0], [0.0]], ZeroTangent())
4040
test_approx(ZeroTangent(), [[0.0], [0.0]])
41+
test_approx(ZeroTangent(), [[0.0, 0.0], [[0.0, 0.0], [0.0, 0.0]]])
42+
test_approx([[0.0, 0.0], [[0.0, 0.0], [0.0, 0.0]]], NoTangent())
4143
test_approx(Broadcast.broadcasted(identity, [1.0 2.0; 3.0 4.0]), [1.0 2.0; 3.0 4.0])
4244

4345
test_approx(@thunk(10 * 0.1 * [[1.0], [2.0]]), [[1.0], [2.0]])
@@ -112,6 +114,17 @@ end
112114
@test fails(() -> test_approx([[1.0], [2.0]], [[1.1], [2.0]]))
113115
@test fails(() -> test_approx([[0.0], [0.1]], ZeroTangent()))
114116
@test fails(() -> test_approx(ZeroTangent(), [[0.1], [0.0]]))
117+
@test fails(() -> test_approx([[0.0], [0.0], [[0.0, 0.1], [0.0]]], ZeroTangent()))
118+
@test fails(() -> test_approx(ZeroTangent(), [[0.0], [0.0], [[0.0, 0.1], [0.0]]]))
119+
120+
@test fails(() -> test_approx(
121+
Tangent{Tuple{Float64,Float64}}(NoTangent(), 0.1),
122+
NoTangent(),
123+
))
124+
@test fails(() -> test_approx(
125+
NoTangent(),
126+
Tangent{Tuple{Float64,Float64}}(NoTangent(), 0.1),
127+
))
115128

116129
@test fails(() -> test_approx(@thunk(10 * [[1.0], [2.0]]), [[1.0], [2.0]]))
117130

0 commit comments

Comments
 (0)