@@ -53,10 +53,6 @@ function ChainRulesCore.frule((Δf, Δx), f::Foo, x)
53
53
return f (x), Δf. a + Δx
54
54
end
55
55
56
- # testing configs
57
- abstract type MySpecialTrait end
58
- struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
59
-
60
56
# Type-stable derivative for test below
61
57
struct FVecOfTuplesPullback{T} end
62
58
function (f:: FVecOfTuplesPullback{T} )(Δ) where {T}
664
660
test_rrule (f_notimplemented2, identity, randn ())
665
661
end
666
662
667
- @testset " custom rrule_f" begin
668
- only2x (x, y) = 2 x
669
- custom (:: RuleConfig , :: typeof (only2x), x, y) = only2x (x, y), Δ -> (NoTangent (), 2 Δ, ZeroTangent ())
670
- wrong1 (:: RuleConfig , :: typeof (only2x), x, y) = only2x (x, y), Δ -> (ZeroTangent (), 2 Δ, ZeroTangent ())
671
- wrong2 (:: RuleConfig , :: typeof (only2x), x, y) = only2x (x, y), Δ -> (NoTangent (), 2.1 Δ, ZeroTangent ())
672
- wrong3 (:: RuleConfig , :: typeof (only2x), x, y) = only2x (x, y), Δ -> (NoTangent (), 2 Δ)
673
-
674
- test_rrule (only2x, 2.0 , 3.0 ; rrule_f= custom, check_inferred= false )
675
- @test errors (() -> test_rrule (only2x, 2.0 , 3.0 ; rrule_f= wrong1, check_inferred= false ))
676
- @test fails (() -> test_rrule (only2x, 2.0 , 3.0 ; rrule_f= wrong2, check_inferred= false ))
677
- @test fails (() -> test_rrule (only2x, 2.0 , 3.0 ; rrule_f= wrong3, check_inferred= false ))
678
- end
679
-
680
- @testset " custom frule_f" begin
681
- mytuple (x, y) = return 2 x, 1.0
682
- T = Tuple{Float64, Float64}
683
- custom (:: RuleConfig , (Δf, Δx, Δy), :: typeof (mytuple), x, y) = mytuple (x, y), Tangent {T} (2 Δx, ZeroTangent ())
684
- wrong1 (:: RuleConfig , (Δf, Δx, Δy), :: typeof (mytuple), x, y) = mytuple (x, y), Tangent {T} (2.1 Δx, ZeroTangent ())
685
- wrong2 (:: RuleConfig , (Δf, Δx, Δy), :: typeof (mytuple), x, y) = mytuple (x, y), Tangent {T} (2 Δx, 1.0 )
686
-
687
- test_frule (mytuple, 2.0 , 3.0 ; frule_f= custom, check_inferred= false )
688
- @test fails (() -> test_frule (mytuple, 2.0 , 3.0 ; frule_f= wrong1, check_inferred= false ))
689
- @test fails (() -> test_frule (mytuple, 2.0 , 3.0 ; frule_f= wrong2, check_inferred= false ))
690
- end
691
-
692
- @testset " custom_config" begin
693
- has_config (x) = 2 x
694
- function ChainRulesCore. rrule (:: MySpecialConfig , :: typeof (has_config), x)
695
- has_config_pullback (ȳ) = return (NoTangent (), 2 ȳ)
696
- return has_config (x), has_config_pullback
697
- end
698
-
699
- has_trait (x) = 2 x
700
- function ChainRulesCore. rrule (:: RuleConfig{<:MySpecialTrait} , :: typeof (has_trait), x)
701
- has_trait_pullback (ȳ) = return (NoTangent (), 2 ȳ)
702
- return has_trait (x), has_trait_pullback
703
- end
704
-
705
- # it works if the special config is provided
706
- test_rrule (MySpecialConfig (), has_config, rand ())
707
- test_rrule (MySpecialConfig (), has_trait, rand ())
708
-
709
- # but it doesn't work for the default config
710
- errors (() -> test_rrule (has_config, rand ()), " no method matching rrule" )
711
- errors (() -> test_rrule (has_trait, rand ()), " no method matching rrule" )
712
- end
713
-
714
663
@testset " @maybe_inferred" begin
715
664
f_noninferrable (x) = Ref {Real} (x)[]
716
665
0 commit comments