Skip to content

Commit 7b1ef4c

Browse files
authored
Test other tangents (#176)
1 parent 8c5da46 commit 7b1ef4c

File tree

5 files changed

+90
-3
lines changed

5 files changed

+90
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.7.9"
3+
version = "0.7.10"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ChainRulesTestUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import FiniteDifferences: rand_tangent
1313

1414
const _fdm = central_fdm(5, 1; max_range=1e-2)
1515
const TEST_INFERRED = Ref(true)
16+
const TRANSFORMS_TO_ALT_TANGENTS = Function[] # e.g. [x -> @thunk(x), _ -> ZeroTangent(), x -> rebasis(x)]
1617

1718
export TestIterator
1819
export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix

src/check_result.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ end
3333
test_approx(::AbstractZero, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
3434
test_approx(x, ::AbstractZero, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
3535
test_approx(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true
36+
test_approx(x::NoTangent, y::NoTangent, msg=""; kwargs...) = @test true
3637

3738
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
3839
test_approx(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true

src/testers.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ end
8080
# Keyword Arguments
8181
- `output_tangent` tangent to test accumulation of derivatives against
8282
should be a differential for the output of `f`. Is set automatically if not provided.
83+
- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
84+
transform the passed argument tangents into alternative tangents that should be tested.
85+
Note that the alternative tangents are only tested for not erroring when passed to
86+
frule. Testing for correctness using finite differencing can be done using a
87+
separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
88+
`test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
8389
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
8490
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
8591
`frule`). Used for testing gradients from AD systems.
@@ -98,6 +104,7 @@ function test_frule(
98104
f,
99105
args...;
100106
output_tangent=Auto(),
107+
tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS,
101108
fdm=_fdm,
102109
frule_f=ChainRulesCore.frule,
103110
check_inferred::Bool=true,
@@ -122,7 +129,7 @@ function test_frule(
122129
_test_inferred(frule_f, deepcopy(config), deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
123130
end
124131

125-
res = frule_f(deepcopy(config), deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
132+
res = call_on_copy(frule_f, config, tangents, primals...)
126133
res === nothing && throw(MethodError(frule_f, typeof(primals)))
127134
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
128135
Ω_ad, dΩ_ad = res
@@ -144,10 +151,26 @@ function test_frule(
144151
test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...)
145152

146153
acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
147-
_test_add!!_behaviour(acc, dΩ_ad; rtol=rtol, atol=atol, kwargs...)
154+
_test_add!!_behaviour(acc, dΩ_ad; isapprox_kwargs...)
155+
156+
# test that rules work for other tangents
157+
_test_frule_alt_tangents(
158+
call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
159+
isapprox_kwargs...
160+
)
148161
end # top-level testset
149162
end
150163

164+
function _test_frule_alt_tangents(
165+
call, frule_f, config, tangent_transforms, tangents, primals, acc;
166+
isapprox_kwargs...
167+
)
168+
@testset "ȧrgs = $(tsf.(tangents))" for tsf in tangent_transforms
169+
_, dΩ = call(frule_f, config, tsf.(tangents), primals...)
170+
_test_add!!_behaviour(acc, dΩ; isapprox_kwargs...)
171+
end
172+
end
173+
151174
"""
152175
test_rrule([config::RuleConfig,] f, args...; kwargs...)
153176
@@ -162,6 +185,12 @@ end
162185
# Keyword Arguments
163186
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
164187
should be a differential for the output of `f`. Is set automatically if not provided.
188+
- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
189+
transform the passed `output_tangent` into alternative tangents that should be tested.
190+
Note that the alternative tangents are only tested for not erroring when passed to
191+
rrule. Testing for correctness using finite differencing can be done using a
192+
separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
193+
`test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
165194
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
166195
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
167196
Used for testing gradients from AD systems.
@@ -180,6 +209,7 @@ function test_rrule(
180209
f,
181210
args...;
182211
output_tangent=Auto(),
212+
tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS,
183213
fdm=_fdm,
184214
rrule_f=ChainRulesCore.rrule,
185215
check_inferred::Bool=true,
@@ -249,9 +279,24 @@ function test_rrule(
249279
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
250280
end
251281
end
282+
283+
# test other tangents don't error when passed to the pullback
284+
_test_rrule_alt_tangents(pullback, tangent_transforms, ȳ, accum_cotangents)
252285
end # top-level testset
253286
end
254287

288+
function _test_rrule_alt_tangents(
289+
pullback, tangent_transforms, ȳ, accum_cotangents;
290+
isapprox_kwargs...
291+
)
292+
@testset "ȳ = $(tsf(ȳ))" for tsf in tangent_transforms
293+
ad_cotangents = pullback(tsf(ȳ))
294+
for (accum_cotangent, ad_cotangent) in zip(accum_cotangents, ad_cotangents)
295+
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
296+
end
297+
end
298+
end
299+
255300
"""
256301
@maybe_inferred [Type] f(...)
257302

test/testers.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,46 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
570570
end
571571
end
572572

573+
@testset "tangent_transforms frule" begin
574+
others_work(x) = 2x
575+
function ChainRulesCore.frule((Δd, Δx), ::typeof(others_work), x)
576+
return others_work(x), 2Δx
577+
end
578+
579+
others_nowork(x) = 2x
580+
function ChainRulesCore.frule((Δd, Δx), ::typeof(others_nowork), x)
581+
return others_nowork(x), error("nope")
582+
end
583+
584+
test_frule(others_work, rand(); tangent_transforms=[identity, x -> @thunk(x)])
585+
@test errors("nope") do
586+
test_frule(others_nowork, 2.3; tangent_transforms=[x -> @thunk(x)])
587+
end
588+
end
589+
590+
@testset "tangent_transforms rrule" begin
591+
others_work(x) = 2x
592+
function ChainRulesCore.rrule(::typeof(others_work), x)
593+
y = others_work(x)
594+
others_work_pullback(ȳ) = return (NoTangent(), 2ȳ)
595+
return y, others_work_pullback
596+
end
597+
598+
others_nowork(x) = [x, x]
599+
function ChainRulesCore.rrule(::typeof(others_nowork), x)
600+
y = others_nowork(x)
601+
others_nowork_pullback(ȳ) = return (NoTangent(), error("nope"))
602+
return y, others_nowork_pullback
603+
end
604+
605+
test_rrule(others_work, 2.3; tangent_transforms=[_ -> ZeroTangent()])
606+
test_rrule(others_work, 2.3; tangent_transforms=[x -> @thunk(x)])
607+
608+
@test errors("nope") do
609+
test_rrule(others_nowork, 2.3; tangent_transforms=[x -> @thunk(x)])
610+
end
611+
end
612+
573613
@testset "Tuple primal that is not equal to differential backing" begin
574614
# https://github.com/JuliaMath/SpecialFunctions.jl/issues/288
575615
forwards_trouble(x) = (1, 2.0 * x)

0 commit comments

Comments
 (0)