From bedaf09b63f8964cd65a03c525ba4d6c7edc9bec Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 23 May 2022 11:35:26 +0200 Subject: [PATCH 1/3] Pass keyword arguments to `test_approx` when checking chunks --- src/testers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/testers.jl b/src/testers.jl index 31168fc..0a28ae6 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -227,7 +227,7 @@ function test_rrule( end if check_thunked_output_tangent - test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:") + test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:"; isapprox_kwargs...) check_inferred && _test_inferred(pullback, @thunk(ȳ)) end end # top-level testset From f274dab23c533b64c333a8c2b2d97231a1bdda76 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 23 May 2022 11:36:49 +0200 Subject: [PATCH 2/3] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ee6a92b..e1f69a7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "1.7.0" +version = "1.7.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 69f1950028fdc2dd2df79eece5b73434e64f1cde Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 23 May 2022 14:30:24 +0200 Subject: [PATCH 3/3] Fix tests --- test/testers.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/testers.jl b/test/testers.jl index 00d542a..16230bd 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -683,7 +683,10 @@ end function ChainRulesCore.rrule(::typeof(my_id), x) my_id_pb(ȳ) = (NoTangent(), ȳ) function my_id_pb(ȳ::AbstractThunk) - precision = rand() > 0.5 ? Float64 : Float32 + # We use a condition that always evaluates to true to avoid issues with tolerances + # (see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/247) + # The function is type unstable for `Float64` inputs nevertheless + precision = rand() >= 0.0 ? Float64 : Float32 return (NoTangent(), precision(unthunk(ȳ))) end return x, my_id_pb