4
4
# Note that this must work well both on Differential types and Primal types
5
5
6
6
"""
7
- check_equal (actual, expected, [msg]; kwargs...)
7
+ test_approx (actual, expected, [msg]; kwargs...)
8
8
9
9
`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results
10
10
are shown on failures.
@@ -15,7 +15,7 @@ give bread-crumbs into nested structures.
15
15
16
16
All keyword arguments are passed to `isapprox`.
17
17
"""
18
- function check_equal (
18
+ function test_approx (
19
19
actual:: Union{AbstractArray{<:Number},Number} ,
20
20
expected:: Union{AbstractArray{<:Number},Number} ,
21
21
msg= " " ;
@@ -25,26 +25,26 @@ function check_equal(
25
25
end
26
26
27
27
for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
28
- @eval function check_equal (actual:: $T1 , expected:: $T2 , msg= " " ; kwargs... )
29
- return check_equal (unthunk (actual), unthunk (expected), msg; kwargs... )
28
+ @eval function test_approx (actual:: $T1 , expected:: $T2 , msg= " " ; kwargs... )
29
+ return test_approx (unthunk (actual), unthunk (expected), msg; kwargs... )
30
30
end
31
31
end
32
32
33
- check_equal (:: ZeroTangent , x, msg= " " ; kwargs... ) = check_equal (zero (x), x, msg; kwargs... )
34
- check_equal (x, :: ZeroTangent , msg= " " ; kwargs... ) = check_equal (x, zero (x), msg; kwargs... )
35
- check_equal (x:: ZeroTangent , y:: ZeroTangent , msg= " " ; kwargs... ) = @test true
33
+ test_approx (:: ZeroTangent , x, msg= " " ; kwargs... ) = test_approx (zero (x), x, msg; kwargs... )
34
+ test_approx (x, :: ZeroTangent , msg= " " ; kwargs... ) = test_approx (x, zero (x), msg; kwargs... )
35
+ test_approx (x:: ZeroTangent , y:: ZeroTangent , msg= " " ; kwargs... ) = @test true
36
36
37
37
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
38
- check_equal (x:: NoTangent , y:: Nothing , msg= " " ; kwargs... ) = @test true
39
- check_equal (x:: Nothing , y:: NoTangent , msg= " " ; kwargs... ) = @test true
38
+ test_approx (x:: NoTangent , y:: Nothing , msg= " " ; kwargs... ) = @test true
39
+ test_approx (x:: Nothing , y:: NoTangent , msg= " " ; kwargs... ) = @test true
40
40
41
41
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
42
42
# not yet been implemented
43
43
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
44
- check_equal (x:: ChainRulesCore.NotImplemented , y, msg= " " ; kwargs... ) = @test_broken x == y
45
- check_equal (x, y:: ChainRulesCore.NotImplemented , msg= " " ; kwargs... ) = @test_broken x == y
44
+ test_approx (x:: ChainRulesCore.NotImplemented , y, msg= " " ; kwargs... ) = @test_broken x == y
45
+ test_approx (x, y:: ChainRulesCore.NotImplemented , msg= " " ; kwargs... ) = @test_broken x == y
46
46
# In this case we check for equality (messages etc. have to be equal)
47
- function check_equal (
47
+ function test_approx (
48
48
x:: ChainRulesCore.NotImplemented , y:: ChainRulesCore.NotImplemented , msg= " " ; kwargs...
49
49
)
50
50
return @test_msg msg x == y
53
53
"""
54
54
_can_pass_early(actual, expected; kwargs...)
55
55
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper
56
- and can just report `check_equal ` as passing.
56
+ and can just report `test_approx ` as passing.
57
57
58
58
If either `==` or `≈` return true then so does this.
59
59
The `kwargs` are passed on to `isapprox`
@@ -69,33 +69,33 @@ function _can_pass_early(actual, expected; kwargs...)
69
69
return false
70
70
end
71
71
72
- function check_equal (actual:: AbstractArray , expected:: AbstractArray , msg= " " ; kwargs... )
72
+ function test_approx (actual:: AbstractArray , expected:: AbstractArray , msg= " " ; kwargs... )
73
73
if _can_pass_early (actual, expected)
74
74
@test true
75
75
else
76
76
@test_msg " $msg : indices must match" eachindex (actual) == eachindex (expected)
77
77
for ii in eachindex (actual)
78
78
new_msg = " $msg $(typeof (actual)) [$ii ]"
79
- check_equal (actual[ii], expected[ii], new_msg; kwargs... )
79
+ test_approx (actual[ii], expected[ii], new_msg; kwargs... )
80
80
end
81
81
end
82
82
end
83
83
84
- function check_equal (actual:: Tangent{P} , expected:: Tangent{P} , msg= " " ; kwargs... ) where {P}
84
+ function test_approx (actual:: Tangent{P} , expected:: Tangent{P} , msg= " " ; kwargs... ) where {P}
85
85
if _can_pass_early (actual, expected)
86
86
@test true
87
87
else
88
88
all_keys = union (keys (actual), keys (expected))
89
89
for ii in all_keys
90
90
new_msg = " $msg $P .$ii "
91
- check_equal (
91
+ test_approx (
92
92
getproperty (actual, ii), getproperty (expected, ii), new_msg; kwargs...
93
93
)
94
94
end
95
95
end
96
96
end
97
97
98
- function check_equal (
98
+ function test_approx (
99
99
:: Tangent{ActualPrimal} , expected:: Tangent{ExpectedPrimal} , msg= " " ; kwargs...
100
100
) where {ActualPrimal,ExpectedPrimal}
101
101
# this will certainly fail as we have another dispatch for that, but this will give as
@@ -104,7 +104,7 @@ function check_equal(
104
104
end
105
105
106
106
# Some structual differential and a natural differential
107
- function check_equal (actual:: Tangent{P,T} , expected, msg= " " ; kwargs... ) where {T,P}
107
+ function test_approx (actual:: Tangent{P,T} , expected, msg= " " ; kwargs... ) where {T,P}
108
108
if _can_pass_early (actual, expected)
109
109
@test true
110
110
else
@@ -114,42 +114,42 @@ function check_equal(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T
114
114
# the natural differential is allowed to have other properties that we ignore
115
115
for ii in propertynames (actual)
116
116
new_msg = " $msg $P .$ii "
117
- check_equal (
117
+ test_approx (
118
118
getproperty (actual, ii), getproperty (expected, ii), new_msg; kwargs...
119
119
)
120
120
end
121
121
end
122
122
end
123
- check_equal (x, y:: Tangent , msg= " " ; kwargs... ) = check_equal (y, x, msg; kwargs... )
123
+ test_approx (x, y:: Tangent , msg= " " ; kwargs... ) = test_approx (y, x, msg; kwargs... )
124
124
125
125
# This catches comparisons of Tangents and Tuples/NamedTuple
126
126
# and gives an error message complaining about that. the `@test` will definitely fail
127
127
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
128
- function check_equal (x:: Tangent , y:: LegacyZygoteCompTypes , msg= " " ; kwargs... )
128
+ function test_approx (x:: Tangent , y:: LegacyZygoteCompTypes , msg= " " ; kwargs... )
129
129
@test_msg " $msg : for structural differentials use `Tangent`" typeof (x) === typeof (y)
130
130
end
131
- function check_equal (x:: LegacyZygoteCompTypes , y:: Tangent , msg= " " ; kwargs... )
132
- return check_equal (y, x, msg; kwargs... )
131
+ function test_approx (x:: LegacyZygoteCompTypes , y:: Tangent , msg= " " ; kwargs... )
132
+ return test_approx (y, x, msg; kwargs... )
133
133
end
134
134
135
135
# Generic fallback, probably a tuple or something
136
- function check_equal (actual:: A , expected:: E , msg= " " ; kwargs... ) where {A,E}
136
+ function test_approx (actual:: A , expected:: E , msg= " " ; kwargs... ) where {A,E}
137
137
if _can_pass_early (actual, expected)
138
138
@test true
139
139
else
140
140
c_actual = collect (actual)
141
141
c_expected = collect (expected)
142
142
if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow
143
- throw (MethodError, check_equal , (actual, expected))
143
+ throw (MethodError, test_approx , (actual, expected))
144
144
end
145
- check_equal (c_actual, c_expected; kwargs... )
145
+ test_approx (c_actual, c_expected; kwargs... )
146
146
end
147
147
end
148
148
149
149
# ##########################################################################################
150
150
151
151
"""
152
- _check_add !!_behaviour(acc, val)
152
+ _test_add !!_behaviour(acc, val)
153
153
154
154
This checks that `acc + val` is the same as `add!!(acc, val)`.
155
155
It matters primarily for types that overload `add!!` such as `InplaceableThunk`s.
@@ -159,25 +159,25 @@ It matters primarily for types that overload `add!!` such as `InplaceableThunk`s
159
159
160
160
`kwargs` are all passed on to isapprox
161
161
"""
162
- function _check_add !!_behaviour (acc, val; kwargs... )
162
+ function _test_add !!_behaviour (acc, val; kwargs... )
163
163
# Note, we don't test that `acc` is actually mutated because it doesn't have to be
164
164
# e.g. if it is immutable. We do test the `add!!` return value.
165
165
# That is what people should rely on. The mutation is just to save allocations.
166
166
acc_mutated = deepcopy (acc) # prevent this test changing others
167
- return check_equal (add!! (acc_mutated, val), acc + val, " in add!!" ; kwargs... )
167
+ return test_approx (add!! (acc_mutated, val), acc + val, " in add!!" ; kwargs... )
168
168
end
169
169
170
170
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has
171
171
# intentionally not yet been implemented
172
172
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
173
- function _check_add !!_behaviour (acc_mutated, acc:: ChainRulesCore.NotImplemented ; kwargs... )
173
+ function _test_add !!_behaviour (acc_mutated, acc:: ChainRulesCore.NotImplemented ; kwargs... )
174
174
return @test_broken acc_mutated == acc
175
175
end
176
- function _check_add !!_behaviour (acc_mutated:: ChainRulesCore.NotImplemented , acc; kwargs... )
176
+ function _test_add !!_behaviour (acc_mutated:: ChainRulesCore.NotImplemented , acc; kwargs... )
177
177
return @test_broken acc_mutated == acc
178
178
end
179
179
# In this case we check for equality (not implemented messages etc. have to be equal)
180
- function _check_add !!_behaviour (
180
+ function _test_add !!_behaviour (
181
181
acc_mutated:: ChainRulesCore.NotImplemented ,
182
182
acc:: ChainRulesCore.NotImplemented ;
183
183
kwargs... ,
0 commit comments