Skip to content

Commit 56f9b53

Browse files
authored
Merge pull request #228 from JuliaDiff/ox/global_check
Add method for checking rules added to DataType etc
2 parents 38a9dac + fb8f193 commit 56f9b53

File tree

5 files changed

+129
-1
lines changed

5 files changed

+129
-1
lines changed

docs/src/index.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@ For information about ChainRules, including how to write rules, refer to the gen
99
[![](https://img.shields.io/badge/docs-main-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/dev)
1010
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/stable)
1111

12-
## Canonical example
12+
## Testing Method Table Sensibility
13+
A basic feature of ChainRulesTestUtils is its ability to check that the method tables for `rrule` and `frule` remain sensible.
14+
This searches the method tables for methods that should not exist and when it fails tells you where they were defined.
15+
By calling [`test_method_tables`](@ref) ChainRulesTestUtils will check for things such as having attracted a rule to `DataType` rather than attaching it to a constructor.
16+
Basically all packages using ChainRulesTestUtils can use [`test_method_tables`](@ref), as it is independent of what rules you have written.
17+
18+
## Canonical example of testing frule and rrule
1319

1420
Let's suppose a custom transformation has been defined
1521
```jldoctest ex
@@ -274,3 +280,4 @@ Test.DefaultTestSet("test_rrule: abs on Float64", Any[], 5, false, false)
274280

275281
This behavior can also be overridden globally by setting the environment variable `CHAINRULES_TEST_INFERRED` before ChainRulesTestUtils is loaded or by changing `ChainRulesTestUtils.TEST_INFERRED[]` from inside Julia.
276282
ChainRulesTestUtils can detect whether a test is run as part of [PkgEval](https://github.com/JuliaCI/PkgEval.jl) and in this case disables inference tests automatically. Packages can use [`@maybe_inferred`](@ref) to get the same behavior for other inference tests.
283+

src/ChainRulesTestUtils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ export TestIterator
1515
export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
1616
export , rand_tangent
1717
export @maybe_inferred
18+
export test_method_tables
1819

1920
__init__() = init_test_inferred_setting!()
2021

@@ -31,4 +32,6 @@ include("check_result.jl")
3132
include("rule_config.jl")
3233
include("finite_difference_calls.jl")
3334
include("testers.jl")
35+
36+
include("global_checks.jl")
3437
end # module

src/global_checks.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
_parameters(type)
3+
Extracts the type-parameters of the `type`.
4+
e.g. `_parameters(Foo{A, B, C}) == [A, B, C]`
5+
"""
6+
_parameters(sig::UnionAll) = _parameters(sig.body)
7+
_parameters(sig::DataType) = sig.parameters
8+
_parameters(sig::Union) = Base.uniontypes(sig)
9+
10+
11+
"""
12+
test_method_signature(frule|rrule, method)
13+
14+
Tests that the method signature is sensible.
15+
Right now this just means checking the rule is not being applied to `DataType`, `Union`, or
16+
`UnionAll`.
17+
which is easy to do accidentally when writing rules for constructors.
18+
It happens if you write e.g. `rrule(::typeof(Foo), x)` rather than `rrule(::Type{<:Foo}, x)`.
19+
This would then actually define `rrule(::DataType, x)`. (or `UnionAll` if `Foo`
20+
was parametric, or `Union` if `Foo` was a type alias for a `Union`)
21+
"""
22+
function test_method_signature end
23+
24+
function test_method_signature(::typeof(rrule), method::Method)
25+
@testset "Sensible Constructors" begin
26+
function_type = if method.sig <: Tuple{Any, RuleConfig, Type, Vararg}
27+
_parameters(method.sig)[3]
28+
elseif method.sig <: Tuple{Any, Type, Vararg}
29+
_parameters(method.sig)[2]
30+
else
31+
nothing
32+
end
33+
34+
@test_msg(
35+
"Bad constructor rrule. `typeof(T)` used rather than `Type{T}`. $method",
36+
function_type (DataType, UnionAll, Union)
37+
)
38+
end
39+
end
40+
41+
function test_method_signature(::typeof(frule), method::Method)
42+
@testset "Sensible Constructors" begin
43+
function_type = if method.sig <: Tuple{Any, RuleConfig, Any, Type, Vararg}
44+
_parameters(method.sig)[4]
45+
elseif method.sig <: Tuple{Any, Any, Type, Vararg}
46+
_parameters(method.sig)[3]
47+
else
48+
nothing
49+
end
50+
51+
@test_msg(
52+
"Bad constructor frule. `typeof(T)` used rather than `Type{T}`. $method",
53+
function_type (DataType, UnionAll, Union)
54+
)
55+
end
56+
end
57+
58+
"""
59+
test_method_tables()
60+
61+
Checks that the method tables for `rrule` and `frule` are sensible.
62+
This in future may carry out a number of checks, but presently just checks to make sure that
63+
no rules have been added to the very general `DataType`, `Union` or `UnionAll` types,
64+
which is easy to do accidentally when writing rules for constructors.
65+
It happens if you write e.g. `rrule(::typeof(Foo), x)` rather than `rrule(::Type{<:Foo}, x)`.
66+
This would then actually define `rrule(::DataType, x)`. (or `UnionAll` if `Foo`
67+
was parametric, or `Union` if `Foo` was a type alias for a `Union`)
68+
"""
69+
function test_method_tables()
70+
@testset "Sensible Constructors" begin
71+
# if someone wrote e.g. `rrule(::typeof(Foo), x)` rather than
72+
# `rrule(::Type{<:Foo}, x)` then that would actually define `rrule(::DataType, x)`
73+
# which would be bad. This test checks for that and fails if such a method exists.
74+
for method in methods(rrule)
75+
test_method_signature(rrule, method)
76+
end
77+
# frule
78+
for method in methods(frule)
79+
test_method_signature(frule, method)
80+
end
81+
end
82+
end

test/global_checks.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
struct BadConstructor
2+
x
3+
end
4+
5+
if VERSION >= v"1.3"
6+
@testset "global_checks.jl" begin
7+
test_method_tables()
8+
ChainRulesCore.rrule(::typeof(BadConstructor), x) = nothing
9+
@test fails(test_method_tables)
10+
Base.delete_method(Base.which(rrule, (DataType, Any)))
11+
test_method_tables() # make sure delete worked
12+
13+
ChainRulesCore.rrule(::RuleConfig, ::typeof(BadConstructor), x) = nothing
14+
@test fails(test_method_tables)
15+
Base.delete_method(Base.which(rrule, (RuleConfig, DataType, Any)))
16+
test_method_tables() # make sure delete worked
17+
18+
19+
20+
ChainRulesCore.frule(::Any, ::typeof(BadConstructor), x) = nothing
21+
@test fails(test_method_tables)
22+
Base.delete_method(Base.which(frule, (Any, DataType, Any)))
23+
test_method_tables() # make sure delete worked
24+
25+
ChainRulesCore.frule(::RuleConfig, ::Any, ::typeof(BadConstructor), x) = nothing
26+
@test fails(test_method_tables)
27+
Base.delete_method(Base.which(frule, (RuleConfig, Any, DataType, Any)))
28+
test_method_tables() # make sure delete worked
29+
end
30+
else # pre 1.3, so no `delete_method` so just test happy path
31+
@testset "global_checks.jl" begin
32+
test_method_tables()
33+
end
34+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ ChainRulesTestUtils.TEST_INFERRED[] = true
1616
include("testers.jl")
1717
include("data_generation.jl")
1818
include("rand_tangent.jl")
19+
20+
include("global_checks.jl")
1921
end

0 commit comments

Comments
 (0)