Skip to content

Commit 439d036

Browse files
committed
Avoid nesting testsets in test_rule
1 parent d72b146 commit 439d036

File tree

5 files changed

+125
-45
lines changed

5 files changed

+125
-45
lines changed

src/ChainRulesTestUtils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ export TestIterator
1717
export check_equal, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
1818
export
1919

20+
2021
include("generate_tangent.jl")
2122
include("data_generation.jl")
2223
include("iterator.jl")
24+
25+
include("output_control.jl")
2326
include("check_result.jl")
2427

2528
include("finite_difference_calls.jl")

src/check_result.jl

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,56 @@
44
# Note that this must work well both on Differential types and Primal types
55

66
"""
7-
check_equal(actual, expected; kwargs...)
7+
check_equal(actual, expected, [msg]; kwargs...)
88
99
`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results
1010
are shown on failures.
1111
Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc.
12+
13+
If provided `msg` is printed on a failure. Often additional items are appended to `msg` to
14+
give bread-crumbs into nested structures.
15+
1216
All keyword arguments are passed to `isapprox`.
1317
"""
1418
function check_equal(
1519
actual::Union{AbstractArray{<:Number},Number},
16-
expected::Union{AbstractArray{<:Number},Number};
20+
expected::Union{AbstractArray{<:Number},Number},
21+
msg="",
22+
;
1723
kwargs...,
1824
)
19-
@test isapprox(actual, expected; kwargs...)
25+
@test_msg msg isapprox(actual, expected; kwargs...)
2026
end
2127

2228
for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
23-
@eval function check_equal(actual::$T1, expected::$T2; kwargs...)
24-
return check_equal(unthunk(actual), unthunk(expected); kwargs...)
29+
@eval function check_equal(actual::$T1, expected::$T2, msg=""; kwargs...)
30+
return check_equal(unthunk(actual), unthunk(expected), msg; kwargs...)
2531
end
2632
end
2733

