Skip to content

Commit dd3f1ab

Browse files
committed
handle abstract fields right in mutable tangents outside of zero tangent
1 parent b3562c6 commit dd3f1ab

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

src/tangent_types/structural_tangent.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,6 @@ It itself is also mutable.
7373
struct MutableTangent{P,F} <: StructuralTangent{P}
7474
backing::F
7575

76-
function MutableTangent{P}(fieldvals) where P
77-
backing = map(Ref, fieldvals)
78-
return new{P, typeof(backing)}(backing)
79-
end
8076
function MutableTangent{P}(
8177
any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names}
8278
) where {names, P}
@@ -91,8 +87,14 @@ struct MutableTangent{P,F} <: StructuralTangent{P}
9187
end
9288
return new{P, typeof(backing)}(backing)
9389
end
90+
91+
function MutableTangent{P}(fvals) where P
92+
any_mask = NamedTuple{fieldnames(P)}((!isconcretetype).(fieldtypes(P)))
93+
return MutableTangent{P}(any_mask, fvals)
94+
end
9495
end
9596

97+
9698
####################################################################
9799
# StructuralTangent Common
98100

test/tangent_types/structural_tangent.jl

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ struct Foo
44
y::Float64
55
end
66

7+
mutable struct MFoo
8+
x::Float64
9+
y
10+
end
11+
712
# For testing Primal + Tangent performance
813
struct Bar
914
x::Float64
@@ -452,14 +457,40 @@ end
452457
end
453458

454459
@testset "== and hash" begin
455-
@test MutableTangent{Any}(; x=1.0) == MutableTangent{MDemo}(; x=1.0)
456-
@test MutableTangent{MDemo}(; x=1.0) == MutableTangent{Any}(; x=1.0)
457-
@test MutableTangent{Any}(; x=2.0) != MutableTangent{MDemo}(; x=1.0)
458-
@test MutableTangent{MDemo}(; x=1.0) != MutableTangent{Any}(; x=2.0)
460+
@test MutableTangent{MDemo}(; x=1f0) == MutableTangent{MDemo}(; x=1.0)
461+
@test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1f0)
462+
@test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0)
463+
@test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0)
459464

460465
nt = (; x=1.0)
461466
@test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0)
462467

463-
@test hash(MutableTangent{Any}(; x=1.0)) == hash(MutableTangent{MDemo}(; x=1.0))
468+
@test hash(MutableTangent{MDemo}(; x=1f0)) == hash(MutableTangent{MDemo}(; x=1.0))
469+
end
470+
471+
@testset "Mutation" begin
472+
v = MutableTangent{MFoo}(x=1.5, y=2.4)
473+
v.x = 1.6
474+
@test v == MutableTangent{MFoo}(x=1.6, y=2.4)
475+
v.y = [1.0, 2.0] # change type, because primal can change type
476+
@test v == MutableTangent{MFoo}(x=1.6, y=[1.0, 2.0])
477+
end
478+
end
479+
480+
@testset "map" begin
481+
@testset "Tangent" begin
482+
∂foo = Tangent{Foo}(x=1.5, y=2.4)
483+
@test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0, y=4.8)
484+
485+
∂foo = Tangent{Foo}(x=1.5)
486+
@test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0)
487+
end
488+
@testset "MutableTangent" begin
489+
∂foo = MutableTangent{MFoo}(x=1.5, y=2.4)
490+
∂foo2 = map(v->2*v, ∂foo)
491+
@test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=4.8)
492+
# Check can still be mutated to new typ
493+
∂foo2.y=[1.0, 2.0]
494+
@test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=[1.0, 2.0])
464495
end
465496
end

0 commit comments

Comments
 (0)