Skip to content

Commit 63ab3e1

Browse files
author
Miha Zgubic
committed
error inside Tangent constructor if incorrect type is used
1 parent bde3166 commit 63ab3e1

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/tangent_types/tangent.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ struct Tangent{P,T} <: AbstractTangent
2525
# Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
2626
# (but potentially a different one, as it doesn't contain differentials)
2727
backing::T
28+
29+
function Tangent{P,T}(backing) where {P,T}
30+
function backing_error(P, G, E)
31+
msg = "Tangent for the primal $P should be backed by a $E type, not by $G."
32+
throw(ArgumentError(msg))
33+
end
34+
35+
if P <: Tuple
36+
T <: Tuple || backing_error(P, T, Tuple)
37+
elseif P <: AbstractDict
38+
T <: AbstractDict || backing_error(P, T, AbstractDict)
39+
else
40+
T <: NamedTuple || backing_error(P, T, NamedTuple)
41+
end
42+
return new(backing)
43+
end
2844
end
2945

3046
function Tangent{P}(; kwargs...) where {P}

test/tangent_types/tangent.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,26 @@ end
2323
@test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}}
2424
end
2525

26+
@testset "constructor" begin
27+
t = (1.0, 2.0)
28+
nt = (x = 1, y=2.0)
29+
d = Dict(:x => 1.0, :y => 2.0)
30+
vals = [1, 2]
31+
32+
@test_throws ArgumentError Tangent{typeof(t), typeof(nt)}(nt)
33+
@test_throws ArgumentError Tangent{typeof(t), typeof(d)}(d)
34+
35+
@test_throws ArgumentError Tangent{typeof(d), typeof(nt)}(nt)
36+
@test_throws ArgumentError Tangent{typeof(d), typeof(t)}(t)
37+
38+
@test_throws ArgumentError Tangent{typeof(nt), typeof(vals)}(vals)
39+
@test_throws ArgumentError Tangent{typeof(nt), typeof(d)}(d)
40+
@test_throws ArgumentError Tangent{typeof(nt), typeof(t)}(t)
41+
42+
@test_throws ArgumentError Tangent{Foo, typeof(d)}(d)
43+
@test_throws ArgumentError Tangent{Foo, typeof(t)}(t)
44+
end
45+
2646
@testset "==" begin
2747
@test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5)
2848
@test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1)

0 commit comments

Comments
 (0)