4
4
# Note that this must work well both on Differential types and Primal types
5
5
6
6
"""
7
- check_equal(actual, expected; kwargs...)
7
+ check_equal(actual, expected, [msg] ; kwargs...)
8
8
9
9
`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results
10
10
are shown on failures.
11
11
Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc.
12
+
13
+ If provided `msg` is printed on a failure. Often additional items are appended to `msg` to
14
+ give bread-crumbs into nested structures.
15
+
12
16
All keyword arguments are passed to `isapprox`.
13
17
"""
14
18
function check_equal (
15
19
actual:: Union{AbstractArray{<:Number},Number} ,
16
- expected:: Union{AbstractArray{<:Number},Number} ;
20
+ expected:: Union{AbstractArray{<:Number},Number} ,
21
+ msg= " " ,
22
+ ;
17
23
kwargs... ,
18
24
)
19
- @test isapprox (actual, expected; kwargs... )
25
+ @test_msg msg isapprox (actual, expected; kwargs... )
20
26
end
21
27
22
28
for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
23
- @eval function check_equal (actual:: $T1 , expected:: $T2 ; kwargs... )
24
- return check_equal (unthunk (actual), unthunk (expected); kwargs... )
29
+ @eval function check_equal (actual:: $T1 , expected:: $T2 , msg = " " ; kwargs... )
30
+ return check_equal (unthunk (actual), unthunk (expected), msg ; kwargs... )
25
31
end
26
32
end
27
33
28
- check_equal (:: ZeroTangent , x; kwargs... ) = check_equal (zero (x), x; kwargs... )
29
- check_equal (x, :: ZeroTangent ; kwargs... ) = check_equal (x, zero (x); kwargs... )
30
- check_equal (x:: ZeroTangent , y:: ZeroTangent ; kwargs... ) = @test true
34
+ check_equal (:: ZeroTangent , x, msg = " " ; kwargs... ) = check_equal (zero (x), x, msg ; kwargs... )
35
+ check_equal (x, :: ZeroTangent , msg = " " ; kwargs... ) = check_equal (x, zero (x), msg ; kwargs... )
36
+ check_equal (x:: ZeroTangent , y:: ZeroTangent , msg = " " ; kwargs... ) = @test true
31
37
32
38
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
33
- check_equal (x:: NoTangent , y:: Nothing ; kwargs... ) = @test true
34
- check_equal (x:: Nothing , y:: NoTangent ; kwargs... ) = @test true
39
+ check_equal (x:: NoTangent , y:: Nothing , msg = " " ; kwargs... ) = @test true
40
+ check_equal (x:: Nothing , y:: NoTangent , msg = " " ; kwargs... ) = @test true
35
41
36
42
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
37
43
# not yet been implemented
38
44
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
39
- check_equal (x:: ChainRulesCore.NotImplemented , y; kwargs... ) = @test_broken x == y
40
- check_equal (x, y:: ChainRulesCore.NotImplemented ; kwargs... ) = @test_broken x == y
45
+ check_equal (x:: ChainRulesCore.NotImplemented , y, msg = " " ; kwargs... ) = @test_broken x == y
46
+ check_equal (x, y:: ChainRulesCore.NotImplemented , msg = " " ; kwargs... ) = @test_broken x == y
41
47
# In this case we check for equality (messages etc. have to be equal)
42
48
function check_equal (
43
- x:: ChainRulesCore.NotImplemented , y:: ChainRulesCore.NotImplemented ; kwargs...
49
+ x:: ChainRulesCore.NotImplemented , y:: ChainRulesCore.NotImplemented , msg = " " ; kwargs...
44
50
)
45
- return @test x == y
51
+ return @test_msg msg x == y
46
52
end
47
53
48
54
"""
49
55
_can_pass_early(actual, expected; kwargs...)
50
- Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
56
+ Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper
51
57
and can just report `check_equal` as passing.
52
58
53
59
If either `==` or `≈` return true then so does this.
@@ -64,60 +70,71 @@ function _can_pass_early(actual, expected; kwargs...)
64
70
return false
65
71
end
66
72
67
- function check_equal (actual:: AbstractArray , expected:: AbstractArray ; kwargs... )
73
+ function check_equal (actual:: AbstractArray , expected:: AbstractArray , msg = " " ; kwargs... )
68
74
if _can_pass_early (actual, expected)
69
75
@test true
70
76
else
71
- @test eachindex (actual) == eachindex (expected)
72
- @testset " $(typeof (actual)) [$ii ]" for ii in eachindex (actual)
73
- check_equal (actual[ii], expected[ii]; kwargs... )
77
+ @test_msg " $msg : indexes must match" eachindex (actual) == eachindex (expected)
78
+ for ii in eachindex (actual)
79
+ new_msg = " $msg $(typeof (actual)) [$ii ]"
80
+ check_equal (actual[ii], expected[ii], new_msg; kwargs... )
74
81
end
75
82
end
76
83
end
77
84
78
- function check_equal (actual:: Tangent{P} , expected:: Tangent{P} ; kwargs... ) where {P}
85
+ function check_equal (actual:: Tangent{P} , expected:: Tangent{P} , msg = " " ; kwargs... ) where {P}
79
86
if _can_pass_early (actual, expected)
80
87
@test true
81
88
else
82
89
all_keys = union (keys (actual), keys (expected))
83
- @testset " $P .$ii " for ii in all_keys
84
- check_equal (getproperty (actual, ii), getproperty (expected, ii); kwargs... )
90
+ for ii in all_keys
91
+ new_msg = " $msg $P .$ii "
92
+ check_equal (
93
+ getproperty (actual, ii), getproperty (expected, ii), new_msg; kwargs...
94
+ )
85
95
end
86
96
end
87
97
end
88
98
89
99
function check_equal (
90
- :: Tangent{ActualPrimal} , expected:: Tangent{ExpectedPrimal} ; kwargs...
100
+ :: Tangent{ActualPrimal} , expected:: Tangent{ExpectedPrimal} , msg = " " ; kwargs...
91
101
) where {ActualPrimal,ExpectedPrimal}
92
102
# this will certainly fail as we have another dispatch for that, but this will give as
93
103
# good error message
94
104
@test ActualPrimal === ExpectedPrimal
95
105
end
96
106
97
107
# Some structual differential and a natural differential
98
- function check_equal (actual:: Tangent{P,T} , expected; kwargs... ) where {T,P}
108
+ function check_equal (actual:: Tangent{P,T} , expected, msg = " " ; kwargs... ) where {T,P}
99
109
if _can_pass_early (actual, expected)
100
110
@test true
101
111
else
102
112
@assert (T <: NamedTuple ) # it should be a structual differential if we hit this
103
113
104
114
# We are only checking the properties that are in the Tangent
105
115
# the natural differential is allowed to have other properties that we ignore
106
- @testset " $P .$ii " for ii in propertynames (actual)
107
- check_equal (getproperty (actual, ii), getproperty (expected, ii); kwargs... )
116
+ for ii in propertynames (actual)
117
+ new_msg = " $msg $P .$ii "
118
+ check_equal (
119
+ getproperty (actual, ii), getproperty (expected, ii), new_msg; kwargs...
120
+ )
108
121
end
109
122
end
110
123
end
111
- check_equal (x, y:: Tangent ; kwargs... ) = check_equal (y, x; kwargs... )
124
+ check_equal (x, y:: Tangent , msg = " " ; kwargs... ) = check_equal (y, x, msg ; kwargs... )
112
125
113
126
# This catches comparisons of Tangents and Tuples/NamedTuple
114
- # and gives an error message complaining about that
127
+ # and gives an error message complaining about that. the `@test` will definately fail
115
128
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
116
- check_equal (:: C , :: T ; kwargs... ) where {C<: Tangent ,T<: LegacyZygoteCompTypes } = @test C === T
117
- check_equal (:: T , :: C ; kwargs... ) where {C<: Tangent ,T<: LegacyZygoteCompTypes } = @test T === C
129
+ function check_equal (x:: Tangent , y:: LegacyZygoteCompTypes , msg= " " ; kwargs... )
130
+ @test_msg " $msg : for structural differentials use `Tangent`" typeof (x) === typeof (y)
131
+ end
132
+ function check_equal (x:: LegacyZygoteCompTypes , y:: Tangent , msg= " " ; kwargs... )
133
+ return check_equal (y, x, msg; kwargs... )
134
+ end
118
135
119
136
# Generic fallback, probably a tuple or something
120
- function check_equal (actual:: A , expected:: E ; kwargs... ) where {A,E}
137
+ function check_equal (actual:: A , expected:: E , msg = " " ; kwargs... ) where {A,E}
121
138
if _can_pass_early (actual, expected)
122
139
@test true
123
140
else
@@ -130,6 +147,8 @@ function check_equal(actual::A, expected::E; kwargs...) where {A,E}
130
147
end
131
148
end
132
149
150
+ # ##########################################################################################
151
+
133
152
"""
134
153
_check_add!!_behaviour(acc, val)
135
154
@@ -146,19 +165,19 @@ function _check_add!!_behaviour(acc, val; kwargs...)
146
165
# e.g. if it is immutable. We do test the `add!!` return value.
147
166
# That is what people should rely on. The mutation is just to save allocations.
148
167
acc_mutated = deepcopy (acc) # prevent this test changing others
149
- return check_equal (add!! (acc_mutated, val), acc + val; kwargs... )
168
+ return check_equal (add!! (acc_mutated, val), acc + val, " in add!! " ; kwargs... )
150
169
end
151
170
152
- # Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
153
- # not yet been implemented
171
+ # Checking equality with `NotImplemented` reports `@test_broken` since the derivative has
172
+ # intentionally not yet been implemented
154
173
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
155
174
function _check_add!!_behaviour (acc_mutated, acc:: ChainRulesCore.NotImplemented ; kwargs... )
156
175
return @test_broken acc_mutated == acc
157
176
end
158
177
function _check_add!!_behaviour (acc_mutated:: ChainRulesCore.NotImplemented , acc; kwargs... )
159
178
return @test_broken acc_mutated == acc
160
179
end
161
- # In this case we check for equality (messages etc. have to be equal)
180
+ # In this case we check for equality (not implemented messages etc. have to be equal)
162
181
function _check_add!!_behaviour (
163
182
acc_mutated:: ChainRulesCore.NotImplemented ,
164
183
acc:: ChainRulesCore.NotImplemented ;
0 commit comments