@@ -598,6 +598,35 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
598
598
test_rrule (rev_trouble, (3 , 3.0 ) ⊢ Tangent {Tuple{Int,Float64}} (ZeroTangent (), 1.0 ))
599
599
end
600
600
601
+ @testset " check_thunked_output_tangent" begin
602
+ @testset " no method for thunk" begin
603
+ does_not_accept_thunk_id (x) = x
604
+ function ChainRulesCore. rrule (:: typeof (does_not_accept_thunk_id), x)
605
+ does_not_accept_thunk_id_pullback (ȳ:: AbstractArray ) = (NoTangent () ,ȳ)
606
+ return does_not_accept_thunk_id (x), does_not_accept_thunk_id_pullback
607
+ end
608
+
609
+ test_rrule (
610
+ does_not_accept_thunk_id, [1.0 , 2.0 ]; check_thunked_output_tangent= false
611
+ )
612
+ @test errors (r" MethodError.*Thunk" ) do
613
+ test_rrule (does_not_accept_thunk_id, [1.0 , 2.0 ])
614
+ end
615
+ end
616
+
617
+ @testset " Thunk wrong" begin
618
+ bad_thunk_id (x) = x
619
+ function ChainRulesCore. rrule (:: typeof (bad_thunk_id), x)
620
+ bad_thunk_id_pullback (ȳ:: AbstractArray ) = (NoTangent (), ȳ)
621
+ bad_thunk_id_pullback (ȳ:: AbstractThunk ) = (NoTangent (), 2 * ȳ)
622
+ return bad_thunk_id (x), bad_thunk_id_pullback
623
+ end
624
+
625
+ test_rrule (bad_thunk_id, [1.0 , 2.0 ]; check_thunked_output_tangent= false )
626
+ @test fails (()-> test_rrule (bad_thunk_id, [1.0 , 2.0 ]))
627
+ end
628
+ end
629
+
601
630
@testset " error message about incorrectly using ZeroTangent()" begin
602
631
foo (a, i) = a[i]
603
632
function ChainRulesCore. rrule (:: typeof (foo), a, i)
0 commit comments