3
3
4
4
Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple`).
5
5
as an object with mirroring fields.
6
+
7
+ !!!!!! warning Exprimental
8
+ The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs.
9
+ `MutableTangent` is an experimental feature.
10
+ Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental.
11
+ While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
12
+
6
13
"""
7
14
abstract type StructuralTangent{P} <: AbstractTangent end
8
15
9
16
function StructuralTangent {P} (nt:: NamedTuple ) where {P}
10
- return Tangent {P,typeof(nt)} (nt)
17
+ if ismutabletype (P)
18
+ return MutableTangent {P} (nt)
19
+ else
20
+ return Tangent {P,typeof(nt)} (nt)
21
+ end
11
22
end
12
23
13
- StructuralTangent {P} (tup:: Tuple ) where {P} = Tangent {P,typeof(tup)} (tup)
14
- StructuralTangent {P} (dict:: Dict ) where {P} = Tangent {P} (dict)
24
+ ismutabletype (:: Type{P} ) where P = ismutable (P)
25
+ ismutabletype (:: Type{String} ) = false
26
+ ismutabletype (:: Type{Symbol} ) = false
27
+
28
+
29
+ StructuralTangent {P} (tup:: Tuple ) where P = Tangent {P,typeof(tup)} (tup)
30
+ StructuralTangent {P} (dict:: Dict ) where P = Tangent {P} (dict)
15
31
16
32
Base. keys (tangent:: StructuralTangent ) = keys (backing (tangent))
17
33
Base. propertynames (tangent:: StructuralTangent ) = propertynames (backing (tangent))
@@ -29,10 +45,10 @@ function Base.map(f, tangent::StructuralTangent{P}) where {P}
29
45
L = propertynames (backing (tangent))
30
46
vals = map (f, Tuple (backing (tangent)))
31
47
named_vals = NamedTuple {L,typeof(vals)} (vals)
32
- return if tangent isa Tangent
33
- Tangent {P,typeof(named_vals) } (named_vals)
48
+ return if tangent isa MutableTangent
49
+ MutableTangent {P } (named_vals)
34
50
else
35
- # Handle MutableTangent
51
+ Tangent {P,typeof(named_vals)} (named_vals)
36
52
end
37
53
end
38
54
@@ -386,3 +402,24 @@ canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent
386
402
canonicalize (tangent:: Tangent{Any,<:NamedTuple{L}} ) where {L} = tangent
387
403
canonicalize (tangent:: Tangent{Any,<:Tuple} ) = tangent
388
404
canonicalize (tangent:: Tangent{Any,<:AbstractDict} ) = tangent
405
+
406
+
407
+ """
408
+ MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent
409
+
410
+ This type represents the tangent to a mutable struct.
411
+ It itself is also mutable.
412
+
413
+ !!!!!! warning Exprimental
414
+ MutableTangent is an experimental feature.
415
+ While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
416
+ Exactly how it should be used (e.g. is it forward-mode only?)
417
+ """
418
+ mutable struct MutableTangent{P}
419
+ # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
420
+ # (but potentially a different one, as it doesn't contain tangents)
421
+ backing:: NamedTuple
422
+ end
423
+
424
+ Base. getproperty (tangent:: MutableTangent , idx:: Symbol ) = unthunk (getfield (backing (tangent), idx))
425
+ Base. setproperty!
0 commit comments