Skip to content

Commit 7f99ce4

Browse files
committed
add == and hash for MutableTangent
1 parent b2adc2a commit 7f99ce4

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

src/tangent_types/structural_tangent.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple
55
as an object with mirroring fields.
66
77
!!!!!! warning Exprimental
8+
`StructuralTangent` is an experimental feature, and is part of the mutation support featureset.
89
The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs.
910
`MutableTangent` is an experimental feature.
1011
Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental.
@@ -409,7 +410,7 @@ This type represents the tangent to a mutable struct.
409410
It itself is also mutable.
410411
411412
!!! warning Exprimental
412-
MutableTangent is an experimental feature.
413+
MutableTangent is an experimental feature, and is part of the mutation support featureset.
413414
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
414415
Exactly how it should be used (e.g. is it forward-mode only?)
415416
@@ -440,4 +441,10 @@ function Base.setproperty!(tangent::MutableTangent, idx::Int, x)
440441
return setproperty!(tangent, name, x)
441442
end
442443

443-
idx2sym(::NamedTuple{names}, idx) where names = names[idx]
444+
idx2sym(::NamedTuple{names}, idx) where names = names[idx]
445+
446+
Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h)
447+
function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2}
448+
typeintersect(T1, T2) == Union{} && return false
449+
backing(t1)==backing(t2)
450+
end

test/tangent_types/structural_tangent.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -432,14 +432,28 @@ end
432432
return y, ẏ
433433
end
434434

435-
obj = MDemo(99.0)
436-
∂obj = MutableTangent{MDemo}(;x=1.5)
437-
frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0)
438-
@test ∂obj.x == 10.0
439-
@test obj.x == 95.0
440-
441-
frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0)
442-
@test ∂obj.x == 20.0
443-
@test getproperty(∂obj, 1) == 20.0
444-
@test obj.x == 96.0
435+
@testset "usecase" begin
436+
obj = MDemo(99.0)
437+
∂obj = MutableTangent{MDemo}(;x=1.5)
438+
frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0)
439+
@test ∂obj.x == 10.0
440+
@test obj.x == 95.0
441+
442+
frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0)
443+
@test ∂obj.x == 20.0
444+
@test getproperty(∂obj, 1) == 20.0
445+
@test obj.x == 96.0
446+
end
447+
448+
@testset "== and hash" begin
449+
@test MutableTangent{Any}(x=1.0) == MutableTangent{MDemo}(x=1.0)
450+
@test MutableTangent{MDemo}(x=1.0) == MutableTangent{Any}(x=1.0)
451+
@test MutableTangent{Any}(x=2.0) != MutableTangent{MDemo}(x=1.0)
452+
@test MutableTangent{MDemo}(x=1.0) != MutableTangent{Any}(x=2.0)
453+
454+
nt = (;x=1.0)
455+
@test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(x=1.0)
456+
457+
@test hash(MutableTangent{Any}(x=1.0)) == hash(MutableTangent{MDemo}(x=1.0))
458+
end
445459
end

0 commit comments

Comments
 (0)