Skip to content

Commit ad7cafc

Browse files
committed
Test for things adding rules to DataType/UnionAll
1 parent 4023ec0 commit ad7cafc

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

test/global_state.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
parameters(type)
3+
Extracts the type-parameters of the `type`.
4+
e.g. `parameters(Foo{A, B, C}) == [A, B, C]`
5+
"""
6+
parameters(sig::UnionAll) = parameters(sig.body)
7+
parameters(sig::DataType) = sig.parameters
8+
parameters(sig::Union) = Base.uniontypes(sig)
9+
10+
@testset "Make sure methods haven't been added to DataType/UnionAll/Union" begin
11+
# if someone wrote e.g. `rrule(::typeof(Foo), x)` rather than `rrule(::Type{<:Foo}, x)`
12+
# then that would actually define `rrule(::DataType, x)` which would be bad
13+
# This test checks for that and fails if such a method exists.
14+
for method in methods(rrule)
15+
function_type = if method.sig <: Tuple{typeof(rrule), RuleConfig, Type, Vararg}
16+
parameters(method.sig)[3]
17+
elseif method.sig <: Tuple{typeof(rrule), Type, Vararg}
18+
@show parameters(method.sig)[2]
19+
else
20+
nothing
21+
end
22+
23+
if function_type == DataType || function_type == UnionAll || function_type == Union
24+
@error "Bad constructor rrule. typeof(T)` not `Type{T}`" method
25+
@test false
26+
end
27+
end
28+
29+
# frule
30+
for method in methods(frule)
31+
function_type = if method.sig <: Tuple{typeof(frule), RuleConfig, Any, Type, Vararg}
32+
parameters(method.sig)[3]
33+
elseif method.sig <: Tuple{typeof(frule), Any, Type, Vararg}
34+
@show parameters(method.sig)[2]
35+
else
36+
nothing
37+
end
38+
39+
if function_type == DataType || function_type == UnionAll || function_type == Union
40+
@error "Bad constructor frule. typeof(T)` not `Type{T}`" method
41+
@test false
42+
end
43+
end
44+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ end
4646
include("test_helpers.jl")
4747
println()
4848

49+
include_test("global_state.jl")
50+
4951
# Each file puts all tests inside one or more @testset blocks
5052
include_test("rulesets/Base/base.jl")
5153
include_test("rulesets/Base/fastmath_able.jl")

0 commit comments

Comments
 (0)