Skip to content

Commit 9502d6a

Browse files
authored
Merge pull request #158 from JuliaDiff/ox/nicemessage
Avoid nesting testsets in test_rule
2 parents 27a1933 + 525ffa4 commit 9502d6a

File tree

5 files changed

+127
-45
lines changed

5 files changed

+127
-45
lines changed

src/ChainRulesTestUtils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ export TestIterator
1717
export check_equal, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
1818
export
1919

20+
2021
include("generate_tangent.jl")
2122
include("data_generation.jl")
2223
include("iterator.jl")
24+
25+
include("output_control.jl")
2326
include("check_result.jl")
2427

2528
include("finite_difference_calls.jl")

src/check_result.jl

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,55 @@
44
# Note that this must work well both on Differential types and Primal types
55

66
"""
7-
check_equal(actual, expected; kwargs...)
7+
check_equal(actual, expected, [msg]; kwargs...)
88
99
`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results
1010
are shown on failures.
1111
Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc.
12+
13+
If provided `msg` is printed on a failure. Often additional items are appended to `msg` to
14+
give bread-crumbs into nested structures.
15+
1216
All keyword arguments are passed to `isapprox`.
1317
"""
1418
function check_equal(
1519
actual::Union{AbstractArray{<:Number},Number},
16-
expected::Union{AbstractArray{<:Number},Number};
20+
expected::Union{AbstractArray{<:Number},Number},
21+
msg="";
1722
kwargs...,
1823
)
19-
@test isapprox(actual, expected; kwargs...)
24+
@test_msg msg isapprox(actual, expected; kwargs...)
2025
end
2126

2227
for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
23-
@eval function check_equal(actual::$T1, expected::$T2; kwargs...)
24-
return check_equal(unthunk(actual), unthunk(expected); kwargs...)
28+
@eval function check_equal(actual::$T1, expected::$T2, msg=""; kwargs...)
29+
return check_equal(unthunk(actual), unthunk(expected), msg; kwargs...)
2530
end
2631
end
2732