28-
check_equal(::ZeroTangent, x; kwargs...) = check_equal(zero(x), x; kwargs...)
29-
check_equal(x, ::ZeroTangent; kwargs...) = check_equal(x, zero(x); kwargs...)
30-
check_equal(x::ZeroTangent, y::ZeroTangent; kwargs...) = @test true
34+
check_equal(::ZeroTangent, x, msg=""; kwargs...) = check_equal(zero(x), x, msg; kwargs...)
35+
check_equal(x, ::ZeroTangent, msg=""; kwargs...) = check_equal(x, zero(x), msg; kwargs...)
36+
check_equal(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true
3137

3238
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
33-
check_equal(x::NoTangent, y::Nothing; kwargs...) = @test true
34-
check_equal(x::Nothing, y::NoTangent; kwargs...) = @test true
39+
check_equal(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true
40+
check_equal(x::Nothing, y::NoTangent, msg=""; kwargs...) = @test true
3541

3642
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
3743
# not yet been implemented
3844
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
39-
check_equal(x::ChainRulesCore.NotImplemented, y; kwargs...) = @test_broken x == y
40-
check_equal(x, y::ChainRulesCore.NotImplemented; kwargs...) = @test_broken x == y
45+
check_equal(x::ChainRulesCore.NotImplemented, y, msg=""; kwargs...) = @test_broken x == y
46+
check_equal(x, y::ChainRulesCore.NotImplemented, msg=""; kwargs...) = @test_broken x == y
4147
# In this case we check for equality (messages etc. have to be equal)
4248
function check_equal(
43-
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented; kwargs...
49+
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented, msg=""; kwargs...
4450
)
45-
return @test x == y
51+
return @test_msg msg x == y
4652
end
4753

4854
"""
4955
_can_pass_early(actual, expected; kwargs...)
50-
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
56+
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper
5157
and can just report `check_equal` as passing.
5258
5359
If either `==` or `≈` return true then so does this.
@@ -64,60 +70,71 @@ function _can_pass_early(actual, expected; kwargs...)
6470
return false
6571
end
6672

67-
function check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...)
73+
function check_equal(actual::AbstractArray, expected::AbstractArray, msg=""; kwargs...)
6874
if _can_pass_early(actual, expected)
6975
@test true
7076
else
71-
@test eachindex(actual) == eachindex(expected)
72-
@testset "$(typeof(actual))[$ii]" for ii in eachindex(actual)
73-
check_equal(actual[ii], expected[ii]; kwargs...)
77+
@test_msg "$msg: indexes must match" eachindex(actual) == eachindex(expected)
78+
for ii in eachindex(actual)
79+
new_msg = "$msg $(typeof(actual))[$ii]"
80+
check_equal(actual[ii], expected[ii], new_msg; kwargs...)
7481
end
7582
end
7683
end
7784

78-
function check_equal(actual::Tangent{P}, expected::Tangent{P}; kwargs...) where {P}
85+
function check_equal(actual::Tangent{P}, expected::Tangent{P}, msg=""; kwargs...) where {P}
7986
if _can_pass_early(actual, expected)
8087
@test true
8188
else
8289
all_keys = union(keys(actual), keys(expected))
83-
@testset "$P.$ii" for ii in all_keys
84-
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
90+
for ii in all_keys
91+
new_msg = "$msg $P.$ii"
92+
check_equal(
93+
getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs...
94+
)
8595
end
8696
end
8797
end
8898

8999
function check_equal(
90-
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}; kwargs...
100+
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}, msg=""; kwargs...
91101
) where {ActualPrimal,ExpectedPrimal}
92102
# this will certainly fail as we have another dispatch for that, but this will give as
93103
# good error message
94104
@test ActualPrimal === ExpectedPrimal
95105
end
96106

97107
# Some structual differential and a natural differential
98-
function check_equal(actual::Tangent{P,T}, expected; kwargs...) where {T,P}
108+
function check_equal(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T,P}
99109
if _can_pass_early(actual, expected)
100110
@test true
101111
else
102112
@assert (T <: NamedTuple) # it should be a structual differential if we hit this
103113

104114
# We are only checking the properties that are in the Tangent
105115
# the natural differential is allowed to have other properties that we ignore
106-
@testset "$P.$ii" for ii in propertynames(actual)
107-
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
116+
for ii in propertynames(actual)
117+
new_msg = "$msg $P.$ii"
118+
check_equal(
119+
getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs...
120+
)
108121
end
109122
end
110123
end
111-
check_equal(x, y::Tangent; kwargs...) = check_equal(y, x; kwargs...)
124+
check_equal(x, y::Tangent, msg=""; kwargs...) = check_equal(y, x, msg; kwargs...)
112125

113126
# This catches comparisons of Tangents and Tuples/NamedTuple
114-
# and gives an error message complaining about that
127+
# and gives an error message complaining about that. the `@test` will definately fail
115128
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
116-
check_equal(::C, ::T; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test C === T
117-
check_equal(::T, ::C; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test T === C
129+
function check_equal(x::Tangent, y::LegacyZygoteCompTypes, msg=""; kwargs...)
130+
@test_msg "$msg: for structural differentials use `Tangent`" typeof(x) === typeof(y)
131+
end
132+
function check_equal(x::LegacyZygoteCompTypes, y::Tangent, msg=""; kwargs...)
133+
return check_equal(y, x, msg; kwargs...)
134+
end
118135

119136
# Generic fallback, probably a tuple or something
120-
function check_equal(actual::A, expected::E; kwargs...) where {A,E}
137+
function check_equal(actual::A, expected::E, msg=""; kwargs...) where {A,E}
121138
if _can_pass_early(actual, expected)
122139
@test true
123140
else
@@ -130,6 +147,8 @@ function check_equal(actual::A, expected::E; kwargs...) where {A,E}
130147
end
131148
end
132149

150+
###########################################################################################
151+
133152
"""
134153
_check_add!!_behaviour(acc, val)
135154
@@ -146,19 +165,19 @@ function _check_add!!_behaviour(acc, val; kwargs...)
146165
# e.g. if it is immutable. We do test the `add!!` return value.
147166
# That is what people should rely on. The mutation is just to save allocations.
148167
acc_mutated = deepcopy(acc) # prevent this test changing others
149-
return check_equal(add!!(acc_mutated, val), acc + val; kwargs...)
168+
return check_equal(add!!(acc_mutated, val), acc + val, "in add!!"; kwargs...)
150169
end
151170

152-
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
153-
# not yet been implemented
171+
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has
172+
# intentionally not yet been implemented
154173
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
155174
function _check_add!!_behaviour(acc_mutated, acc::ChainRulesCore.NotImplemented; kwargs...)
156175
return @test_broken acc_mutated == acc
157176
end
158177
function _check_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc; kwargs...)
159178
return @test_broken acc_mutated == acc
160179
end
161-
# In this case we check for equality (messages etc. have to be equal)
180+
# In this case we check for equality (not implemented messages etc. have to be equal)
162181
function _check_add!!_behaviour(
163182
acc_mutated::ChainRulesCore.NotImplemented,
164183
acc::ChainRulesCore.NotImplemented;

src/output_control.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Test.get_test_result generates code that uses the following so we must import them
2+
using Test: Returned, Threw, eval_test
3+
4+
"A cunning hack to carry extra message along with the original expression in a test"
5+
struct ExprAndMsg
6+
ex
7+
msg
8+
end
9+
10+
"""
11+
@test_msg msg condion kws...
12+
13+
This is per `Test.@test condion kws...` except that if it fails it also prints the `msg`.
14+
If `msg==""` then this is just like `@test`, nothing is printed
15+
16+
### Examles
17+
```julia
18+
julia> @test_msg "It is required that the total is under 10" sum(1:1000) < 10;
19+
Test Failed at REPL[1]:1
20+
Expression: sum(1:1000) < 10
21+
Problem: It is required that the total is under 10
22+
Evaluated: 500500 < 10
23+
ERROR: There was an error during testing
24+
25+
26+
julia> @test_msg "It is required that the total is under 10" error("not working at all");
27+
Error During Test at REPL[2]:1
28+
Test threw exception
29+
Expression: error("not working at all")
30+
Problem: It is required that the total is under 10
31+
"not working at all"
32+
Stacktrace:
33+
34+
julia> a = "";
35+
36+
julia> @test_msg a sum(1:1000) < 10;
37+
Test Failed at REPL[153]:1
38+
Expression: sum(1:1000) < 10
39+
Evaluated: 500500 < 10
40+
ERROR: There was an error during testing
41+
```
42+
"""
43+
macro test_msg(msg, ex, kws...)
44+
Test.test_expr!("@test_msg msg", ex, kws...)
45+
46+
result = Test.get_test_result(ex, __source__)
47+
return :(Test.do_test($result, $ExprAndMsg($(string(ex)), $(esc(msg)))))
48+
end
49+
50+
function Base.print(io::IO, x::ExprAndMsg)
51+
print(io, x.ex)
52+
!isempty(x.msg) && print(io, "\n Problem: ", x.msg)
53+
end
54+
55+
56+
### helpers for printing in log messages etc
57+
_string_typeof(x) = string(typeof(x))
58+
_string_typeof(xs::Tuple) = join(_string_typeof.(xs), ",")
59+
_string_typeof(x::PrimalAndTangent) = _string_typeof(primal(x)) # only show primal

src/testers.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function test_frule(
9999
# To simplify some of the calls we make later lets group the kwargs for reuse
100100
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
101101

102-
@testset "test_frule: $f on $(join(typeof.(inputs), ","))" begin
102+
@testset "test_frule: $f on $(_string_typeof(inputs))" begin
103103
_ensure_not_running_on_functor(f, "test_frule")
104104

105105
xẋs = auto_primal_and_tangent.(inputs)
@@ -167,7 +167,7 @@ function test_rrule(
167167
# To simplify some of the calls we make later lets group the kwargs for reuse
168168
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
169169

170-
@testset "test_rrule: $f on $(join(typeof.(inputs), ","))" begin
170+
@testset "test_rrule: $f on $(_string_typeof(inputs))" begin
171171
_ensure_not_running_on_functor(f, "test_rrule")
172172

173173
# Check correctness of evaluation.
@@ -226,12 +226,11 @@ function test_rrule(
226226
end
227227

228228
function check_thunking_is_appropriate(x̄s)
229-
@testset "Don't thunk only non_zero argument" begin
230-
num_zeros = count(x -> x isa AbstractZero, x̄s)
231-
num_thunks = count(x -> x isa Thunk, x̄s)
232-
if num_zeros + num_thunks == length(x̄s)
233-
@test num_thunks !== 1
234-
end
229+
num_zeros = count(x -> x isa AbstractZero, x̄s)
230+
num_thunks = count(x -> x isa Thunk, x̄s)
231+
if num_zeros + num_thunks == length(x̄s)
232+
# num_thunks can be either 0, or greater than 1.
233+
@test_msg "Should not thunk only non_zero argument" num_thunks != 1
235234
end
236235
end
237236

test/testers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# For some reason if these aren't defined here, then they are interpreted as closures
1+
# Defining test functions here as if they are defined where used it is too easy to
2+
# mistakenly create closures over variables that only share names by coincidence.
23
futestkws(x; err=true) = err ? error("futestkws_err") : x
34

45
fbtestkws(x, y; err=true) = err ? error("fbtestkws_err") : x
@@ -268,7 +269,6 @@ end
268269
return first(x), first_pullback
269270
end
270271

271-
#CTuple{N} = Tangent{NTuple{N, Float64}} # shorter for testing
272272
@testset "test_frule" begin
273273
test_frule(first, (2.0, 3.0))
274274
test_frule(first, Tuple(randn(4)))

0 commit comments

Comments
 (0)