Skip to content

Commit 6df5395

Browse files
authored
Merge pull request #180 from JuliaDiff/mz/0.8
Remove Deprecations. Replace tangent_transforms functionality
2 parents 492364c + 83c01ee commit 6df5395

File tree

11 files changed

+66
-497
lines changed

11 files changed

+66
-497
lines changed

docs/Manifest.toml

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# This file is machine-generated - editing it directly is not advised
22

3+
[[ANSIColoredPrinters]]
4+
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
5+
uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9"
6+
version = "0.0.1"
7+
38
[[ArgTools]]
49
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
510

@@ -11,15 +16,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1116

1217
[[ChainRulesCore]]
1318
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
git-tree-sha1 = "0b0aa9d61456940511416b59a0e902c57b154956"
19+
git-tree-sha1 = "f53ca8d41e4753c41cdafa6ec5f7ce914b34be54"
1520
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.12"
21+
version = "0.10.13"
1722

1823
[[ChainRulesTestUtils]]
1924
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
2025
path = ".."
2126
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
22-
version = "0.7.13"
27+
version = "1.0.0-DEV"
2328

2429
[[Compat]]
2530
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -46,20 +51,20 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
4651
version = "0.8.5"
4752

4853
[[Documenter]]
49-
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
50-
git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48"
54+
deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
55+
git-tree-sha1 = "95265abf7d7bf06dfdb8d58525a23ea5fb0bdeee"
5156
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
52-
version = "0.27.3"
57+
version = "0.27.4"
5358

5459
[[Downloads]]
5560
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
5661
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
5762

5863
[[FiniteDifferences]]
5964
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
60-
git-tree-sha1 = "12417e4754486a547d98d65293dc0fafdfcc0736"
65+
git-tree-sha1 = "18761c465ef2e87d9091c0fefb61f70d532d4cc0"
6166
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
62-
version = "0.12.14"
67+
version = "0.12.16"
6368

6469
[[IOCapture]]
6570
deps = ["Logging", "Random"]
@@ -167,9 +172,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
167172

168173
[[StaticArrays]]
169174
deps = ["LinearAlgebra", "Random", "Statistics"]
170-
git-tree-sha1 = "a43a7b58a6e7dc933b2fa2e0ca653ccf8bb8fd0e"
175+
git-tree-sha1 = "1b9a0f17ee0adde9e538227de093467348992397"
171176
uuid = "90137ffa-7385-5640-81b9-e52037218182"
172-
version = "1.2.6"
177+
version = "1.2.7"
173178

174179
[[Statistics]]
175180
deps = ["LinearAlgebra", "SparseArrays"]

