Skip to content

Commit 8c5da46

Browse files
authored
test arbitrary functions with f/rrule-like API (#166)
1 parent 83f6ad3 commit 8c5da46

File tree

10 files changed

+183
-30
lines changed

10 files changed

+183
-30
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.7.8"
3+
version = "0.7.9"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.10"
14+
ChainRulesCore = "0.10.4"
1515
Compat = "3"
1616
FiniteDifferences = "0.12.12"
1717
julia = "1"

docs/Manifest.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[ChainRulesCore]]
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82"
14+
git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba"
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.2"
16+
version = "0.10.4"
1717

1818
[[ChainRulesTestUtils]]
1919
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
2020
path = ".."
2121
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
22-
version = "0.7.5"
22+
version = "0.7.9"
2323

2424
[[Compat]]
2525
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -40,10 +40,10 @@ deps = ["Random", "Serialization", "Sockets"]
4040
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
4141

4242
[[DocStringExtensions]]
43-
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
44-
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
43+
deps = ["LibGit2"]
44+
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
4545
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
46-
version = "0.8.4"
46+
version = "0.8.5"
4747

4848
[[Documenter]]
4949
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
@@ -57,9 +57,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
5757

5858
[[FiniteDifferences]]
5959
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
60-
git-tree-sha1 = "5d448db3b862fb331d20144c2e59c54db69720e0"
60+
git-tree-sha1 = "bdc9fb1d27a1ccecd2fe8f39c6211524cbe642cb"
6161
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
62-
version = "0.12.12"
62+
version = "0.12.13"
6363

6464
[[IOCapture]]
6565
deps = ["Logging"]

docs/make.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ makedocs(;
66
format=Documenter.HTML(; prettyurls=false, assets=["assets/chainrules.css"]),
77
sitename="ChainRulesTestUtils",
88
authors="JuliaDiff contributors",
9+
pages=[
10+
"ChainRulesTestUtils" => "index.md",
11+
"API" => "api.md",
12+
],
913
strict=true,
1014
checkdocs=:exports,
1115
)

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# API Documentation
2+
3+
```@autodocs
4+
Modules = [ChainRulesTestUtils]
5+
Private = false
6+
```

docs/src/index.md

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,67 @@ In particular, when specifying the input tangents to [`test_frule`](@ref) and th
145145
As these tangents are used to seed the derivative computation.
146146
Inserting inappropriate zeros can thus hide errors.
147147

