Skip to content

Commit ef44800

Browse files
committed
WIP mutable Tangent (squash me)
1 parent 8afdd49 commit ef44800

File tree

1 file changed

+43
-6
lines changed

1 file changed

+43
-6
lines changed

src/tangent_types/structural_tangent.jl

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,31 @@
33
44
Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple`).
55
as an object with mirroring fields.
6+
7+
!!!!!! warning Exprimental
8+
The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs.
9+
`MutableTangent` is an experimental feature.
10+
Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental.
11+
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
12+
613
"""
714
abstract type StructuralTangent{P} <: AbstractTangent end
815

916
function StructuralTangent{P}(nt::NamedTuple) where {P}
10-
return Tangent{P,typeof(nt)}(nt)
17+
if ismutabletype(P)
18+
return MutableTangent{P}(nt)
19+
else
20+
return Tangent{P,typeof(nt)}(nt)
21+
end
1122
end
1223

13-
StructuralTangent{P}(tup::Tuple) where {P} = Tangent{P,typeof(tup)}(tup)
14-
StructuralTangent{P}(dict::Dict) where {P} = Tangent{P}(dict)
24+
ismutabletype(::Type{P}) where P = ismutable(P)
25+
ismutabletype(::Type{String}) = false
26+
ismutabletype(::Type{Symbol}) = false
27+
28+
29+
StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup)
30+
StructuralTangent{P}(dict::Dict) where P = Tangent{P}(dict)
1531

1632
Base.keys(tangent::StructuralTangent) = keys(backing(tangent))
1733
Base.propertynames(tangent::StructuralTangent) = propertynames(backing(tangent))
@@ -29,10 +45,10 @@ function Base.map(f, tangent::StructuralTangent{P}) where {P}
2945
L = propertynames(backing(tangent))
3046
vals = map(f, Tuple(backing(tangent)))
3147
named_vals = NamedTuple{L,typeof(vals)}(vals)
32-
return if tangent isa Tangent
33-
Tangent{P,typeof(named_vals)}(named_vals)
48+
return if tangent isa MutableTangent
49+
MutableTangent{P}(named_vals)
3450
else
35-
# Handle MutableTangent
51+
Tangent{P,typeof(named_vals)}(named_vals)
3652
end
3753
end
3854

@@ -386,3 +402,24 @@ canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent
386402
canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent
387403
canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent
388404
canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent
405+
406+
407+
"""
408+
MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent
409+
410+
This type represents the tangent to a mutable struct.
411+
It itself is also mutable.
412+
413+
!!!!!! warning Exprimental
414+
MutableTangent is an experimental feature.
415+
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
416+
Exactly how it should be used (e.g. is it forward-mode only?)
417+
"""
418+
mutable struct MutableTangent{P}
419+
# Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
420+
# (but potentially a different one, as it doesn't contain tangents)
421+
backing::NamedTuple
422+
end
423+
424+
Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx))
425+
Base.setproperty!

0 commit comments

Comments
 (0)