docs/src/api.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,3 @@
44
Modules = [ChainRulesTestUtils]
55
Private = false
66
```
7-
8-
9-
## Global Configuration
10-
```@docs
11-
ChainRulesTestUtils.enable_tangent_transform!
12-
```

docs/src/index.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ The call will test the `rrule` for function `f` at the point `x`, and similarly
7272
```jldoctest ex
7373
julia> test_rrule(two2three, 3.33, -7.77);
7474
Test Summary: | Pass Total
75-
test_rrule: two2three on Float64,Float64 | 8 8
75+
test_rrule: two2three on Float64,Float64 | 9 9
7676
7777
```
7878

@@ -100,12 +100,12 @@ call.
100100
```jldoctest ex
101101
julia> test_scalar(relu, 0.5);
102102
Test Summary: | Pass Total
103-
test_scalar: relu at 0.5 | 10 10
103+
test_scalar: relu at 0.5 | 11 11
104104
105105
106106
julia> test_scalar(relu, -0.5);
107107
Test Summary: | Pass Total
108-
test_scalar: relu at -0.5 | 10 10
108+
test_scalar: relu at -0.5 | 11 11
109109
110110
```
111111

src/ChainRulesTestUtils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,4 @@ include("check_result.jl")
2929
include("rule_config.jl")
3030
include("finite_difference_calls.jl")
3131
include("testers.jl")
32-
33-
include("deprecated.jl")
3432
end # module

src/check_result.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function test_approx(actual::A, expected::E, msg=""; kwargs...) where {A,E}
143143
if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow
144144
throw(MethodError, test_approx, (actual, expected))
145145
end
146-
test_approx(c_actual, c_expected; kwargs...)
146+
test_approx(c_actual, c_expected, msg; kwargs...)
147147
end
148148
end
149149

src/deprecated.jl

Lines changed: 0 additions & 93 deletions
This file was deleted.

src/global_config.jl

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,6 @@
11
const _fdm = central_fdm(5, 1; max_range=1e-2)
22
const TEST_INFERRED = Ref(true)
3-
const TRANSFORMS_TO_ALT_TANGENTS = Function[] # e.g. [x -> @thunk(x), _ -> ZeroTangent(), x -> rebasis(x)]
4-
5-
"""
6-
enable_tangent_transform!(Thunk)
7-
8-
Adds a alt-tangent tranform to the list of default `tangent_transforms` for
9-
[`test_frule`](@ref) and [`test_rrule`](@ref) to test.
10-
This list of defaults is overwritten by the `tangent_transforms` keyword argument.
11-
12-
!!! info "Transitional Feature"
13-
ChainRulesCore v1.0 will require that all well-behaved rules work for a variety of
14-
tangent representations. In turn, the corresponding release of ChainRulesTestUtils will
15-
test all the different tangent representations by default.
16-
At that stage `enable_tangent_transform!(Thunk)` will have no effect, as it will already
17-
be enabled.
18-
We provide this configuration as a transitional feature to help migrate your packages
19-
one feature at a time, prior to the breaking release of ChainRulesTestUtils that will
20-
enforce it.
21-
"""
22-
function enable_tangent_transform!(::Type{Thunk})
23-
push!(TRANSFORMS_TO_ALT_TANGENTS, x->@thunk(x))
24-
unique!(TRANSFORMS_TO_ALT_TANGENTS)
25-
end
3+
const TRANSFORMS_TO_ALT_TANGENTS = Function[x->@thunk(x)] # e.g. [_ -> ZeroTangent(), x -> rebasis(x)]
264

275
"sets up TEST_INFERRED based ion enviroment variables"
286
function init_test_inferred_setting!()

src/testers.jl

Lines changed: 9 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,6 @@ end
8080
# Keyword Arguments
8181
- `output_tangent` tangent to test accumulation of derivatives against
8282
should be a differential for the output of `f`. Is set automatically if not provided.
83-
- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
84-
transform the passed argument tangents into alternative tangents that should be tested.
85-
Note that the alternative tangents are only tested for not erroring when passed to
86-
frule. Testing for correctness using finite differencing can be done using a
87-
separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
88-
`test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
8983
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
9084
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
9185
`frule`). Used for testing gradients from AD systems.
@@ -104,7 +98,6 @@ function test_frule(
10498
f,
10599
args...;
106100
output_tangent=Auto(),
107-
tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS,
108101
fdm=_fdm,
109102
frule_f=ChainRulesCore.frule,
110103
check_inferred::Bool=true,
@@ -136,41 +129,16 @@ function test_frule(
136129
Ω = call_on_copy(primals...)
137130
test_approx(Ω_ad, Ω; isapprox_kwargs...)
138131

139-
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
140-
is_ignored = isa.(tangents, Union{Nothing,NoTangent})
141-
if any(tangents .== nothing)
142-
Base.depwarn(
143-
"test_frule(f, k ⊢ nothing) is deprecated, use " *
144-
"test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
145-
:test_frule,
146-
)
147-
end
148-
149132
# Correctness testing via finite differencing.
133+
is_ignored = isa.(tangents, NoTangent)
150134
dΩ_fd = _make_jvp_call(fdm, call_on_copy, Ω, primals, tangents, is_ignored)
151135
test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...)
152136

153137
acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
154138
_test_add!!_behaviour(acc, dΩ_ad; isapprox_kwargs...)
155-
156-
# test that rules work for other tangents
157-
_test_frule_alt_tangents(
158-
call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
159-
isapprox_kwargs...
160-
)
161139
end # top-level testset
162140
end
163141

164-
function _test_frule_alt_tangents(
165-
call, frule_f, config, tangent_transforms, tangents, primals, acc;
166-
isapprox_kwargs...
167-
)
168-
@testset "ȧrgs = $(_string_typeof(tsf.(tangents)))" for tsf in tangent_transforms
169-
_, dΩ = call(frule_f, config, tsf.(tangents), primals...)
170-
_test_add!!_behaviour(acc, dΩ; isapprox_kwargs...)
171-
end
172-
end
173-
174142
"""
175143
test_rrule([config::RuleConfig,] f, args...; kwargs...)
176144
@@ -185,12 +153,8 @@ end
185153
# Keyword Arguments
186154
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
187155
should be a differential for the output of `f`. Is set automatically if not provided.
188-
- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
189-
transform the passed `output_tangent` into alternative tangents that should be tested.
190-
Note that the alternative tangents are only tested for not erroring when passed to
191-
rrule. Testing for correctness using finite differencing can be done using a
192-
separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
193-
`test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
156+
- `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
157+
output tangent to the pullback returns the same result.
194158
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
195159
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
196160
Used for testing gradients from AD systems.
@@ -209,7 +173,7 @@ function test_rrule(
209173
f,
210174
args...;
211175
output_tangent=Auto(),
212-
tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS,
176+
check_thunked_output_tangent=true,
213177
fdm=_fdm,
214178
rrule_f=ChainRulesCore.rrule,
215179
check_inferred::Bool=true,
@@ -254,22 +218,13 @@ function test_rrule(
254218
)
255219

256220
# Correctness testing via finite differencing.
257-
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
258-
is_ignored = isa.(accum_cotangents, Union{Nothing, NoTangent})
259-
if any(accum_cotangents .== nothing)
260-
Base.depwarn(
261-
"test_rrule(f, k ⊢ nothing) is deprecated, use " *
262-
"test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
263-
:test_rrule,
264-
)
265-
end
266-
221+
is_ignored = isa.(accum_cotangents, NoTangent)
267222
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
268223

269224
for (accum_cotangent, ad_cotangent, fd_cotangent) in zip(
270225
accum_cotangents, ad_cotangents, fd_cotangents
271226
)
272-
if accum_cotangent isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
227+
if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
273228
@assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
274229
ad_cotangent isa ZeroTangent && error(
275230
"The pullback in the rrule should use NoTangent()" *
@@ -285,21 +240,10 @@ function test_rrule(
285240
end
286241
end
287242

288-
# test other tangents don't error when passed to the pullback
289-
_test_rrule_alt_tangents(pullback, tangent_transforms, ȳ, accum_cotangents)
290-
end # top-level testset
291-
end
292-
293-
function _test_rrule_alt_tangents(
294-
pullback, tangent_transforms, ȳ, accum_cotangents;
295-
isapprox_kwargs...
296-
)
297-
@testset "ȳ = $(_string_typeof(tsf(ȳ)))" for tsf in tangent_transforms
298-
ad_cotangents = pullback(tsf(ȳ))
299-
for (accum_cotangent, ad_cotangent) in zip(accum_cotangents, ad_cotangents)
300-
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
243+
if check_thunked_output_tangent
244+
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
301245
end
302-
end
246+
end # top-level testset
303247
end
304248

305249
"""

0 commit comments

Comments
 (0)