Skip to content

Commit e3ff6b4

Browse files
devmotionoxinaboxst--
authored
Allow NotImplemented tangents for things that have a correct tangent of NoTangent (#218)
* Fix `test_rrule` if cotangent is not implemented but `rand_tangent` returns `NoTangent` * Bump version * Update src/testers.jl Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * Refactor checks of cotangents * Improve error message * Extend comment * Fix spelling error * Update src/testers.jl Co-authored-by: st-- <st--@users.noreply.github.com> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: st-- <st--@users.noreply.github.com>
1 parent a46fbbc commit e3ff6b4

File tree

3 files changed

+68
-19
lines changed

3 files changed

+68
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "1.2.2"
3+
version = "1.2.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/testers.jl

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -220,24 +220,8 @@ function test_rrule(
220220
# Correctness testing via finite differencing.
221221
is_ignored = isa.(accum_cotangents, NoTangent)
222222
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
223-
224-
for (accum_cotangent, ad_cotangent, fd_cotangent) in zip(
225-
accum_cotangents, ad_cotangents, fd_cotangents
226-
)
227-
if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
228-
@assert fd_cotangent === NoTangent()
229-
ad_cotangent isa ZeroTangent && error(
230-
"The pullback in the rrule should use NoTangent()" *
231-
" rather than ZeroTangent() for non-perturbable arguments.",
232-
)
233-
@test ad_cotangent isa NoTangent # we said it wasn't differentiable.
234-
else
235-
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)
236-
237-
# The main test of the actual derivative being correct:
238-
test_approx(ad_cotangent, fd_cotangent; isapprox_kwargs...)
239-
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
240-
end
223+
foreach(accum_cotangents, ad_cotangents, fd_cotangents) do args...
224+
_test_cotangent(args...; check_inferred=check_inferred, isapprox_kwargs...)
241225
end
242226

243227
if check_thunked_output_tangent
@@ -285,3 +269,58 @@ function _is_inferrable(f, args...; kwargs...)
285269
return false
286270
end
287271
end
272+
273+
"""
274+
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...)
275+
276+
Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and
277+
approximately equal to the cotangent `fd_cotangent` obtained with finite differencing.
278+
279+
If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable,
280+
`ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well.
281+
282+
# Keyword arguments
283+
- If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if
284+
its content can be inferred.
285+
- All remaining keyword arguments are passed to `isapprox`.
286+
"""
287+
function _test_cotangent(
288+
accum_cotangent,
289+
ad_cotangent,
290+
fd_cotangent;
291+
check_inferred=true,
292+
kwargs...,
293+
)
294+
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)
295+
296+
# The main test of the actual derivative being correct:
297+
test_approx(ad_cotangent, fd_cotangent; kwargs...)
298+
_test_add!!_behaviour(accum_cotangent, ad_cotangent; kwargs...)
299+
end
300+
301+
# we marked the argument as non-differentiable
302+
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent; kwargs...)
303+
@test ad_cotangent isa NoTangent
304+
end
305+
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent; kwargs...)
306+
error(
307+
"The pullback in the rrule should use NoTangent()" *
308+
" rather than ZeroTangent() for non-perturbable arguments."
309+
)
310+
end
311+
function _test_cotangent(
312+
::NoTangent,
313+
ad_cotangent::ChainRulesCore.NotImplemented,
314+
::NoTangent;
315+
kwargs...,
316+
)
317+
# this situation can occur if a cotangent is not implemented and
318+
# the default `rand_tangent` is `NoTangent`: e.g. due to having no fields
319+
# the `@test_broken` below should tell them that there is an easy implementation for
320+
# this case of `NoTangent()` (`@test_broken false` would be less useful!)
321+
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
322+
@test_broken ad_cotangent isa NoTangent
323+
end
324+
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent; kwargs...)
325+
error("cotangent obtained with finite differencing has to be NoTangent()")
326+
end

test/testers.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,16 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
644644
@scalar_rule f_notimplemented(x, y) (@not_implemented(""), 1) (1, -1)
645645
test_frule(f_notimplemented, randn(), randn())
646646
test_rrule(f_notimplemented, randn(), randn())
647+
648+
f_notimplemented2(f, x) = 2 * f(x)
649+
function ChainRulesCore.rrule(::typeof(f_notimplemented2), f::typeof(identity), x)
650+
Ω = f_notimplemented2(f, x)
651+
function f_notimplemented2_pullback(Ω̄)
652+
return NoTangent(), @not_implemented("TODO: implement this!"), 2 * Ω̄
653+
end
654+
return Ω, f_notimplemented2_pullback
655+
end
656+
test_rrule(f_notimplemented2, identity, randn())
647657
end
648658

649659
@testset "custom rrule_f" begin

0 commit comments

Comments
 (0)