File tree Expand file tree Collapse file tree 3 files changed +18
-2
lines changed Expand file tree Collapse file tree 3 files changed +18
-2
lines changed Original file line number Diff line number Diff line change 1
1
name = " ChainRulesTestUtils"
2
2
uuid = " cdddcdb0-9152-4a09-a978-84456f9df70a"
3
- version = " 1.2.1 "
3
+ version = " 1.2.2 "
4
4
5
5
[deps ]
6
6
ChainRulesCore = " d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Original file line number Diff line number Diff line change @@ -242,6 +242,7 @@ function test_rrule(
242
242
243
243
if check_thunked_output_tangent
244
244
test_approx (ad_cotangents, pullback (@thunk (ȳ)), " pulling back a thunk:" )
245
+ check_inferred && _test_inferred (pullback, @thunk (ȳ))
245
246
end
246
247
end # top-level testset
247
248
end
Original file line number Diff line number Diff line change 57
57
abstract type MySpecialTrait end
58
58
struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
59
59
60
-
61
60
@testset " testers.jl" begin
62
61
@testset " test_scalar" begin
63
62
@testset " Ensure correct rules succeed" begin
@@ -711,4 +710,20 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
711
710
712
711
ChainRulesTestUtils. TEST_INFERRED[] = true
713
712
end
713
+
714
+ @testset " inference of thunked cotangents" begin
715
+ my_id (x) = x
716
+ function ChainRulesCore. rrule (:: typeof (my_id), x)
717
+ my_id_pb (ȳ) = (NoTangent (), ȳ)
718
+ function my_id_pb (ȳ:: AbstractThunk )
719
+ precision = rand () > 0.5 ? Float64 : Float32
720
+ return (NoTangent (), precision (unthunk (ȳ)))
721
+ end
722
+ return x, my_id_pb
723
+ end
724
+
725
+ @test errors (() -> test_rrule (my_id, 2.0 ))
726
+ test_rrule (my_id, 2.0 ; check_inferred= false )
727
+ test_rrule (my_id, 2.0 ; check_thunked_output_tangent= false )
728
+ end
714
729
end
You can’t perform that action at this time.
0 commit comments