Skip to content

Commit a724bbb

Browse files
authored
Merge pull request #495 from JuliaDiff/mz/error
Error inside `Tangent` constructor if incorrect backing type is used
2 parents bde3166 + d1a9f5b commit a724bbb

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.10.0"
3+
version = "1.10.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_types/tangent.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,18 @@ struct Tangent{P,T} <: AbstractTangent
2525
# Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
2626
# (but potentially a different one, as it doesn't contain differentials)
2727
backing::T
28+
29+
function Tangent{P,T}(backing) where {P,T}
30+
if P <: Tuple
31+
T <: Tuple || _backing_error(P, T, Tuple)
32+
elseif P <: AbstractDict
33+
T <: AbstractDict || _backing_error(P, T, AbstractDict)
34+
elseif P === Any # can be anything
35+
else # Any other struct (including NamedTuple)
36+
T <: NamedTuple || _backing_error(P, T, NamedTuple)
37+
end
38+
return new(backing)
39+
end
2840
end
2941

3042
function Tangent{P}(; kwargs...) where {P}
@@ -45,6 +57,11 @@ function Tangent{P}(d::Dict) where {P<:Dict}
4557
return Tangent{P,typeof(d)}(d)
4658
end
4759

60+
function _backing_error(P, G, E)
61+
msg = "Tangent for the primal $P should be backed by a $E type, not by $G."
62+
return throw(ArgumentError(msg))
63+
end
64+
4865
function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T}
4966
return backing(a) == backing(b)
5067
end

test/tangent_types/tangent.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,26 @@ end
2323
@test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}}
2424
end
2525

26+
@testset "constructor" begin
27+
t = (1.0, 2.0)
28+
nt = (x=1, y=2.0)
29+
d = Dict(:x => 1.0, :y => 2.0)
30+
vals = [1, 2]
31+
32+
@test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt)
33+
@test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d)
34+
35+
@test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt)
36+
@test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t)
37+
38+
@test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals)
39+
@test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d)
40+
@test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t)
41+
42+
@test_throws ArgumentError Tangent{Foo,typeof(d)}(d)
43+
@test_throws ArgumentError Tangent{Foo,typeof(t)}(t)
44+
end
45+
2646
@testset "==" begin
2747
@test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5)
2848
@test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1)
@@ -110,7 +130,7 @@ end
110130
@test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0))
111131

112132
d = Dict(:x => 1, :y => 2.0)
113-
cdict = Tangent{Foo,typeof(d)}(d)
133+
cdict = Tangent{typeof(d),typeof(d)}(d)
114134
@test_throws MethodError reverse(Tangent{Foo}())
115135
end
116136

0 commit comments

Comments
 (0)