@@ -220,24 +220,8 @@ function test_rrule(
220
220
# Correctness testing via finite differencing.
221
221
is_ignored = isa .(accum_cotangents, NoTangent)
222
222
fd_cotangents = _make_j′vp_call (fdm, call, ȳ, primals, is_ignored)
223
-
224
- for (accum_cotangent, ad_cotangent, fd_cotangent) in zip (
225
- accum_cotangents, ad_cotangents, fd_cotangents
226
- )
227
- if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
228
- @assert fd_cotangent === NoTangent ()
229
- ad_cotangent isa ZeroTangent && error (
230
- " The pullback in the rrule should use NoTangent()" *
231
- " rather than ZeroTangent() for non-perturbable arguments." ,
232
- )
233
- @test ad_cotangent isa NoTangent # we said it wasn't differentiable.
234
- else
235
- ad_cotangent isa AbstractThunk && check_inferred && _test_inferred (unthunk, ad_cotangent)
236
-
237
- # The main test of the actual derivative being correct:
238
- test_approx (ad_cotangent, fd_cotangent; isapprox_kwargs... )
239
- _test_add!!_behaviour (accum_cotangent, ad_cotangent; isapprox_kwargs... )
240
- end
223
+ foreach (accum_cotangents, ad_cotangents, fd_cotangents) do args...
224
+ _test_cotangent (args... ; check_inferred= check_inferred, isapprox_kwargs... )
241
225
end
242
226
243
227
if check_thunked_output_tangent
@@ -285,3 +269,58 @@ function _is_inferrable(f, args...; kwargs...)
285
269
return false
286
270
end
287
271
end
272
+
273
+ """
274
+ _test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...)
275
+
276
+ Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and
277
+ approximately equal to the cotangent `fd_cotangent` obtained with finite differencing.
278
+
279
+ If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable,
280
+ `ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well.
281
+
282
+ # Keyword arguments
283
+ - If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if
284
+ its content can be inferred.
285
+ - All remaining keyword arguments are passed to `isapprox`.
286
+ """
287
+ function _test_cotangent (
288
+ accum_cotangent,
289
+ ad_cotangent,
290
+ fd_cotangent;
291
+ check_inferred= true ,
292
+ kwargs... ,
293
+ )
294
+ ad_cotangent isa AbstractThunk && check_inferred && _test_inferred (unthunk, ad_cotangent)
295
+
296
+ # The main test of the actual derivative being correct:
297
+ test_approx (ad_cotangent, fd_cotangent; kwargs... )
298
+ _test_add!!_behaviour (accum_cotangent, ad_cotangent; kwargs... )
299
+ end
300
+
301
+ # we marked the argument as non-differentiable
302
+ function _test_cotangent (:: NoTangent , ad_cotangent, :: NoTangent ; kwargs... )
303
+ @test ad_cotangent isa NoTangent
304
+ end
305
+ function _test_cotangent (:: NoTangent , :: ZeroTangent , :: NoTangent ; kwargs... )
306
+ error (
307
+ " The pullback in the rrule should use NoTangent()" *
308
+ " rather than ZeroTangent() for non-perturbable arguments."
309
+ )
310
+ end
311
+ function _test_cotangent (
312
+ :: NoTangent ,
313
+ ad_cotangent:: ChainRulesCore.NotImplemented ,
314
+ :: NoTangent ;
315
+ kwargs... ,
316
+ )
317
+ # this situation can occur if a cotangent is not implemented and
318
+ # the default `rand_tangent` is `NoTangent`: e.g. due to having no fields
319
+ # the `@test_broken` below should tell them that there is an easy implementation for
320
+ # this case of `NoTangent()` (`@test_broken false` would be less useful!)
321
+ # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
322
+ @test_broken ad_cotangent isa NoTangent
323
+ end
324
+ function _test_cotangent (:: NoTangent , ad_cotangent, fd_cotangent; kwargs... )
325
+ error (" cotangent obtained with finite differencing has to be NoTangent()" )
326
+ end
0 commit comments