Skip to content

Commit 08411c9

Browse files
authored
Merge pull request #183 from JuliaDiff/ox/gconfig
Add public API to enable testing on thunks
2 parents dffcfff + ff6ee1c commit 08411c9

File tree

5 files changed

+53
-21
lines changed

5 files changed

+53
-21
lines changed

docs/Manifest.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[ChainRulesCore]]
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
git-tree-sha1 = "dbc9aae1227cfddaa9d2552f3ecba5b641f6cce9"
14+
git-tree-sha1 = "be770c08881f7bb928dfd86d1ba83798f76cf62a"
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.5"
16+
version = "0.10.9"
1717

1818
[[ChainRulesTestUtils]]
1919
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
@@ -23,9 +23,9 @@ version = "0.7.12"
2323

2424
[[Compat]]
2525
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
26-
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
26+
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
2727
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
28-
version = "3.30.0"
28+
version = "3.31.0"
2929

3030
[[Dates]]
3131
deps = ["Printf"]
@@ -47,9 +47,9 @@ version = "0.8.5"
4747

4848
[[Documenter]]
4949
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
50-
git-tree-sha1 = "5acbebf1be22db43589bc5aa1bb5fcc378b17780"
50+
git-tree-sha1 = "621850838b3e74dd6dd047b5432d2e976877104e"
5151
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
52-
version = "0.27.0"
52+
version = "0.27.2"
5353

5454
[[Downloads]]
5555
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
@@ -167,9 +167,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
167167

168168
[[StaticArrays]]
169169
deps = ["LinearAlgebra", "Random", "Statistics"]
170-
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
170+
git-tree-sha1 = "745914ebcd610da69f3cb6bf76cb7bb83dcb8c9a"
171171
uuid = "90137ffa-7385-5640-81b9-e52037218182"
172-
version = "1.2.2"
172+
version = "1.2.4"
173173

174174
[[Statistics]]
175175
deps = ["LinearAlgebra", "SparseArrays"]

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,9 @@
44
Modules = [ChainRulesTestUtils]
55
Private = false
66
```
7+
8+
9+
## Global Configuration
10+
```@docs
11+
ChainRulesTestUtils.enable_tangent_transform!
12+
```

src/ChainRulesTestUtils.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,14 @@ using Test
1111

1212
import FiniteDifferences: rand_tangent
1313

14-
const _fdm = central_fdm(5, 1; max_range=1e-2)
15-
const TEST_INFERRED = Ref(true)
16-
const TRANSFORMS_TO_ALT_TANGENTS = Function[] # e.g. [x -> @thunk(x), _ -> ZeroTangent(), x -> rebasis(x)]
17-
1814
export TestIterator
1915
export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
2016
export
2117
export @maybe_inferred
2218

23-
function __init__()
24-
TEST_INFERRED[] = if haskey(ENV, "CHAINRULES_TEST_INFERRED")
25-
parse(Bool, "CHAINRULES_TEST_INFERRED")
26-
else
27-
!parse(Bool, get(ENV, "JULIA_PKGEVAL", "false"))
28-
end
19+
__init__() = init_test_inferred_setting!()
2920

30-
!TEST_INFERRED[] && @warn "inference tests have been disabled"
31-
end
21+
include("global_config.jl")
3222

3323
include("generate_tangent.jl")
3424
include("data_generation.jl")

src/global_config.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
const _fdm = central_fdm(5, 1; max_range=1e-2)
2+
const TEST_INFERRED = Ref(true)
3+
const TRANSFORMS_TO_ALT_TANGENTS = Function[] # e.g. [x -> @thunk(x), _ -> ZeroTangent(), x -> rebasis(x)]
4+
5+
"""
6+
enable_tangent_transform!(Thunk)
7+
8+
Adds a alt-tangent tranform to the list of default `tangent_transforms` for
9+
[`test_frule`](@ref) and [`test_rrule`](@ref) to test.
10+
This list of defaults is overwritten by the `tangent_transforms` keyword argument.
11+
12+
!!! info "Transitional Feature"
13+
ChainRulesCore v1.0 will require that all well-behaved rules work for a variety of
14+
tangent representations. In turn, the corresponding release of ChainRulesTestUtils will
15+
test all the different tangent representations by default.
16+
At that stage `enable_tangent_transform!(Thunk)` will have no effect, as it will already
17+
be enabled.
18+
We provide this configuration as a transitional feature to help migrate your packages
19+
one feature at a time, prior to the breaking release of ChainRulesTestUtils that will
20+
enforce it.
21+
"""
22+
function enable_tangent_transform!(::Type{Thunk})
23+
push!(TRANSFORMS_TO_ALT_TANGENTS, x->@thunk(x))
24+
unique!(TRANSFORMS_TO_ALT_TANGENTS)
25+
end
26+
27+
"sets up TEST_INFERRED based ion enviroment variables"
28+
function init_test_inferred_setting!()
29+
TEST_INFERRED[] = if haskey(ENV, "CHAINRULES_TEST_INFERRED")
30+
parse(Bool, "CHAINRULES_TEST_INFERRED")
31+
else
32+
!parse(Bool, get(ENV, "JULIA_PKGEVAL", "false"))
33+
end
34+
35+
!TEST_INFERRED[] && @warn "inference tests have been disabled"
36+
end

src/testers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ end
316316
"""
317317
_test_inferred(f, args...; kwargs...)
318318
319-
Simple wrapper for `@inferred f(args...: kwargs...)`, avoiding the type-instability in not
319+
Simple wrapper for [`@maybe_inferred f(args...: kwargs...)`](@ref `@maybe_inferred`), avoiding the type-instability in not
320320
knowing how many `kwargs` there are.
321321
"""
322322
function _test_inferred(f, args...; kwargs...)

0 commit comments

Comments
 (0)