Skip to content

Commit 5cb9d1c

Browse files
committed
Change to just testing output tangent thunks in the test_rrule only
1 parent e0d3cbe commit 5cb9d1c

File tree

2 files changed

+7
-86
lines changed

2 files changed

+7
-86
lines changed

src/testers.jl

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,6 @@ end
8080
# Keyword Arguments
8181
- `output_tangent` tangent to test accumulation of derivatives against
8282
should be a differential for the output of `f`. Is set automatically if not provided.
83-
- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
84-
transform the passed argument tangents into alternative tangents that should be tested.
85-
Note that the alternative tangents are only tested for not erroring when passed to
86-
frule. Testing for correctness using finite differencing can be done using a
87-
separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
88-
`test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
8983
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
9084
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
9185
`frule`). Used for testing gradients from AD systems.
@@ -104,7 +98,6 @@ function test_frule(
10498
f,
10599
args...;
106100
output_tangent=Auto(),
107-
tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS,
108101
fdm=_fdm,
109102
frule_f=ChainRulesCore.frule,
110103
check_inferred::Bool=true,
@@ -143,25 +136,9 @@ function test_frule(
143136

144137
acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
145138
_test_add!!_behaviour(acc, dΩ_ad; isapprox_kwargs...)
146-
147-
# test that rules work for other tangents
148-
_test_frule_alt_tangents(
149-
call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
150-
isapprox_kwargs...
151-
)
152139
end # top-level testset
153140
end
154141

155-
function _test_frule_alt_tangents(
156-
call, frule_f, config, tangent_transforms, tangents, primals, acc;
157-
isapprox_kwargs...
158-
)
159-
@testset "ȧrgs = $(_string_typeof(tsf.(tangents)))" for tsf in tangent_transforms
160-
_, dΩ = call(frule_f, config, tsf.(tangents), primals...)
161-
_test_add!!_behaviour(acc, dΩ; isapprox_kwargs...)
162-
end
163-
end
164-
165142
"""
166143
test_rrule([config::RuleConfig,] f, args...; kwargs...)
167144
@@ -176,12 +153,8 @@ end
176153
# Keyword Arguments
177154
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
178155
should be a differential for the output of `f`. Is set automatically if not provided.
179-
- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
180-
transform the passed `output_tangent` into alternative tangents that should be tested.
181-
Note that the alternative tangents are only tested for not erroring when passed to
182-
rrule. Testing for correctness using finite differencing can be done using a
183-
separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
184-
`test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
156+
- `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
157+
output tangent to the pullback returns the same result.
185158
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
186159
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
187160
Used for testing gradients from AD systems.
@@ -200,7 +173,7 @@ function test_rrule(
200173
f,
201174
args...;
202175
output_tangent=Auto(),
203-
tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS,
176+
check_thunked_output_tangent=true,
204177
fdm=_fdm,
205178
rrule_f=ChainRulesCore.rrule,
206179
check_inferred::Bool=true,
@@ -267,21 +240,10 @@ function test_rrule(
267240
end
268241
end
269242

270-
# test other tangents don't error when passed to the pullback
271-
_test_rrule_alt_tangents(pullback, tangent_transforms, ȳ, accum_cotangents)
272-
end # top-level testset
273-
end
274-
275-
function _test_rrule_alt_tangents(
276-
pullback, tangent_transforms, ȳ, accum_cotangents;
277-
isapprox_kwargs...
278-
)
279-
@testset "ȳ = $(_string_typeof(tsf(ȳ)))" for tsf in tangent_transforms
280-
ad_cotangents = pullback(tsf(ȳ))
281-
for (accum_cotangent, ad_cotangent) in zip(accum_cotangents, ad_cotangents)
282-
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
243+
if check_thunked_output_tangent
244+
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk")
283245
end
284-
end
246+
end # top-level testset
285247
end
286248

287249
"""

test/testers.jl

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
444444
= [4.0, 5.0, 6.0]
445445
xcopy, ẋcopy = copy(x), copy(ẋ)
446446
y = [1, 2]
447-
# Don't test tangent transforms, we do not support thunks for mutating frules
448-
# TODO: Should we disable testing thunks for frules in general
449-
test_frule(finplace!, x ẋ; fkwargs=(y=y,), tangent_transforms=[])
447+
test_frule(finplace!, x ẋ; fkwargs=(y=y,))
450448
@test x == xcopy
451449
@test== ẋcopy
452450
@test y == [1, 2]
@@ -585,45 +583,6 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
585583
end
586584
end
587585

588-
@testset "tangent_transforms frule" begin
589-
others_work(x) = 2x
590-
function ChainRulesCore.frule((Δd, Δx), ::typeof(others_work), x)
591-
return others_work(x), 2Δx
592-
end
593-
594-
others_nowork(x) = 2x
595-
function ChainRulesCore.frule((Δd, Δx), ::typeof(others_nowork), x)
596-
return others_nowork(x), error("nope")
597-
end
598-
599-
test_frule(others_work, rand(); tangent_transforms=[identity, x -> @thunk(x)])
600-
@test errors("nope") do
601-
test_frule(others_nowork, 2.3; tangent_transforms=[x -> @thunk(x)])
602-
end
603-
end
604-
605-
@testset "tangent_transforms rrule" begin
606-
others_work(x) = 2x
607-
function ChainRulesCore.rrule(::typeof(others_work), x)
608-
y = others_work(x)
609-
others_work_pullback(ȳ) = return (NoTangent(), 2ȳ)
610-
return y, others_work_pullback
611-
end
612-
613-
others_nowork(x) = [x, x]
614-
function ChainRulesCore.rrule(::typeof(others_nowork), x)
615-
y = others_nowork(x)
616-
others_nowork_pullback(ȳ) = return (NoTangent(), error("nope"))
617-
return y, others_nowork_pullback
618-
end
619-
620-
test_rrule(others_work, 2.3; tangent_transforms=[_ -> ZeroTangent()])
621-
test_rrule(others_work, 2.3; tangent_transforms=[x -> @thunk(x)])
622-
623-
@test errors("nope") do
624-
test_rrule(others_nowork, 2.3; tangent_transforms=[x -> @thunk(x)])
625-
end
626-
end
627586

628587
@testset "Tuple primal that is not equal to differential backing" begin
629588
# https://github.com/JuliaMath/SpecialFunctions.jl/issues/288

0 commit comments

Comments
 (0)