@@ -14,16 +14,14 @@ as an object with mirroring fields.
14
14
abstract type StructuralTangent{P} <: AbstractTangent end
15
15
16
16
function StructuralTangent {P} (nt:: NamedTuple ) where {P}
17
- if ismutabletype (P)
17
+ if has_mutable_tangent (P)
18
18
return MutableTangent {P} (nt)
19
19
else
20
20
return Tangent {P,typeof(nt)} (nt)
21
21
end
22
22
end
23
23
24
- ismutabletype (:: Type{P} ) where P = ismutable (P)
25
- ismutabletype (:: Type{String} ) = false
26
- ismutabletype (:: Type{Symbol} ) = false
24
+ has_mutable_tangent (:: Type{P} ) where P = ismutabletype (P) && (! isabstracttype (P) && fieldcount (T) > 0 )
27
25
28
26
29
27
StructuralTangent {P} (tup:: Tuple ) where P = Tangent {P,typeof(tup)} (tup)
@@ -410,16 +408,22 @@ canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent
410
408
This type represents the tangent to a mutable struct.
411
409
It itself is also mutable.
412
410
413
- !!!!!! warning Exprimental
411
+ !!! warning Exprimental
414
412
MutableTangent is an experimental feature.
415
413
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
416
414
Exactly how it should be used (e.g. is it forward-mode only?)
415
+
416
+ !!! warning Do not directly mess with the tangent backing data
417
+ It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values.
418
+ However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is).
419
+ If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this.
417
420
"""
418
421
mutable struct MutableTangent{P}
419
- # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
420
- # (but potentially a different one, as it doesn't contain tangents)
421
- backing:: NamedTuple
422
+ # TODO : we may want to absolutely lock the type of this down
423
+ backing:: NamedTuple
422
424
end
423
425
424
426
Base. getproperty (tangent:: MutableTangent , idx:: Symbol ) = unthunk (getfield (backing (tangent), idx))
425
- Base. setproperty!
427
+ function Base. setproperty! (tangent:: MutableTangent , name:: Symbol , x)
428
+ new_backing = Base. setindex (backing (tangent), x, name)
429
+ end
0 commit comments