Skip to content

Commit 8431671

Browse files
authored
Implement ADviaFDConfig (#243)
1 parent 487890f commit 8431671

File tree

6 files changed

+108
-18
lines changed

6 files changed

+108
-18
lines changed

src/ChainRulesTestUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@ include("rule_config.jl")
3333
include("finite_difference_calls.jl")
3434
include("testers.jl")
3535

36+
include("deprecated.jl")
3637
include("global_checks.jl")
3738
end # module

src/deprecated.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Base.@deprecate_binding ADviaRuleConfig TestConfig false

src/rule_config.jl

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,44 @@
1-
# For testing this config re-dispatches Xrule_via_ad to Xrule without config argument
2-
struct ADviaRuleConfig <: RuleConfig{Union{HasReverseMode, HasForwardsMode}} end
1+
# For testing this config uses finite differences to evaluate the frule and rrule
2+
struct TestConfig <: RuleConfig{Union{HasReverseMode, HasForwardsMode}}
3+
fdm
4+
end
5+
TestConfig() = TestConfig(central_fdm(5, 1))
6+
7+
function ChainRulesCore.frule_via_ad(config::TestConfig, ȧrgs, f, args...; kws...)
38

4-
function ChainRulesCore.frule_via_ad(config::ADviaRuleConfig, ȧrgs, f, args...; kws...)
9+
# try using a rule
510
ret = frule(config, ȧrgs, f, args...; kws...)
6-
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
7-
ret === nothing && throw(MethodError(frule, (ȧrgs, f, args...)))
8-
return ret
11+
ret === nothing || return ret
12+
13+
# but if the rule doesn't exist, use finite differencing instead
14+
call_on_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(kws)...)
15+
16+
primals = (f, args...)
17+
is_ignored = isa.(ȧrgs, NoTangent)
18+
19+
Ω = call_on_copy(f, args...)
20+
ΔΩ = _make_jvp_call(config.fdm, call_on_copy, Ω, primals, ȧrgs, is_ignored)
21+
22+
return Ω, ΔΩ
923
end
1024

11-
function ChainRulesCore.rrule_via_ad(config::ADviaRuleConfig, f, args...; kws...)
25+
function ChainRulesCore.rrule_via_ad(config::TestConfig, f, args...; kws...)
26+
27+
# try using a rule
1228
ret = rrule(config, f, args...; kws...)
13-
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
14-
ret === nothing && throw(MethodError(rrule, (f, args...)))
15-
return ret
29+
ret === nothing || return ret
30+
31+
# but if the rule doesn't exist, use finite differencing instead
32+
call(f, xs...) = f(xs...; kws...)
33+
34+
# this block is here just to work out which tangents should be ignored
35+
primals = (f, args...)
36+
primals_and_tangents = auto_primal_and_tangent.(primals)
37+
is_ignored = isa.(tangent.(primals_and_tangents), NoTangent)
38+
39+
function f_pb(ȳ)
40+
return _make_j′vp_call(config.fdm, call, ȳ, primals, is_ignored)
41+
end
42+
43+
return call(f, args...), f_pb
1644
end

src/testers.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272
test_frule([config::RuleConfig,] f, args..; kwargs...)
7373
7474
# Arguments
75-
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
75+
- `config`: defaults to `ChainRulesTestUtils.TestConfig`.
7676
- `f`: function for which the `frule` should be tested. Its tangent can be provided using `f ⊢ ḟ`.
7777
(You can enter `⊢` via `\\vdash` + tab in the Julia REPL and supporting editors.)
7878
- `args...`: either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
@@ -92,8 +92,7 @@ end
9292
- All remaining keyword arguments are passed to `isapprox`.
9393
"""
9494
function test_frule(args...; kwargs...)
95-
config = ChainRulesTestUtils.ADviaRuleConfig()
96-
test_frule(config, args...; kwargs...)
95+
test_frule(TestConfig(), args...; kwargs...)
9796
end
9897

9998
function test_frule(
@@ -146,7 +145,7 @@ end
146145
test_rrule([config::RuleConfig,] f, args...; kwargs...)
147146
148147
# Arguments
149-
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
148+
- `config`: defaults to `ChainRulesTestUtils.TestConfig`.
150149
- `f`: function for which the `rrule` should be tested. Its tangent can be provided using `f ⊢ f̄`.
151150
(You can enter `⊢` via `\\vdash` + tab in the Julia REPL and supporting editors.)
152151
- `args...`: either the primal args `x`, or primals and their tangents: `x ⊢ x̄`
@@ -168,8 +167,7 @@ end
168167
- All remaining keyword arguments are passed to `isapprox`.
169168
"""
170169
function test_rrule(args...; kwargs...)
171-
config = ChainRulesTestUtils.ADviaRuleConfig()
172-
test_rrule(config, args...; kwargs...)
170+
test_rrule(TestConfig(), args...; kwargs...)
173171
end
174172

