Skip to content

Commit c405d75

Browse files
authored
Merge pull request #161 from JuliaDiff/ox/check_equal
Rename check_equals to test_approx
2 parents e061a82 + f3e2f7d commit c405d75

File tree

8 files changed

+90
-81
lines changed

8 files changed

+90
-81
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.14"
3+
version = "0.6.15"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,13 @@ Inserting inappropriate zeros can thus hide errors.
124124

125125
## Custom finite differencing
126126

127-
If a package is using a custom finite differencing method of testing the `frule`s and `rrule`s, `check_equal` function provides a convenient way of comparing [various types](https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html#Design-Notes:-The-many-to-many-relationship-between-differential-types-and-primal-types.) of differentials.
127+
If a package is using a custom finite differencing method of testing the `frule`s and `rrule`s, `test_approx` function provides a convenient way of comparing [various types](https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html#Design-Notes:-The-many-to-many-relationship-between-differential-types-and-primal-types.) of differentials.
128128

129129
It is effectively `(a, b) -> @test isapprox(a, b)`, but it preprocesses `thunk`s and `ChainRules` differential types `ZeroTangent()`, `NoTangent()`, and `Tangent`, such that the error messages are helpful.
130130

131131
For example,
132132
```julia
133-
check_equal((@thunk 2*2.0), 4.1)
133+
test_approx((@thunk 2*2.0), 4.1)
134134
```
135135
shows both the expression and the evaluated `thunk`s
136136
```julia

src/ChainRulesTestUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import FiniteDifferences: rand_tangent
1414
const _fdm = central_fdm(5, 1; max_range=1e-2)
1515

1616
export TestIterator
17-
export check_equal, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
17+
export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
1818
export
1919

2020

src/check_result.jl

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Note that this must work well both on Differential types and Primal types
55

66
"""
7-
check_equal(actual, expected, [msg]; kwargs...)
7+
test_approx(actual, expected, [msg]; kwargs...)
88
99
`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results
1010
are shown on failures.
@@ -15,7 +15,7 @@ give bread-crumbs into nested structures.
1515
1616
All keyword arguments are passed to `isapprox`.
1717
"""
18-
function check_equal(
18+
function test_approx(
1919
actual::Union{AbstractArray{<:Number},Number},
2020
expected::Union{AbstractArray{<:Number},Number},
2121
msg="";
@@ -25,26 +25,26 @@ function check_equal(
2525
end
2626

2727
for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
28-
@eval function check_equal(actual::$T1, expected::$T2, msg=""; kwargs...)
29-
return check_equal(unthunk(actual), unthunk(expected), msg; kwargs...)
28+
@eval function test_approx(actual::$T1, expected::$T2, msg=""; kwargs...)
29+
return test_approx(unthunk(actual), unthunk(expected), msg; kwargs...)
3030
end
3131
end
3232

33-
check_equal(::ZeroTangent, x, msg=""; kwargs...) = check_equal(zero(x), x, msg; kwargs...)
34-
check_equal(x, ::ZeroTangent, msg=""; kwargs...) = check_equal(x, zero(x), msg; kwargs...)
35-
check_equal(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true
33+
test_approx(::ZeroTangent, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
34+
test_approx(x, ::ZeroTangent, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
35+
test_approx(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true
3636

3737
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
38-
check_equal(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true
39-
check_equal(x::Nothing, y::NoTangent, msg=""; kwargs...) = @test true
38+
test_approx(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true
39+
test_approx(x::Nothing, y::NoTangent, msg=""; kwargs...) = @test true
4040

4141
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
4242
# not yet been implemented
4343
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
44-
check_equal(x::ChainRulesCore.NotImplemented, y, msg=""; kwargs...) = @test_broken x == y
45-
check_equal(x, y::ChainRulesCore.NotImplemented, msg=""; kwargs...) = @test_broken x == y
44+
test_approx(x::ChainRulesCore.NotImplemented, y, msg=""; kwargs...) = @test_broken x == y
45+
test_approx(x, y::ChainRulesCore.NotImplemented, msg=""; kwargs...) = @test_broken x == y
4646
# In this case we check for equality (messages etc. have to be equal)
47-
function check_equal(
47+
function test_approx(
4848
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented, msg=""; kwargs...
4949
)
5050
return @test_msg msg x == y
@@ -53,7 +53,7 @@ end
5353
"""
5454
_can_pass_early(actual, expected; kwargs...)
5555
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper
56-
and can just report `check_equal` as passing.
56+
and can just report `test_approx` as passing.
5757
5858
If either `==` or `≈` return true then so does this.
5959
The `kwargs` are passed on to `isapprox`
@@ -69,33 +69,33 @@ function _can_pass_early(actual, expected; kwargs...)
6969
return false
7070
end
7171

72-
function check_equal(actual::AbstractArray, expected::AbstractArray, msg=""; kwargs...)
72+
function test_approx(actual::AbstractArray, expected::AbstractArray, msg=""; kwargs...)
7373
if _can_pass_early(actual, expected)
7474
@test true
7575
else
7676
@test_msg "$msg: indices must match" eachindex(actual) == eachindex(expected)
7777
for ii in eachindex(actual)
7878
new_msg = "$msg $(typeof(actual))[$ii]"
79-
check_equal(actual[ii], expected[ii], new_msg; kwargs...)
79+
test_approx(actual[ii], expected[ii], new_msg; kwargs...)
8080
end
8181
end
8282
end
8383

84-
function check_equal(actual::Tangent{P}, expected::Tangent{P}, msg=""; kwargs...) where {P}
84+
function test_approx(actual::Tangent{P}, expected::Tangent{P}, msg=""; kwargs...) where {P}
8585
if _can_pass_early(actual, expected)
8686
@test true
8787
else
8888
all_keys = union(keys(actual), keys(expected))
8989
for ii in all_keys
9090
new_msg = "$msg $P.$ii"
91-
check_equal(
91+
test_approx(
9292
getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs...
9393
)
9494
end
9595
end
9696
end
9797

98-
function check_equal(
98+
function test_approx(
9999
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}, msg=""; kwargs...
100100
) where {ActualPrimal,ExpectedPrimal}
101101
# this will certainly fail as we have another dispatch for that, but this will give as
@@ -104,7 +104,7 @@ function check_equal(
104104
end
105105

106106
# Some structual differential and a natural differential
107-
function check_equal(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T,P}
107+
function test_approx(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T,P}
108108
if _can_pass_early(actual, expected)
109109
@test true
110110
else
@@ -114,42 +114,42 @@ function check_equal(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T
114114
# the natural differential is allowed to have other properties that we ignore
115115
for ii in propertynames(actual)
116116
new_msg = "$msg $P.$ii"
117-
check_equal(
117+
test_approx(
118118
getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs...
119119
)
120120
end
121121
end
122122
end
123-
check_equal(x, y::Tangent, msg=""; kwargs...) = check_equal(y, x, msg; kwargs...)
123+
test_approx(x, y::Tangent, msg=""; kwargs...) = test_approx(y, x, msg; kwargs...)
124124

125125
# This catches comparisons of Tangents and Tuples/NamedTuple
126126
# and gives an error message complaining about that. the `@test` will definitely fail
127127
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
128-
function check_equal(x::Tangent, y::LegacyZygoteCompTypes, msg=""; kwargs...)
128+
function test_approx(x::Tangent, y::LegacyZygoteCompTypes, msg=""; kwargs...)
129129
@test_msg "$msg: for structural differentials use `Tangent`" typeof(x) === typeof(y)
130130
end
131-
function check_equal(x::LegacyZygoteCompTypes, y::Tangent, msg=""; kwargs...)
132-
return check_equal(y, x, msg; kwargs...)
131+
function test_approx(x::LegacyZygoteCompTypes, y::Tangent, msg=""; kwargs...)
132+
return test_approx(y, x, msg; kwargs...)
133133
end
134134

135135
# Generic fallback, probably a tuple or something
136-
function check_equal(actual::A, expected::E, msg=""; kwargs...) where {A,E}
136+
function test_approx(actual::A, expected::E, msg=""; kwargs...) where {A,E}
137137
if _can_pass_early(actual, expected)
138138
@test true
139139
else
140140
c_actual = collect(actual)
141141
c_expected = collect(expected)
142142
if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow
143-
throw(MethodError, check_equal, (actual, expected))
143+
throw(MethodError, test_approx, (actual, expected))
144144
end
145-
check_equal(c_actual, c_expected; kwargs...)
145+
test_approx(c_actual, c_expected; kwargs...)
146146
end
147147
end
148148

149149
###########################################################################################
150150

151151
"""
152-
_check_add!!_behaviour(acc, val)
152+
_test_add!!_behaviour(acc, val)
153153
154154
This checks that `acc + val` is the same as `add!!(acc, val)`.
155155
It matters primarily for types that overload `add!!` such as `InplaceableThunk`s.
@@ -159,25 +159,25 @@ It matters primarily for types that overload `add!!` such as `InplaceableThunk`s
159159
160160
`kwargs` are all passed on to isapprox
161161
"""
162-
function _check_add!!_behaviour(acc, val; kwargs...)
162+
function _test_add!!_behaviour(acc, val; kwargs...)
163163
# Note, we don't test that `acc` is actually mutated because it doesn't have to be
164164
# e.g. if it is immutable. We do test the `add!!` return value.
165165
# That is what people should rely on. The mutation is just to save allocations.
166166
acc_mutated = deepcopy(acc) # prevent this test changing others
167-
return check_equal(add!!(acc_mutated, val), acc + val, "in add!!"; kwargs...)
167+
return test_approx(add!!(acc_mutated, val), acc + val, "in add!!"; kwargs...)
168168
end
169169

170170
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has
171171
# intentionally not yet been implemented
172172
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
173-
function _check_add!!_behaviour(acc_mutated, acc::ChainRulesCore.NotImplemented; kwargs...)
173+
function _test_add!!_behaviour(acc_mutated, acc::ChainRulesCore.NotImplemented; kwargs...)
174174
return @test_broken acc_mutated == acc
175175
end
176-
function _check_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc; kwargs...)
176+
function _test_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc; kwargs...)
177177
return @test_broken acc_mutated == acc
178178
end
179179
# In this case we check for equality (not implemented messages etc. have to be equal)
180-
function _check_add!!_behaviour(
180+
function _test_add!!_behaviour(
181181
acc_mutated::ChainRulesCore.NotImplemented,
182182
acc::ChainRulesCore.NotImplemented;
183183
kwargs...,

src/deprecated.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# TODO remove these in version 0.6
1+
# TODO remove these in version 0.7
22
# We are silently deprecating them as there is no alternative we are providing
33

44
function Base.isapprox(a, b::Union{AbstractZero,AbstractThunk}; kwargs...)
@@ -35,6 +35,8 @@ end
3535
# Must be for same primal
3636
Base.isapprox(d_ad::Tangent{P}, d_fd::Tangent{Q}; kwargs...) where {P,Q} = false
3737

38+
###############################################
39+
3840
# From when primal and tangent was passed as a tuple
3941
@deprecate(
4042
rrule_test(f, ȳ, inputs::Tuple{Any,Any}...; kwargs...),
@@ -45,3 +47,6 @@ Base.isapprox(d_ad::Tangent{P}, d_fd::Tangent{Q}; kwargs...) where {P,Q} = false
4547
frule_test(f, inputs::Tuple{Any,Any}...; kwargs...),
4648
test_frule(f, ((x dx) for (x, dx) in inputs)...; kwargs...)
4749
)
50+
51+
# renamed
52+
Base.@deprecate_binding check_equal test_approx

src/testers.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
3232
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
3333
_, real_tangent = frule((ZeroTangent(), real(Δx)), f, z; fkwargs...)
3434
_, embedded_tangent = frule((ZeroTangent(), Δx), f, z; fkwargs...)
35-
check_equal(real_tangent, embedded_tangent; isapprox_kwargs...)
35+
test_approx(real_tangent, embedded_tangent; isapprox_kwargs...)
3636
end
3737
end
3838
if z isa Complex
@@ -53,7 +53,7 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
5353
_, back = rrule(f, z)
5454
_, real_cotangent = back(real(Δu))
5555
_, embedded_cotangent = back(Δu)
56-
check_equal(real_cotangent, embedded_cotangent; isapprox_kwargs...)
56+
test_approx(real_cotangent, embedded_cotangent; isapprox_kwargs...)
5757
end
5858
end
5959
if Ω isa Complex
@@ -113,7 +113,7 @@ function test_frule(
113113
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
114114
Ω_ad, dΩ_ad = res
115115
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
116-
check_equal(Ω_ad, Ω; isapprox_kwargs...)
116+
test_approx(Ω_ad, Ω; isapprox_kwargs...)
117117

118118
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
119119
ẋs_is_ignored = isa.(ẋs, Union{Nothing,NoTangent})
@@ -127,10 +127,10 @@ function test_frule(
127127

128128
# Correctness testing via finite differencing.
129129
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
130-
check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)
130+
test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...)
131131

132132
acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
133-
_check_add!!_behaviour(acc, dΩ_ad; rtol=rtol, atol=atol, kwargs...)
133+
_test_add!!_behaviour(acc, dΩ_ad; rtol=rtol, atol=atol, kwargs...)
134134
end # top-level testset
135135
end
136136

@@ -181,7 +181,7 @@ function test_rrule(
181181
res === nothing && throw(MethodError(rrule, typeof((f, xs...))))
182182
y_ad, pullback = res
183183
y = f(xs...; fkwargs...)
184-
check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct
184+
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
185185

186186
= output_tangent isa Auto ? rand_tangent(y) : output_tangent
187187

@@ -216,8 +216,8 @@ function test_rrule(
216216
x̄_ad isa AbstractThunk && check_inferred && _test_inferred(unthunk, x̄_ad)
217217

218218
# The main test of the actual deriviative being correct:
219-
check_equal(x̄_ad, x̄_fd; isapprox_kwargs...)
220-
_check_add!!_behaviour(accumulated_x̄, x̄_ad; isapprox_kwargs...)
219+
test_approx(x̄_ad, x̄_fd; isapprox_kwargs...)
220+
_test_add!!_behaviour(accumulated_x̄, x̄_ad; isapprox_kwargs...)
221221
end
222222
end
223223

0 commit comments

Comments
 (0)