28-
check_equal(::ZeroTangent, x; kwargs...) = check_equal(zero(x), x; kwargs...)
29-
check_equal(x, ::ZeroTangent; kwargs...) = check_equal(x, zero(x); kwargs...)
30-
check_equal(x::ZeroTangent, y::ZeroTangent; kwargs...) = @test true
33+
check_equal(::ZeroTangent, x, msg=""; kwargs...) = check_equal(zero(x), x, msg; kwargs...)
34+
check_equal(x, ::ZeroTangent, msg=""; kwargs...) = check_equal(x, zero(x), msg; kwargs...)
35+
check_equal(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true
3136

3237
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
33-
check_equal(x::NoTangent, y::Nothing; kwargs...) = @test true
34-
check_equal(x::Nothing, y::NoTangent; kwargs...) = @test true
38+
check_equal(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true
39+
check_equal(x::Nothing, y::NoTangent, msg=""; kwargs...) = @test true
3540

3641
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
3742
# not yet been implemented
3843
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
39-
check_equal(x::ChainRulesCore.NotImplemented, y; kwargs...) = @test_broken x == y
40-
check_equal(x, y::ChainRulesCore.NotImplemented; kwargs...) = @test_broken x == y
44+
check_equal(x::ChainRulesCore.NotImplemented, y, msg=""; kwargs...) = @test_broken x == y
45+
check_equal(x, y::ChainRulesCore.NotImplemented, msg=""; kwargs...) = @test_broken x == y
4146
# In this case we check for equality (messages etc. have to be equal)
4247
function check_equal(
43-
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented; kwargs...
48+
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented, msg=""; kwargs...
4449
)
45-
return @test x == y
50+
return @test_msg msg x == y
4651
end
4752

4853
"""
4954
_can_pass_early(actual, expected; kwargs...)
50-
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
55+
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper
5156
and can just report `check_equal` as passing.
5257
5358
If either `==` or `≈` return true then so does this.
@@ -64,60 +69,71 @@ function _can_pass_early(actual, expected; kwargs...)
6469
return false
6570
end
6671

67-
function check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...)
72+
function check_equal(actual::AbstractArray, expected::AbstractArray, msg=""; kwargs...)
6873
if _can_pass_early(actual, expected)
6974
@test true
7075
else
71-
@test eachindex(actual) == eachindex(expected)
72-
@testset "$(typeof(actual))[$ii]" for ii in eachindex(actual)
73-
check_equal(actual[ii], expected[ii]; kwargs...)
76+
@test_msg "$msg: indices must match" eachindex(actual) == eachindex(expected)
77+
for ii in eachindex(actual)
78+
new_msg = "$msg $(typeof(actual))[$ii]"
79+
check_equal(actual[ii], expected[ii], new_msg; kwargs...)
7480
end
7581
end
7682
end
7783

78-
function check_equal(actual::Tangent{P}, expected::Tangent{P}; kwargs...) where {P}
84+
function check_equal(actual::Tangent{P}, expected::Tangent{P}, msg=""; kwargs...) where {P}
7985
if _can_pass_early(actual, expected)
8086
@test true
8187
else
8288
all_keys = union(keys(actual), keys(expected))
83-
@testset "$P.$ii" for ii in all_keys
84-
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
89+
for ii in all_keys
90+
new_msg = "$msg $P.$ii"
91+
check_equal(
92+
getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs...
93+
)
8594
end
8695
end
8796
end
8897

8998
function check_equal(
90-
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}; kwargs...
99+
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}, msg=""; kwargs...
91100
) where {ActualPrimal,ExpectedPrimal}
92101
# this will certainly fail as we have another dispatch for that, but this will give as
93102
# good error message
94103
@test ActualPrimal === ExpectedPrimal
95104
end
96105

97106
# Some structual differential and a natural differential
98-
function check_equal(actual::Tangent{P,T}, expected; kwargs...) where {T,P}
107+
function check_equal(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T,P}
99108
if _can_pass_early(actual, expected)
100109
@test true
101110
else
102111
@assert (T <: NamedTuple) # it should be a structual differential if we hit this
103112

104113
# We are only checking the properties that are in the Tangent
105114
# the natural differential is allowed to have other properties that we ignore
106-
@testset "$P.$ii" for ii in propertynames(actual)
107-
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
115+
for ii in propertynames(actual)
116+
new_msg = "$msg $P.$ii"
117+
check_equal(
118+
getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs...
119+
)
108120
end
109121
end
110122
end
111-
check_equal(x, y::Tangent; kwargs...) = check_equal(y, x; kwargs...)
123+
check_equal(x, y::Tangent, msg=""; kwargs...) = check_equal(y, x, msg; kwargs...)
112124

113125
# This catches comparisons of Tangents and Tuples/NamedTuple
114-
# and gives an error message complaining about that
126+
# and gives an error message complaining about that. the `@test` will definitely fail
115127
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
116-
check_equal(::C, ::T; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test C === T
117-
check_equal(::T, ::C; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test T === C
128+
function check_equal(x::Tangent, y::LegacyZygoteCompTypes, msg=""; kwargs...)
129+
@test_msg "$msg: for structural differentials use `Tangent`" typeof(x) === typeof(y)
130+
end
131+
function check_equal(x::LegacyZygoteCompTypes, y::Tangent, msg=""; kwargs...)
132+
return check_equal(y, x, msg; kwargs...)
133+
end
118134

119135
# Generic fallback, probably a tuple or something
120-
function check_equal(actual::A, expected::E; kwargs...) where {A,E}
136+
function check_equal(actual::A, expected::E, msg=""; kwargs...) where {A,E}
121137
if _can_pass_early(actual, expected)
122138
@test true
123139
else
@@ -130,6 +146,8 @@ function check_equal(actual::A, expected::E; kwargs...) where {A,E}
130146
end
131147
end
132148

149+
###########################################################################################
150+
133151
"""
134152
_check_add!!_behaviour(acc, val)
135153
@@ -146,19 +164,19 @@ function _check_add!!_behaviour(acc, val; kwargs...)
146164
# e.g. if it is immutable. We do test the `add!!` return value.
147165
# That is what people should rely on. The mutation is just to save allocations.
148166
acc_mutated = deepcopy(acc) # prevent this test changing others
149-
return check_equal(add!!(acc_mutated, val), acc + val; kwargs...)
167+
return check_equal(add!!(acc_mutated, val), acc + val, "in add!!"; kwargs...)
150168
end
151169

152-
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
153-
# not yet been implemented
170+
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has
171+
# intentionally not yet been implemented
154172
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
155173
function _check_add!!_behaviour(acc_mutated, acc::ChainRulesCore.NotImplemented; kwargs...)
156174
return @test_broken acc_mutated == acc
157175
end
158176
function _check_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc; kwargs...)
159177
return @test_broken acc_mutated == acc
160178
end
161-
# In this case we check for equality (messages etc. have to be equal)
179+
# In this case we check for equality (not implemented messages etc. have to be equal)
162180
function _check_add!!_behaviour(
163181
acc_mutated::ChainRulesCore.NotImplemented,
164182
acc::ChainRulesCore.NotImplemented;

src/output_control.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Test.get_test_result generates code that uses the following so we must import them
2+
using Test: Returned, Threw, eval_test
3+
4+
"A cunning hack to carry extra message along with the original expression in a test"
5+
struct ExprAndMsg
6+
ex
7+
msg
8+
end
9+
10+
"""
11+
@test_msg msg condion kws...
12+
13+
This is per `Test.@test condion kws...` except that if it fails it also prints the `msg`.
14+
If `msg==""` then this is just like `@test`, nothing is printed
15+
16+
### Examles
17+
```julia
18+
julia> @test_msg "It is required that the total is under 10" sum(1:1000) < 10;
19+
Test Failed at REPL[1]:1
20+
Expression: sum(1:1000) < 10
21+
Problem: It is required that the total is under 10
22+
Evaluated: 500500 < 10
23+
ERROR: There was an error during testing
24+
25+
26+
julia> @test_msg "It is required that the total is under 10" error("not working at all");
27+
Error During Test at REPL[2]:1
28+
Test threw exception
29+
Expression: error("not working at all")
30+
Problem: It is required that the total is under 10
31+
"not working at all"
32+
Stacktrace:
33+
34+
julia> a = "";
35+
36+
julia> @test_msg a sum(1:1000) < 10;
37+
Test Failed at REPL[153]:1
38+
Expression: sum(1:1000) < 10
39+
Evaluated: 500500 < 10
40+
ERROR: There was an error during testing
41+
```
42+
"""
43+
macro test_msg(msg, ex, kws...)
44+
# This code is basically a evil hack that accesses the internals of the Test stdlib.
45+
# Code below is based on the `@test` macro definition as it was in Julia 1.6.
46+
# https://github.com/JuliaLang/julia/blob/v1.6.1/stdlib/Test/src/Test.jl#L371-L376
47+
Test.test_expr!("@test_msg msg", ex, kws...)
48+
49+
result = Test.get_test_result(ex, __source__)
50+
return :(Test.do_test($result, $ExprAndMsg($(string(ex)), $(esc(msg)))))
51+
end
52+
53+
function Base.print(io::IO, x::ExprAndMsg)
54+
print(io, x.ex)
55+
!isempty(x.msg) && print(io, "\n Problem: ", x.msg)
56+
end
57+
58+
59+
### helpers for printing in log messages etc
60+
_string_typeof(x) = string(typeof(x))
61+
_string_typeof(xs::Tuple) = join(_string_typeof.(xs), ",")
62+
_string_typeof(x::PrimalAndTangent) = _string_typeof(primal(x)) # only show primal

src/testers.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function test_frule(
9999
# To simplify some of the calls we make later lets group the kwargs for reuse
100100
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
101101

102-
@testset "test_frule: $f on $(join(typeof.(inputs), ","))" begin
102+
@testset "test_frule: $f on $(_string_typeof(inputs))" begin
103103
_ensure_not_running_on_functor(f, "test_frule")
104104

105105
xẋs = auto_primal_and_tangent.(inputs)
@@ -167,7 +167,7 @@ function test_rrule(
167167
# To simplify some of the calls we make later lets group the kwargs for reuse
168168
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
169169

170-
@testset "test_rrule: $f on $(join(typeof.(inputs), ","))" begin
170+
@testset "test_rrule: $f on $(_string_typeof(inputs))" begin
171171
_ensure_not_running_on_functor(f, "test_rrule")
172172

173173
# Check correctness of evaluation.
@@ -226,12 +226,11 @@ function test_rrule(
226226
end
227227

228228
function check_thunking_is_appropriate(x̄s)
229-
@testset "Don't thunk only non_zero argument" begin
230-
num_zeros = count(x -> x isa AbstractZero, x̄s)
231-
num_thunks = count(x -> x isa Thunk, x̄s)
232-
if num_zeros + num_thunks == length(x̄s)
233-
@test num_thunks !== 1
234-
end
229+
num_zeros = count(x -> x isa AbstractZero, x̄s)
230+
num_thunks = count(x -> x isa Thunk, x̄s)
231+
if num_zeros + num_thunks == length(x̄s)
232+
# num_thunks can be either 0, or greater than 1.
233+
@test_msg "Should not thunk only non_zero argument" num_thunks != 1
235234
end
236235
end
237236

test/testers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# For some reason if these aren't defined here, then they are interpreted as closures
1+
# Defining test functions here as if they are defined where used it is too easy to
2+
# mistakenly create closures over variables that only share names by coincidence.
23
futestkws(x; err=true) = err ? error("futestkws_err") : x
34

45
fbtestkws(x, y; err=true) = err ? error("fbtestkws_err") : x
@@ -268,7 +269,6 @@ end
268269
return first(x), first_pullback
269270
end
270271

271-
#CTuple{N} = Tangent{NTuple{N, Float64}} # shorter for testing
272272
@testset "test_frule" begin
273273
test_frule(first, (2.0, 3.0))
274274
test_frule(first, Tuple(randn(4)))

0 commit comments

Comments
 (0)