175173
function test_rrule(

test/rule_config.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,66 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
4949
errors(() -> test_rrule(has_config, rand()), "no method matching rrule")
5050
errors(() -> test_rrule(has_trait, rand()), "no method matching rrule")
5151
end
52+
53+
@testset "TestConfig direct" begin
54+
poly(x) = x^2 + 3.2x
55+
56+
x = 2.1
57+
config = ChainRulesTestUtils.TestConfig()
58+
59+
@testset "rrule" begin
60+
y, pb = rrule_via_ad(config, poly, x)
61+
@test y == poly(x)
62+
test_approx(pb(1.0), (NoTangent(), (2*x + 3.2) * 1.0))
63+
# and automatically
64+
test_rrule(config, poly, rand(); rrule_f=rrule_via_ad, check_inferred=false)
65+
end
66+
67+
@testset "frule" begin
68+
ḟ, ẋ = (NoTangent(), rand())
69+
Ω, ΔΩ = frule_via_ad(config, (ḟ, ẋ), poly, x)
70+
@test Ω == poly(x)
71+
test_approx(ΔΩ, (2*x + 3.2) * ẋ)
72+
# and automatically
73+
test_frule(config, poly, x; frule_f=frule_via_ad, check_inferred=false)
74+
end
75+
76+
# more functions
77+
simo(x) = (x, 2x, 3x)
78+
miso(x, y, z) = x+y
79+
test_rrule(config, simo, x; rrule_f=rrule_via_ad, check_inferred=false)
80+
test_rrule(config, miso, x, 2x, "s"; rrule_f=rrule_via_ad, check_inferred=false)
81+
82+
test_frule(config, simo, x; frule_f=frule_via_ad, check_inferred=false)
83+
test_frule(config, miso, x, x, "s"; frule_f=frule_via_ad, check_inferred=false)
84+
end
85+
86+
@testset "TestConfig in a rule" begin
87+
inner(x, y) = x^2 + 2*y + 3
88+
outer(f, x) = 2 * f(x, 3.2)
89+
90+
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(outer), f, x)
91+
fx, pb_f = rrule_via_ad(config, f, x, 3.2)
92+
outer_pb(ȳ) = (NoTangent(), pb_f(2 * ȳ)[1:2]...)
93+
return outer(f, x), outer_pb
94+
end
95+
96+
function ChainRulesCore.frule(config::RuleConfig{>:HasForwardsMode}, (ȯuter, ḟ, ẋ), ::typeof(outer), f, x)
97+
inner, inner_dot = frule_via_ad(config, (ḟ, ẋ, NoTangent()), f, x, 3.2)
98+
return 2 * inner, 2 * inner_dot
99+
end
100+
101+
config = ChainRulesTestUtils.TestConfig()
102+
test_rrule(config, outer, inner, rand(); rrule_f=rrule_via_ad, check_inferred=false)
103+
test_frule(config, outer, inner, rand(); frule_f=frule_via_ad, check_inferred=false)
104+
end
105+
106+
@testset "Catch incorrect rules" begin
107+
myid(x) = x
108+
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(myid), x)
109+
wrong_pb(dy) = (NoTangent(), 8dy)
110+
return x, wrong_pb
111+
end
112+
@test fails(() -> test_rrule(myid, 3.0; rrule_f=rrule_via_ad, check_inferred=false))
113+
end
52114
end

test/testers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,11 @@ end
285285

286286
@testset "single input, multiple output" begin
287287
simo(x) = (x, 2x)
288-
function ChainRulesCore.rrule(simo, x)
288+
function ChainRulesCore.rrule(::typeof(simo), x)
289289
simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b)
290290
return simo(x), simo_pullback
291291
end
292-
function ChainRulesCore.frule((_, ẋ), simo, x)
292+
function ChainRulesCore.frule((_, ẋ), ::typeof(simo), x)
293293
y = simo(x)
294294
return y, Tangent{typeof(y)}(ẋ, 2ẋ)
295295
end

0 commit comments

Comments
 (0)