Skip to content

Commit abc55f1

Browse files
committed
confirm that check_thunked_output_tangent catches problems
1 parent 9918793 commit abc55f1

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

src/check_result.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function test_approx(actual::A, expected::E, msg=""; kwargs...) where {A,E}
143143
if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow
144144
throw(MethodError, test_approx, (actual, expected))
145145
end
146-
test_approx(c_actual, c_expected; kwargs...)
146+
test_approx(c_actual, c_expected, msg; kwargs...)
147147
end
148148
end
149149

src/testers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ function test_rrule(
241241
end
242242

243243
if check_thunked_output_tangent
244-
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk")
244+
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
245245
end
246246
end # top-level testset
247247
end

test/testers.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,35 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
598598
test_rrule(rev_trouble, (3, 3.0) Tangent{Tuple{Int,Float64}}(ZeroTangent(), 1.0))
599599
end
600600

601+
@testset "check_thunked_output_tangent" begin
602+
@testset "no method for thunk" begin
603+
does_not_accept_thunk_id(x) = x
604+
function ChainRulesCore.rrule(::typeof(does_not_accept_thunk_id), x)
605+
does_not_accept_thunk_id_pullback(ȳ::AbstractArray) = (NoTangent() ,ȳ)
606+
return does_not_accept_thunk_id(x), does_not_accept_thunk_id_pullback
607+
end
608+
609+
test_rrule(
610+
does_not_accept_thunk_id, [1.0, 2.0]; check_thunked_output_tangent=false
611+
)
612+
@test errors(r"MethodError.*Thunk") do
613+
test_rrule(does_not_accept_thunk_id, [1.0, 2.0])
614+
end
615+
end
616+
617+
@testset "Thunk wrong" begin
618+
bad_thunk_id(x) = x
619+
function ChainRulesCore.rrule(::typeof(bad_thunk_id), x)
620+
bad_thunk_id_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ)
621+
bad_thunk_id_pullback(ȳ::AbstractThunk) = (NoTangent(), 2 * ȳ)
622+
return bad_thunk_id(x), bad_thunk_id_pullback
623+
end
624+
625+
test_rrule(bad_thunk_id, [1.0, 2.0]; check_thunked_output_tangent=false)
626+
@test fails(()->test_rrule(bad_thunk_id, [1.0, 2.0]))
627+
end
628+
end
629+
601630
@testset "error message about incorrectly using ZeroTangent()" begin
602631
foo(a, i) = a[i]
603632
function ChainRulesCore.rrule(::typeof(foo), a, i)

0 commit comments

Comments
 (0)