Skip to content

Commit c1726af

Browse files
committed
First pass at something that maybe works
1 parent 4d16be9 commit c1726af

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export ProjectTo, canonicalize, unthunk # tangent operations
1515
export add!!, is_inplaceable_destination # gradient accumulation operations
1616
export ignore_derivatives, @ignore_derivatives
1717
# tangents
18-
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
18+
export Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
1919

2020
include("debug_mode.jl")
2121

src/tangent_types/structural_tangent.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,15 @@ It itself is also mutable.
418418
However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is).
419419
If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this.
420420
"""
421-
mutable struct MutableTangent{P}
421+
mutable struct MutableTangent{P} <: StructuralTangent{P}
422422
#TODO: we may want to absolutely lock the type of this down
423423
backing::NamedTuple
424424
end
425425

426-
Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx))
426+
MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs))
427+
Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx)
427428
function Base.setproperty!(tangent::MutableTangent, name::Symbol, x)
428429
new_backing = Base.setindex(backing(tangent), x, name)
430+
setfield!(tangent, :backing, new_backing)
431+
return x
429432
end

test/tangent_types/structural_tangent.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,21 @@ end
421421
@test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole
422422
end
423423
end
424+
425+
@testset "MutableTangent" begin
426+
mutable struct MDemo
427+
x::Float64
428+
end
429+
function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x)
430+
y = setfield!(obj, field, x)
431+
= setproperty!(ȯbj, field, ẋ)
432+
return y, ẏ
433+
end
434+
435+
obj = MDemo(99.0)
436+
∂obj = MutableTangent{MDemo}(;x=1.5)
437+
frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0)
438+
439+
@test ∂obj.x == 10.0
440+
@test obj.x == 95.0
441+
end

0 commit comments

Comments
 (0)