Skip to content

Commit a46fbbc

Browse files
authored
Catch inference problems for thunked cotangents (#215)
1 parent fc7408c commit a46fbbc

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "1.2.1"
3+
version = "1.2.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/testers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ function test_rrule(
242242

243243
if check_thunked_output_tangent
244244
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
245+
check_inferred && _test_inferred(pullback, @thunk(ȳ))
245246
end
246247
end # top-level testset
247248
end

test/testers.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ end
5757
abstract type MySpecialTrait end
5858
struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
5959

60-
6160
@testset "testers.jl" begin
6261
@testset "test_scalar" begin
6362
@testset "Ensure correct rules succeed" begin
@@ -711,4 +710,20 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
711710

712711
ChainRulesTestUtils.TEST_INFERRED[] = true
713712
end
713+
714+
@testset "inference of thunked cotangents" begin
715+
my_id(x) = x
716+
function ChainRulesCore.rrule(::typeof(my_id), x)
717+
my_id_pb(ȳ) = (NoTangent(), ȳ)
718+
function my_id_pb(ȳ::AbstractThunk)
719+
precision = rand() > 0.5 ? Float64 : Float32
720+
return (NoTangent(), precision(unthunk(ȳ)))
721+
end
722+
return x, my_id_pb
723+
end
724+
725+
@test errors(() -> test_rrule(my_id, 2.0))
726+
test_rrule(my_id, 2.0; check_inferred=false)
727+
test_rrule(my_id, 2.0; check_thunked_output_tangent=false)
728+
end
714729
end

0 commit comments

Comments
 (0)