Skip to content

Commit d2142fa

Browse files
committed
add == and hash for MutableTangent
1 parent 6f15d27 commit d2142fa

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

src/tangent_types/structural_tangent.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple
55
as an object with mirroring fields.
66
77
!!!!!! warning Exprimental
8+
`StructuralTangent` is an experimental feature, and is part of the mutation support featureset.
89
The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs.
910
`MutableTangent` is an experimental feature.
1011
Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental.
@@ -409,7 +410,7 @@ This type represents the tangent to a mutable struct.
409410
It itself is also mutable.
410411
411412
!!! warning Exprimental
412-
MutableTangent is an experimental feature.
413+
MutableTangent is an experimental feature, and is part of the mutation support featureset.
413414
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
414415
Exactly how it should be used (e.g. is it forward-mode only?)
415416
@@ -440,4 +441,10 @@ function Base.setproperty!(tangent::MutableTangent, idx::Int, x)
440441
return setproperty!(tangent, name, x)
441442
end
442443

443-
idx2sym(::NamedTuple{names}, idx) where names = names[idx]
444+
idx2sym(::NamedTuple{names}, idx) where names = names[idx]
445+
446+
Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h)
447+
function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2}
448+
typeintersect(T1, T2) == Union{} && return false
449+
backing(t1)==backing(t2)
450+
end

test/tangent_types/structural_tangent.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,28 @@ end
436436
return y, ẏ
437437
end
438438

439-
obj = MDemo(99.0)
440-
∂obj = MutableTangent{MDemo}(;x=1.5)
441-
frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0)
442-
@test ∂obj.x == 10.0
443-
@test obj.x == 95.0
444-
445-
frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0)
446-
@test ∂obj.x == 20.0
447-
@test getproperty(∂obj, 1) == 20.0
448-
@test obj.x == 96.0
439+
@testset "usecase" begin
440+
obj = MDemo(99.0)
441+
∂obj = MutableTangent{MDemo}(;x=1.5)
442+
frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0)
443+
@test ∂obj.x == 10.0
444+
@test obj.x == 95.0
445+
446+
frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0)
447+
@test ∂obj.x == 20.0
448+
@test getproperty(∂obj, 1) == 20.0
449+
@test obj.x == 96.0
450+
end
451+
452+
@testset "== and hash" begin
453+
@test MutableTangent{Any}(x=1.0) == MutableTangent{MDemo}(x=1.0)
454+
@test MutableTangent{MDemo}(x=1.0) == MutableTangent{Any}(x=1.0)
455+
@test MutableTangent{Any}(x=2.0) != MutableTangent{MDemo}(x=1.0)
456+
@test MutableTangent{MDemo}(x=1.0) != MutableTangent{Any}(x=2.0)
457+
458+
nt = (;x=1.0)
459+
@test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(x=1.0)
460+
461+
@test hash(MutableTangent{Any}(x=1.0)) == hash(MutableTangent{MDemo}(x=1.0))
462+
end
449463
end

0 commit comments

Comments
 (0)