148+
## Testing higher order functions
149+
150+
Higher order functions, such as `map`, take a function (or a functor) `f` as an argument.
151+
`f/rrule`s for these functions call back into AD to compute the `f/rrule` of `f`.
152+
To test these functions, we use a dummy AD system, which simply calls the appropriate rule for `f` directly.
153+
For this reason, when testing `map(f, collection)`, the rules for `f` need to be defined.
154+
The `RuleConfig` for this dummy AD system is the default one, and does not need to be provided.
155+
```julia
156+
test_rrule(map, x->2x [1, 2, 3.]) # fails, because there is no rrule for x->2x
157+
158+
mydouble(x) = 2x
159+
function ChainRulesCore.rrule(::typeof(mydouble), x)
160+
mydouble_pullback(ȳ) = (NoTangent(), ȳ)
161+
return mydouble(x), mydouble_pullback
162+
end
163+
test_rrule(map, mydouble, [1, 2, 3.]) # works
164+
```
165+
166+
## Testing AD systems
167+
168+
The gradients computed by AD systems can be also be tested using `test_rrule`.
169+
To do that, one needs to provide an `rrule_f`/`frule_f` keyword argument, as well as the `RuleConfig` used by the AD system.
170+
`rrule_f` is a function that wraps the gradient computation by an AD system in the same API as the `rrule`.
171+
`RuleConfig` is an object that determines which sets of rules are defined for an AD system.
172+
For example, let's say we have a complicated function
173+
174+
```julia
175+
function complicated(x, y)
176+
return do(x + y) + some(x) * hard(y) + maths(x * y)
177+
end
178+
```
179+
180+
that we do not know an `rrule` for, and we want to check whether the gradients provided by the AD system are correct.
181+
182+
Firstly, we need to define an `rrule`-like function which wraps the gradients computed by AD.
183+
184+
Let's say the AD package uses some custom differential types and does not provide a gradient w.r.t. the function itself.
185+
In order to make the pullback compatible with the `rrule` API we need to add a `NoTangent()` to represent the differential w.r.t. the function itself.
186+
We also need to transform the `ChainRules` differential types to the custom types (`cr2custom`) before feeding the `Δ` to the AD-generated pullback, and back to `ChainRules` differential types when returning from the `rrule` (`custom2cr`).
187+
188+
```julia
189+
function ad_rrule(f::Function, args...)
190+
y, ad_pullback = ADSystem.pullback(f, args...)
191+
function rrulelike_pullback(Δ)
192+
diffs = custom2cr(ad_pullback(cr2custom(Δ)))
193+
return NoTangent(), diffs...
194+
end
195+
196+
return y, rrulelike_pullback
197+
end
198+
199+
custom2cr(differential) = ...
200+
cr2custom(differential) = ...
201+
```
202+
Secondly, we use the `test_rrule` function to test the gradients using the config used by the AD system
203+
```julia
204+
config = MyAD.CustomRuleConfig()
205+
test_rrule(config, complicated, 2.3, 6.1; rrule_f=ad_rrule)
206+
```
207+
by specifying the `ad_rrule` as the `rrule_f` keyword argument.
208+
148209
## Custom finite differencing
149210

