Skip to content

Commit 27a1933

Browse files
authored
Merge pull request #156 from JuliaDiff/ox/clean
format the code
2 parents 0095fb2 + d72b146 commit 27a1933

12 files changed

+236
-226
lines changed

docs/make.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
using ChainRulesTestUtils
22
using Documenter
33

4-
5-
makedocs(
4+
makedocs(;
65
modules=[ChainRulesTestUtils],
7-
format=Documenter.HTML(prettyurls=false, assets=["assets/chainrules.css"]),
6+
format=Documenter.HTML(; prettyurls=false, assets=["assets/chainrules.css"]),
87
sitename="ChainRulesTestUtils",
98
authors="JuliaDiff contributors",
109
strict=true,
1110
checkdocs=:exports,
1211
)
1312

1413
const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git"
15-
deploydocs(
16-
repo=repo,
17-
push_preview=true,
18-
)
14+
deploydocs(; repo=repo, push_preview=true)

src/ChainRulesTestUtils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ export TestIterator
1717
export check_equal, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
1818
export
1919

20-
2120
include("generate_tangent.jl")
2221
include("data_generation.jl")
2322
include("iterator.jl")

src/check_result.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc.
1212
All keyword arguments are passed to `isapprox`.
1313
"""
1414
function check_equal(
15-
actual::Union{AbstractArray{<:Number}, Number},
16-
expected::Union{AbstractArray{<:Number}, Number};
17-
kwargs...
15+
actual::Union{AbstractArray{<:Number},Number},
16+
expected::Union{AbstractArray{<:Number},Number};
17+
kwargs...,
1818
)
1919
@test isapprox(actual, expected; kwargs...)
2020
end
2121

2222
for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
2323
@eval function check_equal(actual::$T1, expected::$T2; kwargs...)
24-
check_equal(unthunk(actual), unthunk(expected); kwargs...)
24+
return check_equal(unthunk(actual), unthunk(expected); kwargs...)
2525
end
2626
end
2727

@@ -75,7 +75,7 @@ function check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...)
7575
end
7676
end
7777

78-
function check_equal(actual::Tangent{P}, expected::Tangent{P}; kwargs...) where P
78+
function check_equal(actual::Tangent{P}, expected::Tangent{P}; kwargs...) where {P}
7979
if _can_pass_early(actual, expected)
8080
@test true
8181
else
@@ -88,15 +88,14 @@ end
8888

8989
function check_equal(
9090
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}; kwargs...
91-
) where {ActualPrimal, ExpectedPrimal}
91+
) where {ActualPrimal,ExpectedPrimal}
9292
# this will certainly fail as we have another dispatch for that, but this will give as
9393
# good error message
9494
@test ActualPrimal === ExpectedPrimal
9595
end
9696

97-
9897
# Some structual differential and a natural differential
99-
function check_equal(actual::Tangent{P, T}, expected; kwargs...) where {T, P}
98+
function check_equal(actual::Tangent{P,T}, expected; kwargs...) where {T,P}
10099
if _can_pass_early(actual, expected)
101100
@test true
102101
else
@@ -118,7 +117,7 @@ check_equal(::C, ::T; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @
118117
check_equal(::T, ::C; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test T === C
119118

120119
# Generic fallback, probably a tuple or something
121-
function check_equal(actual::A, expected::E; kwargs...) where {A, E}
120+
function check_equal(actual::A, expected::E; kwargs...) where {A,E}
122121
if _can_pass_early(actual, expected)
123122
@test true
124123
else
@@ -147,7 +146,7 @@ function _check_add!!_behaviour(acc, val; kwargs...)
147146
# e.g. if it is immutable. We do test the `add!!` return value.
148147
# That is what people should rely on. The mutation is just to save allocations.
149148
acc_mutated = deepcopy(acc) # prevent this test changing others
150-
check_equal(add!!(acc_mutated, val), acc + val; kwargs...)
149+
return check_equal(add!!(acc_mutated, val), acc + val; kwargs...)
151150
end
152151

153152
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
@@ -161,7 +160,8 @@ function _check_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc;
161160
end
162161
# In this case we check for equality (messages etc. have to be equal)
163162
function _check_add!!_behaviour(
164-
acc_mutated::ChainRulesCore.NotImplemented, acc::ChainRulesCore.NotImplemented;
163+
acc_mutated::ChainRulesCore.NotImplemented,
164+
acc::ChainRulesCore.NotImplemented;
165165
kwargs...,
166166
)
167167
return @test acc_mutated == acc

src/deprecated.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,39 @@
11
# TODO remove these in version 0.6
22
# We are silently deprecating them as there is no alternative we are providing
33

4-
Base.isapprox(a, b::Union{AbstractZero, AbstractThunk}; kwargs...) = isapprox(b, a; kwargs...)
5-
Base.isapprox(d_ad::AbstractThunk, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...)
6-
Base.isapprox(d_ad::NoTangent, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `NoTangent`")
4+
function Base.isapprox(a, b::Union{AbstractZero,AbstractThunk}; kwargs...)
5+
return isapprox(b, a; kwargs...)
6+
end
7+
function Base.isapprox(d_ad::AbstractThunk, d_fd; kwargs...)
8+
return isapprox(extern(d_ad), d_fd; kwargs...)
9+
end
10+
function Base.isapprox(d_ad::NoTangent, d_fd; kwargs...)
11+
return error("Tried to differentiate w.r.t. a `NoTangent`")
12+
end
713
# Call `all` to handle the case where `ZeroTangent` is standing in for a non-scalar zero
8-
Base.isapprox(d_ad::ZeroTangent, d_fd; kwargs...) = all(isapprox.(extern(d_ad), d_fd; kwargs...))
14+
function Base.isapprox(d_ad::ZeroTangent, d_fd; kwargs...)
15+
return all(isapprox.(extern(d_ad), d_fd; kwargs...))
16+
end
917

1018
isapprox_vec(a, b; kwargs...) = isapprox(first(to_vec(a)), first(to_vec(b)); kwargs...)
1119
Base.isapprox(a, b::Tangent; kwargs...) = isapprox(b, a; kwargs...)
1220
function Base.isapprox(d_ad::Tangent{<:Tuple}, d_fd::Tuple; kwargs...)
1321
return isapprox_vec(d_ad, d_fd; kwargs...)
1422
end
1523
function Base.isapprox(
16-
d_ad::Tangent{P, <:Tuple}, d_fd::Tangent{P, <:Tuple}; kwargs...
17-
) where {P <: Tuple}
24+
d_ad::Tangent{P,<:Tuple}, d_fd::Tangent{P,<:Tuple}; kwargs...
25+
) where {P<:Tuple}
1826
return isapprox_vec(d_ad, d_fd; kwargs...)
1927
end
2028

2129
function Base.isapprox(
22-
d_ad::Tangent{P, <:NamedTuple{T}}, d_fd::Tangent{P, <:NamedTuple{T}}; kwargs...,
23-
) where {P, T}
30+
d_ad::Tangent{P,<:NamedTuple{T}}, d_fd::Tangent{P,<:NamedTuple{T}}; kwargs...
31+
) where {P,T}
2432
return isapprox_vec(d_ad, d_fd; kwargs...)
2533
end
2634

27-
2835
# Must be for same primal
29-
Base.isapprox(d_ad::Tangent{P}, d_fd::Tangent{Q}; kwargs...) where {P, Q} = false
30-
36+
Base.isapprox(d_ad::Tangent{P}, d_fd::Tangent{Q}; kwargs...) where {P,Q} = false
3137

3238
# From when primal and tangent was passed as a tuple
3339
@deprecate(

src/finite_difference_calls.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ function _make_jvp_call(fdm, f, y, xs, ẋs, ignores)
1919
f2 = _wrap_function(f, xs, ignores)
2020

2121
ignores = collect(ignores)
22-
all(ignores) && return ntuple(_->nothing, length(xs))
22+
all(ignores) && return ntuple(_ -> nothing, length(xs))
2323
sigargs = zip(xs[.!ignores], ẋs[.!ignores])
2424
return _maybe_fix_to_composite(y, jvp(fdm, f2, sigargs...))
2525
end
2626

27-
2827
"""
2928
_make_j′vp_call(fdm, f, ȳ, xs, ignores) -> Tuple
3029
@@ -89,10 +88,9 @@ function _wrap_function(f, xs, ignores)
8988
return fnew
9089
end
9190

92-
9391
# TODO: remove after https://github.com/JuliaDiff/FiniteDifferences.jl/issues/97
9492
# For functions which return a tuple, FD returns a tuple to represent the differential. Tuple
9593
# is not a natural differential, because it doesn't overload +, so make it a Tangent.
9694
_maybe_fix_to_composite(::P, x::Tuple) where {P} = Tangent{P}(x...)
97-
_maybe_fix_to_composite(::P, x::NamedTuple) where {P} = Tangent{P}(;x...)
95+
_maybe_fix_to_composite(::P, x::NamedTuple) where {P} = Tangent{P}(; x...)
9896
_maybe_fix_to_composite(::Any, x) = x

src/testers.jl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,12 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
6060
Δv = one(Ω) * im
6161
@testset "with cotangent $Δv" begin
6262
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
63-
test_rrule(f, z Δx; output_tangent=Δv, rule_test_kwargs...)
63+
test_rrule(f, z Δx; output_tangent=Δv, rule_test_kwargs...)
6464
end
6565
end
6666
end # top-level testset
6767
end
6868

69-
7069
"""
7170
test_frule(f, inputs...; kwargs...)
7271
@@ -87,12 +86,15 @@ end
8786
- All remaining keyword arguments are passed to `isapprox`.
8887
"""
8988
function test_frule(
90-
f, inputs...;
89+
f,
90+
inputs...;
9191
output_tangent=Auto(),
9292
fdm=_fdm,
9393
check_inferred::Bool=true,
9494
fkwargs::NamedTuple=NamedTuple(),
95-
rtol::Real=1e-9, atol::Real=1e-9, kwargs...
95+
rtol::Real=1e-9,
96+
atol::Real=1e-9,
97+
kwargs...,
9698
)
9799
# To simplify some of the calls we make later lets group the kwargs for reuse
98100
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
@@ -114,12 +116,12 @@ function test_frule(
114116
check_equal(Ω_ad, Ω; isapprox_kwargs...)
115117

116118
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
117-
ẋs_is_ignored = isa.(ẋs, Union{Nothing, NoTangent})
119+
ẋs_is_ignored = isa.(ẋs, Union{Nothing,NoTangent})
118120
if any(ẋs .== nothing)
119121
Base.depwarn(
120122
"test_frule(f, k ⊢ nothing) is deprecated, use " *
121123
"test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
122-
:test_frule
124+
:test_frule,
123125
)
124126
end
125127

@@ -132,8 +134,6 @@ function test_frule(
132134
end # top-level testset
133135
end
134136

135-
136-
137137
"""
138138
test_rrule(f, inputs...; kwargs...)
139139
@@ -154,12 +154,15 @@ end
154154
- All remaining keyword arguments are passed to `isapprox`.
155155
"""
156156
function test_rrule(
157-
f, inputs...;
157+
f,
158+
inputs...;
158159
output_tangent=Auto(),
159160
fdm=_fdm,
160161
check_inferred::Bool=true,
161162
fkwargs::NamedTuple=NamedTuple(),
162-
rtol::Real=1e-9, atol::Real=1e-9, kwargs...
163+
rtol::Real=1e-9,
164+
atol::Real=1e-9,
165+
kwargs...,
163166
)
164167
# To simplify some of the calls we make later lets group the kwargs for reuse
165168
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
@@ -191,22 +194,22 @@ function test_rrule(
191194

192195
# Correctness testing via finite differencing.
193196
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
194-
x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing, NoTangent})
197+
x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing,NoTangent})
195198
if any(accumulated_x̄ .== nothing)
196199
Base.depwarn(
197200
"test_rrule(f, k ⊢ nothing) is deprecated, use " *
198201
"test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
199-
:test_rrule
202+
:test_rrule,
200203
)
201204
end
202205

