Skip to content

Commit 06668f8

Browse files
committed
Add method for checking rules added to DataType etc
1 parent 63bbd48 commit 06668f8

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

src/ChainRulesTestUtils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ export TestIterator
1515
export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
1616
export , rand_tangent
1717
export @maybe_inferred
18+
export test_method_tables_sensibility
1819

1920
__init__() = init_test_inferred_setting!()
2021

@@ -31,4 +32,6 @@ include("check_result.jl")
3132
include("rule_config.jl")
3233
include("finite_difference_calls.jl")
3334
include("testers.jl")
35+
36+
include("global_checks.jl")
3437
end # module

src/global_checks.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
_parameters(type)
3+
Extracts the type-parameters of the `type`.
4+
e.g. `_parameters(Foo{A, B, C}) == [A, B, C]`
5+
"""
6+
_parameters(sig::UnionAll) = _parameters(sig.body)
7+
_parameters(sig::DataType) = sig.parameters
8+
_parameters(sig::Union) = Base.uniontypes(sig)
9+
10+
"""
11+
test_method_tables_sensibility()
12+
13+
Checks that the method tables for `rrule` and `frule` are sensible.
14+
This in future may carry out a number of checks, but presently just checks to make sure that
15+
no rules have been added to the very general `DataType`, `Union` or `UnionAll` types.
16+
This is easy to do when writing rules for constructors.
17+
It happens if you writeg. `rrule(::typeof(Foo), x)` rather than `rrule(::Type{<:Foo}, x)`:
18+
This would then actually define `rrule(::DataType, x)`. (or `UnionAll` if `Foo`
19+
was parametric, or `Union` if `Foo` was a type alias for a `Union`)
20+
"""
21+
function test_method_tables_sensibility()
22+
@testset "Make sure methods haven't been added to DataType/UnionAll/Union" begin
23+
# if someone wrote e.g. `rrule(::typeof(Foo), x)` rather than
24+
# `rrule(::Type{<:Foo}, x)` then that would actually define `rrule(::DataType, x)`
25+
# which would be bad. This test checks for that and fails if such a method exists.
26+
for method in methods(rrule)
27+
function_type = if method.sig <: Tuple{Any, RuleConfig, Type, Vararg}
28+
_parameters(method.sig)[3]
29+
elseif method.sig <: Tuple{Any, Type, Vararg}
30+
_parameters(method.sig)[2]
31+
else
32+
nothing
33+
end
34+
35+
if function_type (DataType, UnionAll Union)
36+
@error "Bad constructor rrule. typeof(T)` not `Type{T}`" method
37+
@test false
38+
end
39+
end
40+
41+
# frule
42+
for method in methods(frule)
43+
function_type = if method.sig <: Tuple{Any, RuleConfig, Any, Type, Vararg}
44+
_parameters(method.sig)[4]
45+
elseif method.sig <: Tuple{Any, Any, Type, Vararg}
46+
@show _parameters(method.sig)[3]
47+
else
48+
nothing
49+
end
50+
51+
if function_type (DataType, UnionAll Union)
52+
@error "Bad constructor frule. typeof(T)` not `Type{T}`" method
53+
@test false
54+
end
55+
end
56+
end

test/global_checks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@testset "global_checks.jl" begin
2+
# Just check it doesn't error.
3+
test_method_tables_sensibility()
4+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ ChainRulesTestUtils.TEST_INFERRED[] = true
1616
include("testers.jl")
1717
include("data_generation.jl")
1818
include("rand_tangent.jl")
19+
20+
include("global_checks.jl")
1921
end

0 commit comments

Comments
 (0)