150211
If a package is using a custom finite differencing method of testing the `frule`s and `rrule`s, `test_approx` function provides a convenient way of comparing [various types](https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html#Design-Notes:-The-many-to-many-relationship-between-differential-types-and-primal-types.) of differentials.
@@ -199,10 +260,3 @@ Test.DefaultTestSet("test_rrule: abs on Float64", Any[], 5, false, false)
199260

200261
This behavior can also be overridden globally by setting the environment variable `CHAINRULES_TEST_INFERRED` before ChainRulesTestUtils is loaded or by changing `ChainRulesTestUtils.TEST_INFERRED[]` from inside Julia.
201262
ChainRulesTestUtils can detect whether a test is run as part of [PkgEval](https://github.com/JuliaCI/PkgEval.jl)and in this case disables inference tests automatically. Packages can use [`@maybe_inferred`](@ref) to get the same behavior for other inference tests.
202-
203-
# API Documentation
204-
205-
```@autodocs
206-
Modules = [ChainRulesTestUtils]
207-
Private = false
208-
```

src/ChainRulesTestUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ include("iterator.jl")
3636
include("output_control.jl")
3737
include("check_result.jl")
3838

39+
include("rule_config.jl")
3940
include("finite_difference_calls.jl")
4041
include("testers.jl")
4142

src/check_result.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs
3030
end
3131
end
3232

33-
test_approx(::ZeroTangent, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
34-
test_approx(x, ::ZeroTangent, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
33+
test_approx(::AbstractZero, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
34+
test_approx(x, ::AbstractZero, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
3535
test_approx(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true
3636

3737
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113

src/rule_config.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# For testing this config re-dispatches Xrule_via_ad to Xrule without config argument
2+
struct ADviaRuleConfig <: RuleConfig{Union{HasReverseMode, HasForwardsMode}} end
3+
4+
function ChainRulesCore.frule_via_ad(config::ADviaRuleConfig, ȧrgs, f, args...; kws...)
5+
ret = frule(config, ȧrgs, f, args...; kws...)
6+
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
7+
ret === nothing && throw(MethodError(frule, (ȧrgs, f, args...)))
8+
return ret
9+
end
10+
11+
function ChainRulesCore.rrule_via_ad(config::ADviaRuleConfig, f, args...; kws...)
12+
ret = rrule(config, f, args...; kws...)
13+
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
14+
ret === nothing && throw(MethodError(rrule, (f, args...)))
15+
return ret
16+
end

src/testers.jl

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,29 +67,39 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
6767
end
6868

6969
"""
70-
test_frule(f, args..; kwargs...)
70+
test_frule([config::RuleConfig,] f, args..; kwargs...)
7171
7272
# Arguments
73+
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
7374
- `f`: Function for which the `frule` should be tested. Can also provide `f ⊢ ḟ`.
7475
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
7576
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
76-
- ``: differential w.r.t. `x`, will be generated automatically if not provided
77-
Non-differentiable arguments, such as indices, should have `` set as `NoTangent()`.
77+
- ``: differential w.r.t. `x`, will be generated automatically if not provided
78+
Non-differentiable arguments, such as indices, should have `` set as `NoTangent()`.
7879
7980
# Keyword Arguments
8081
- `output_tangent` tangent to test accumulation of derivatives against
8182
should be a differential for the output of `f`. Is set automatically if not provided.
8283
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
84+
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
85+
`frule`). Used for testing gradients from AD systems.
8386
- If `check_inferred=true`, then the inferrability of the `frule` is checked,
8487
as long as `f` is itself inferrable.
8588
- `fkwargs` are passed to `f` as keyword arguments.
8689
- All remaining keyword arguments are passed to `isapprox`.
8790
"""
91+
function test_frule(args...; kwargs...)
92+
config = ChainRulesTestUtils.ADviaRuleConfig()
93+
test_frule(config, args...; kwargs...)
94+
end
95+
8896
function test_frule(
97+
config::RuleConfig,
8998
f,
9099
args...;
91100
output_tangent=Auto(),
92101
fdm=_fdm,
102+
frule_f=ChainRulesCore.frule,
93103
check_inferred::Bool=true,
94104
fkwargs::NamedTuple=NamedTuple(),
95105
rtol::Real=1e-9,
@@ -109,11 +119,11 @@ function test_frule(
109119
tangents = tangent.(primals_and_tangents)
110120

111121
if check_inferred && _is_inferrable(deepcopy(primals)...; deepcopy(fkwargs)...)
112-
_test_inferred(frule, deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
122+
_test_inferred(frule_f, deepcopy(config), deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
113123
end
114124

115-
res = frule(deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
116-
res === nothing && throw(MethodError(frule, typeof(primals)))
125+
res = frule_f(deepcopy(config), deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
126+
res === nothing && throw(MethodError(frule_f, typeof(primals)))
117127
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
118128
Ω_ad, dΩ_ad = res
119129
Ω = call_on_copy(primals...)
@@ -139,9 +149,10 @@ function test_frule(
139149
end
140150

141151
"""
142-
test_rrule(f, args...; kwargs...)
152+
test_rrule([config::RuleConfig,] f, args...; kwargs...)
143153
144154
# Arguments
155+
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
145156
- `f`: Function to which rule should be applied. Can also provide `f ⊢ f̄`.
146157
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ x̄`
147158
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
@@ -152,16 +163,25 @@ end
152163
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
153164
should be a differential for the output of `f`. Is set automatically if not provided.
154165
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
166+
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
167+
Used for testing gradients from AD systems.
155168
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
156169
— if `f` is itself inferrable — along with the inferrability of the pullback it returns.
157170
- `fkwargs` are passed to `f` as keyword arguments.
158171
- All remaining keyword arguments are passed to `isapprox`.
159172
"""
173+
function test_rrule(args...; kwargs...)
174+
config = ChainRulesTestUtils.ADviaRuleConfig()
175+
test_rrule(config, args...; kwargs...)
176+
end
177+
160178
function test_rrule(
179+
config::RuleConfig,
161180
f,
162181
args...;
163182
output_tangent=Auto(),
164183
fdm=_fdm,
184+
rrule_f=ChainRulesCore.rrule,
165185
check_inferred::Bool=true,
166186
fkwargs::NamedTuple=NamedTuple(),
167187
rtol::Real=1e-9,
@@ -182,15 +202,15 @@ function test_rrule(
182202
accum_cotangents = tangent.(primals_and_tangents)
183203

184204
if check_inferred && _is_inferrable(primals...; fkwargs...)
185-
_test_inferred(rrule, primals...; fkwargs...)
205+
_test_inferred(rrule_f, config, primals...; fkwargs...)
186206
end
187-
res = rrule(primals...; fkwargs...)
188-
res === nothing && throw(MethodError(rrule, typeof((primals...))))
207+
res = rrule_f(config, primals...; fkwargs...)
208+
res === nothing && throw(MethodError(rrule_f, typeof(primals)))
189209
y_ad, pullback = res
190210
y = call(primals...)
191211
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
192212

193-
= output_tangent isa Auto ? rand_tangent(y) : output_tangent
213+
ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent
194214

195215
check_inferred && _test_inferred(pullback, ȳ)
196216
ad_cotangents = pullback(ȳ)

test/testers.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ function ChainRulesCore.frule((Δf, Δx), f::Foo, x)
5353
return f(x), Δf.a + Δx
5454
end
5555

56+
# testing configs
57+
abstract type MySpecialTrait end
58+
struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
59+
60+
5661
@testset "testers.jl" begin
5762
@testset "test_scalar" begin
5863
@testset "Ensure correct rules succeed" begin
@@ -599,6 +604,53 @@ end
599604
test_rrule(f_notimplemented, randn(), randn())
600605
end
601606

607+
@testset "custom rrule_f" begin
608+
only2x(x, y) = 2x
609+
custom(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2Δ, ZeroTangent())
610+
wrong1(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (ZeroTangent(), 2Δ, ZeroTangent())
611+
wrong2(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2.1Δ, ZeroTangent())
612+
wrong3(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2Δ)
613+
614+
test_rrule(only2x, 2.0, 3.0; rrule_f=custom, check_inferred=false)
615+
@test errors(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong1, check_inferred=false))
616+
@test fails(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong2, check_inferred=false))
617+
@test fails(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong3, check_inferred=false))
618+
end
619+
620+
@testset "custom frule_f" begin
621+
mytuple(x, y) = return 2x, 1.0
622+
T = Tuple{Float64, Float64}
623+
custom(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2Δx, ZeroTangent())
624+
wrong1(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2.1Δx, ZeroTangent())
625+
wrong2(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2Δx, 1.0)
626+
627+
test_frule(mytuple, 2.0, 3.0; frule_f=custom, check_inferred=false)
628+
@test fails(() -> test_frule(mytuple, 2.0, 3.0; frule_f=wrong1, check_inferred=false))
629+
@test fails(() -> test_frule(mytuple, 2.0, 3.0; frule_f=wrong2, check_inferred=false))
630+
end
631+
632+
@testset "custom_config" begin
633+
has_config(x) = 2x
634+
function ChainRulesCore.rrule(::MySpecialConfig, ::typeof(has_config), x)
635+
has_config_pullback(ȳ) = return (NoTangent(), 2ȳ)
636+
return has_config(x), has_config_pullback
637+
end
638+
639+
has_trait(x) = 2x
640+
function ChainRulesCore.rrule(::RuleConfig{<:MySpecialTrait}, ::typeof(has_trait), x)
641+
has_trait_pullback(ȳ) = return (NoTangent(), 2ȳ)
642+
return has_trait(x), has_trait_pullback
643+
end
644+
645+
# it works if the special config is provided
646+
test_rrule(MySpecialConfig(), has_config, rand())
647+
test_rrule(MySpecialConfig(), has_trait, rand())
648+
649+
# but it doesn't work for the default config
650+
errors(() -> test_rrule(has_config, rand()), "no method matching rrule")
651+
errors(() -> test_rrule(has_trait, rand()), "no method matching rrule")
652+
end
653+
602654
@testset "@maybe_inferred" begin
603655
f_noninferrable(x) = Ref{Real}(x)[]
604656

0 commit comments

Comments
 (0)