Skip to content

Commit 487890f

Browse files
authored
move ruleconfig tests into a separate file (#242)
1 parent 86d8bfe commit 487890f

File tree

3 files changed

+53
-51
lines changed

3 files changed

+53
-51
lines changed

test/rule_config.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# testing configs
2+
abstract type MySpecialTrait end
3+
struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
4+
5+
@testset "rule_config.jl" begin
6+
@testset "custom rrule_f" begin
7+
only2x(x, y) = 2x
8+
custom(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2Δ, ZeroTangent())
9+
wrong1(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (ZeroTangent(), 2Δ, ZeroTangent())
10+
wrong2(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2.1Δ, ZeroTangent())
11+
wrong3(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2Δ)
12+
13+
test_rrule(only2x, 2.0, 3.0; rrule_f=custom, check_inferred=false)
14+
@test errors(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong1, check_inferred=false))
15+
@test fails(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong2, check_inferred=false))
16+
@test fails(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong3, check_inferred=false))
17+
end
18+
19+
@testset "custom frule_f" begin
20+
mytuple(x, y) = return 2x, 1.0
21+
T = Tuple{Float64, Float64}
22+
custom(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2Δx, ZeroTangent())
23+
wrong1(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2.1Δx, ZeroTangent())
24+
wrong2(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2Δx, 1.0)
25+
26+
test_frule(mytuple, 2.0, 3.0; frule_f=custom, check_inferred=false)
27+
@test fails(() -> test_frule(mytuple, 2.0, 3.0; frule_f=wrong1, check_inferred=false))
28+
@test fails(() -> test_frule(mytuple, 2.0, 3.0; frule_f=wrong2, check_inferred=false))
29+
end
30+
31+
@testset "custom_config" begin
32+
has_config(x) = 2x
33+
function ChainRulesCore.rrule(::MySpecialConfig, ::typeof(has_config), x)
34+
has_config_pullback(ȳ) = return (NoTangent(), 2ȳ)
35+
return has_config(x), has_config_pullback
36+
end
37+
38+
has_trait(x) = 2x
39+
function ChainRulesCore.rrule(::RuleConfig{<:MySpecialTrait}, ::typeof(has_trait), x)
40+
has_trait_pullback(ȳ) = return (NoTangent(), 2ȳ)
41+
return has_trait(x), has_trait_pullback
42+
end
43+
44+
# it works if the special config is provided
45+
test_rrule(MySpecialConfig(), has_config, rand())
46+
test_rrule(MySpecialConfig(), has_trait, rand())
47+
48+
# but it doesn't work for the default config
49+
errors(() -> test_rrule(has_config, rand()), "no method matching rrule")
50+
errors(() -> test_rrule(has_trait, rand()), "no method matching rrule")
51+
end
52+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ChainRulesTestUtils.TEST_INFERRED[] = true
1616
include("testers.jl")
1717
include("data_generation.jl")
1818
include("rand_tangent.jl")
19+
include("rule_config.jl")
1920

2021
include("global_checks.jl")
2122
end

test/testers.jl

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@ function ChainRulesCore.frule((Δf, Δx), f::Foo, x)
5353
return f(x), Δf.a + Δx
5454
end
5555

56-
# testing configs
57-
abstract type MySpecialTrait end
58-
struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
59-
6056
# Type-stable derivative for test below
6157
struct FVecOfTuplesPullback{T} end
6258
function (f::FVecOfTuplesPullback{T})(Δ) where {T}
@@ -664,53 +660,6 @@ end
664660
test_rrule(f_notimplemented2, identity, randn())
665661
end
666662

667-
@testset "custom rrule_f" begin
668-
only2x(x, y) = 2x
669-
custom(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2Δ, ZeroTangent())
670-
wrong1(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (ZeroTangent(), 2Δ, ZeroTangent())
671-
wrong2(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2.1Δ, ZeroTangent())
672-
wrong3(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2Δ)
673-
674-
test_rrule(only2x, 2.0, 3.0; rrule_f=custom, check_inferred=false)
675-
@test errors(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong1, check_inferred=false))
676-
@test fails(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong2, check_inferred=false))
677-
@test fails(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong3, check_inferred=false))
678-
end
679-
680-
@testset "custom frule_f" begin
681-
mytuple(x, y) = return 2x, 1.0
682-
T = Tuple{Float64, Float64}
683-
custom(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2Δx, ZeroTangent())
684-
wrong1(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2.1Δx, ZeroTangent())
685-
wrong2(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2Δx, 1.0)
686-
687-
test_frule(mytuple, 2.0, 3.0; frule_f=custom, check_inferred=false)
688-
@test fails(() -> test_frule(mytuple, 2.0, 3.0; frule_f=wrong1, check_inferred=false))
689-
@test fails(() -> test_frule(mytuple, 2.0, 3.0; frule_f=wrong2, check_inferred=false))
690-
end
691-
692-
@testset "custom_config" begin
693-
has_config(x) = 2x
694-
function ChainRulesCore.rrule(::MySpecialConfig, ::typeof(has_config), x)
695-
has_config_pullback(ȳ) = return (NoTangent(), 2ȳ)
696-
return has_config(x), has_config_pullback
697-
end
698-
699-
has_trait(x) = 2x
700-
function ChainRulesCore.rrule(::RuleConfig{<:MySpecialTrait}, ::typeof(has_trait), x)
701-
has_trait_pullback(ȳ) = return (NoTangent(), 2ȳ)
702-
return has_trait(x), has_trait_pullback
703-
end
704-
705-
# it works if the special config is provided
706-
test_rrule(MySpecialConfig(), has_config, rand())
707-
test_rrule(MySpecialConfig(), has_trait, rand())
708-
709-
# but it doesn't work for the default config
710-
errors(() -> test_rrule(has_config, rand()), "no method matching rrule")
711-
errors(() -> test_rrule(has_trait, rand()), "no method matching rrule")
712-
end
713-
714663
@testset "@maybe_inferred" begin
715664
f_noninferrable(x) = Ref{Real}(x)[]
716665

0 commit comments

Comments
 (0)