203206
x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
204207
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
205-
if accumulated_x̄ isa Union{Nothing, NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
208+
if accumulated_x̄ isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
206209
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
207210
x̄_ad isa ZeroTangent && error(
208211
"The pullback in the rrule for $f function should use NoTangent()" *
209-
" rather than ZeroTangent() for non-perturbable arguments."
212+
" rather than ZeroTangent() for non-perturbable arguments.",
210213
)
211214
@test x̄_ad isa NoTangent # we said it wasn't differentiable.
212215
else
@@ -224,8 +227,8 @@ end
224227

225228
function check_thunking_is_appropriate(x̄s)
226229
@testset "Don't thunk only non_zero argument" begin
227-
num_zeros = count(x->x isa AbstractZero, x̄s)
228-
num_thunks = count(x->x isa Thunk, x̄s)
230+
num_zeros = count(x -> x isa AbstractZero, x̄s)
231+
num_thunks = count(x -> x isa Thunk, x̄s)
229232
if num_zeros + num_thunks == length(x̄s)
230233
@test num_thunks !== 1
231234
end
@@ -235,12 +238,10 @@ end
235238
function _ensure_not_running_on_functor(f, name)
236239
# if x itself is a Type, then it is a constructor, thus not a functor.
237240
# This also catchs UnionAll constructors which have a `:var` and `:body` fields
238-
f isa Type && return
241+
f isa Type && return nothing
239242

240243
if fieldcount(typeof(f)) > 0
241-
throw(ArgumentError(
242-
"$name cannot be used on closures/functors (such as $f)"
243-
))
244+
throw(ArgumentError("$name cannot be used on closures/functors (such as $f)"))
244245
end
245246
end
246247

0 commit comments

Comments
 (0)