Skip to content

Commit 005dae2

Browse files
committed
wip
1 parent ef44800 commit 005dae2

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

src/tangent_types/structural_tangent.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,14 @@ as an object with mirroring fields.
1414
abstract type StructuralTangent{P} <: AbstractTangent end
1515

1616
function StructuralTangent{P}(nt::NamedTuple) where {P}
17-
if ismutabletype(P)
17+
if has_mutable_tangent(P)
1818
return MutableTangent{P}(nt)
1919
else
2020
return Tangent{P,typeof(nt)}(nt)
2121
end
2222
end
2323

24-
ismutabletype(::Type{P}) where P = ismutable(P)
25-
ismutabletype(::Type{String}) = false
26-
ismutabletype(::Type{Symbol}) = false
24+
has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(T) > 0)
2725

2826

2927
StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup)
@@ -410,16 +408,22 @@ canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent
410408
This type represents the tangent to a mutable struct.
411409
It itself is also mutable.
412410
413-
!!!!!! warning Exprimental
411+
!!! warning Exprimental
414412
MutableTangent is an experimental feature.
415413
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
416414
Exactly how it should be used (e.g. is it forward-mode only?)
415+
416+
!!! warning Do not directly mess with the tangent backing data
417+
It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values.
418+
However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is).
419+
If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this.
417420
"""
418421
mutable struct MutableTangent{P}
419-
# Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
420-
# (but potentially a different one, as it doesn't contain tangents)
421-
backing::NamedTuple
422+
#TODO: we may want to absolutely lock the type of this down
423+
backing::NamedTuple
422424
end
423425

424426
Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx))
425-
Base.setproperty!
427+
function Base.setproperty!(tangent::MutableTangent, name::Symbol, x)
428+
new_backing = Base.setindex(backing(tangent), x, name)
429+
end

0 commit comments

Comments
 (0)