4
4
# Note that this must work well both on Differential types and Primal types
5
5
6
6
"""
7
- check_equal(actual, expected; kwargs...)
7
+ check_equal(actual, expected, [msg] ; kwargs...)
8
8
9
9
`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results
10
10
are shown on failures.
11
11
Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc.
12
+
13
+ If provided `msg` is printed on a failure. Often additional items are appended to `msg` to
14
+ give bread-crumbs into nested structures.
15
+
12
16
All keyword arguments are passed to `isapprox`.
13
17
"""
14
18
function check_equal (
15
19
actual:: Union{AbstractArray{<:Number},Number} ,
16
- expected:: Union{AbstractArray{<:Number},Number} ;
20
+ expected:: Union{AbstractArray{<:Number},Number} ,
21
+ msg= " " ;
17
22
kwargs... ,
18
23
)
19
- @test isapprox (actual, expected; kwargs... )
24
+ @test_msg msg isapprox (actual, expected; kwargs... )
20
25
end
21
26
22
27
for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
23
- @eval function check_equal (actual:: $T1 , expected:: $T2 ; kwargs... )
24
- return check_equal (unthunk (actual), unthunk (expected); kwargs... )
28
+ @eval function check_equal (actual:: $T1 , expected:: $T2 , msg = " " ; kwargs... )
29
+ return check_equal (unthunk (actual), unthunk (expected), msg ; kwargs... )
25
30
end
26
31
end
27
32
28
- check_equal (:: ZeroTangent , x; kwargs... ) = check_equal (zero (x), x; kwargs... )
29
- check_equal (x, :: ZeroTangent ; kwargs... ) = check_equal (x, zero (x); kwargs... )
30
- check_equal (x:: ZeroTangent , y:: ZeroTangent ; kwargs... ) = @test true
33
+ check_equal (:: ZeroTangent , x, msg = " " ; kwargs... ) = check_equal (zero (x), x, msg ; kwargs... )
34
+ check_equal (x, :: ZeroTangent , msg = " " ; kwargs... ) = check_equal (x, zero (x), msg ; kwargs... )
35
+ check_equal (x:: ZeroTangent , y:: ZeroTangent , msg = " " ; kwargs... ) = @test true
31
36
32
37
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
33
- check_equal (x:: NoTangent , y:: Nothing ; kwargs... ) = @test true
34
- check_equal (x:: Nothing , y:: NoTangent ; kwargs... ) = @test true
38
+ check_equal (x:: NoTangent , y:: Nothing , msg = " " ; kwargs... ) = @test true
39
+ check_equal (x:: Nothing , y:: NoTangent , msg = " " ; kwargs... ) = @test true
35
40
36
41
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
37
42
# not yet been implemented
38
43
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
39
- check_equal (x:: ChainRulesCore.NotImplemented , y; kwargs... ) = @test_broken x == y
40
- check_equal (x, y:: ChainRulesCore.NotImplemented ; kwargs... ) = @test_broken x == y
44
+ check_equal (x:: ChainRulesCore.NotImplemented , y, msg = " " ; kwargs... ) = @test_broken x == y
45
+ check_equal (x, y:: ChainRulesCore.NotImplemented , msg = " " ; kwargs... ) = @test_broken x == y
41
46
# In this case we check for equality (messages etc. have to be equal)
42
47
function check_equal (
43
- x:: ChainRulesCore.NotImplemented , y:: ChainRulesCore.NotImplemented ; kwargs...
48
+ x:: ChainRulesCore.NotImplemented , y:: ChainRulesCore.NotImplemented , msg = " " ; kwargs...
44
49
)
45
- return @test x == y
50
+ return @test_msg msg x == y
46
51
end
47
52
48
53
"""
49
54
_can_pass_early(actual, expected; kwargs...)
50
- Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
55
+ Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper
51
56
and can just report `check_equal` as passing.
52
57
53
58
If either `==` or `≈` return true then so does this.
@@ -64,60 +69,71 @@ function _can_pass_early(actual, expected; kwargs...)
64
69
return false
65
70
end
66
71
67
- function check_equal (actual:: AbstractArray , expected:: AbstractArray ; kwargs... )
72
+ function check_equal (actual:: AbstractArray , expected:: AbstractArray , msg = " " ; kwargs... )
68
73
if _can_pass_early (actual, expected)
69
74
@test true
70
75
else
71
- @test eachindex (actual) == eachindex (expected)
72
- @testset " $(typeof (actual)) [$ii ]" for ii in eachindex (actual)
73
- check_equal (actual[ii], expected[ii]; kwargs... )
76
+ @test_msg " $msg : indices must match" eachindex (actual) == eachindex (expected)
77
+ for ii in eachindex (actual)
78
+ new_msg = " $msg $(typeof (actual)) [$ii ]"
79
+ check_equal (actual[ii], expected[ii], new_msg; kwargs... )
74
80
end
75
81
end
76
82
end
77
83
78
- function check_equal (actual:: Tangent{P} , expected:: Tangent{P} ; kwargs... ) where {P}
84
+ function check_equal (actual:: Tangent{P} , expected:: Tangent{P} , msg = " " ; kwargs... ) where {P}
79
85
if _can_pass_early (actual, expected)
80
86
@test true
81
87
else
82
88
all_keys = union (keys (actual), keys (expected))
83
- @testset " $P .$ii " for ii in all_keys
84
- check_equal (getproperty (actual, ii), getproperty (expected, ii); kwargs... )
89
+ for ii in all_keys
90
+ new_msg = " $msg $P .$ii "
91
+ check_equal (
92
+ getproperty (actual, ii), getproperty (expected, ii), new_msg; kwargs...
93
+ )
85
94
end
86
95
end
87
96
end
88
97
89
98
function check_equal (
90
- :: Tangent{ActualPrimal} , expected:: Tangent{ExpectedPrimal} ; kwargs...
99
+ :: Tangent{ActualPrimal} , expected:: Tangent{ExpectedPrimal} , msg = " " ; kwargs...
91
100
) where {ActualPrimal,ExpectedPrimal}
92
101
# this will certainly fail as we have another dispatch for that, but this will give as
93
102
# good error message
94
103
@test ActualPrimal === ExpectedPrimal
95
104
end
96
105
97
106
# Some structual differential and a natural differential
98
- function check_equal (actual:: Tangent{P,T} , expected; kwargs... ) where {T,P}
107
+ function check_equal (actual:: Tangent{P,T} , expected, msg = " " ; kwargs... ) where {T,P}
99
108
if _can_pass_early (actual, expected)
100
109
@test true
101
110
else
102
111
@assert (T <: NamedTuple ) # it should be a structual differential if we hit this
103
112
104
113
# We are only checking the properties that are in the Tangent
105
114
# the natural differential is allowed to have other properties that we ignore
106
- @testset " $P .$ii " for ii in propertynames (actual)
107
- check_equal (getproperty (actual, ii), getproperty (expected, ii); kwargs... )
115
+ for ii in propertynames (actual)
116
+ new_msg = " $msg $P .$ii "
117
+ check_equal (
118
+ getproperty (actual, ii), getproperty (expected, ii), new_msg; kwargs...
119
+ )
108
120
end
109
121
end
110
122
end
111
- check_equal (x, y:: Tangent ; kwargs... ) = check_equal (y, x; kwargs... )
123
+ check_equal (x, y:: Tangent , msg = " " ; kwargs... ) = check_equal (y, x, msg ; kwargs... )
112
124
113
125
# This catches comparisons of Tangents and Tuples/NamedTuple
114
- # and gives an error message complaining about that
126
+ # and gives an error message complaining about that. the `@test` will definitely fail
115
127
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
116
- check_equal (:: C , :: T ; kwargs... ) where {C<: Tangent ,T<: LegacyZygoteCompTypes } = @test C === T
117
- check_equal (:: T , :: C ; kwargs... ) where {C<: Tangent ,T<: LegacyZygoteCompTypes } = @test T === C
128
+ function check_equal (x:: Tangent , y:: LegacyZygoteCompTypes , msg= " " ; kwargs... )
129
+ @test_msg " $msg : for structural differentials use `Tangent`" typeof (x) === typeof (y)
130
+ end
131
+ function check_equal (x:: LegacyZygoteCompTypes , y:: Tangent , msg= " " ; kwargs... )
132
+ return check_equal (y, x, msg; kwargs... )
133
+ end
118
134
119
135
# Generic fallback, probably a tuple or something
120
- function check_equal (actual:: A , expected:: E ; kwargs... ) where {A,E}
136
+ function check_equal (actual:: A , expected:: E , msg = " " ; kwargs... ) where {A,E}
121
137
if _can_pass_early (actual, expected)
122
138
@test true
123
139
else
@@ -130,6 +146,8 @@ function check_equal(actual::A, expected::E; kwargs...) where {A,E}
130
146
end
131
147
end
132
148
149
+ # ##########################################################################################
150
+
133
151
"""
134
152
_check_add!!_behaviour(acc, val)
135
153
@@ -146,19 +164,19 @@ function _check_add!!_behaviour(acc, val; kwargs...)
146
164
# e.g. if it is immutable. We do test the `add!!` return value.
147
165
# That is what people should rely on. The mutation is just to save allocations.
148
166
acc_mutated = deepcopy (acc) # prevent this test changing others
149
- return check_equal (add!! (acc_mutated, val), acc + val; kwargs... )
167
+ return check_equal (add!! (acc_mutated, val), acc + val, " in add!! " ; kwargs... )
150
168
end
151
169
152
- # Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
153
- # not yet been implemented
170
+ # Checking equality with `NotImplemented` reports `@test_broken` since the derivative has
171
+ # intentionally not yet been implemented
154
172
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
155
173
function _check_add!!_behaviour (acc_mutated, acc:: ChainRulesCore.NotImplemented ; kwargs... )
156
174
return @test_broken acc_mutated == acc
157
175
end
158
176
function _check_add!!_behaviour (acc_mutated:: ChainRulesCore.NotImplemented , acc; kwargs... )
159
177
return @test_broken acc_mutated == acc
160
178
end
161
- # In this case we check for equality (messages etc. have to be equal)
179
+ # In this case we check for equality (not implemented messages etc. have to be equal)
162
180
function _check_add!!_behaviour (
163
181
acc_mutated:: ChainRulesCore.NotImplemented ,
164
182
acc:: ChainRulesCore.NotImplemented ;
0 commit comments