Skip to content

Commit 9c2e90f

Browse files
committed
First pass at something that maybe works
1 parent 005dae2 commit 9c2e90f

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export ProjectTo, canonicalize, unthunk # tangent operations
1414
export add!!, is_inplaceable_destination # gradient accumulation operations
1515
export ignore_derivatives, @ignore_derivatives
1616
# tangents
17-
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
17+
export Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
1818

1919
include("debug_mode.jl")
2020

src/tangent_types/structural_tangent.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,15 @@ It itself is also mutable.
418418
However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is).
419419
If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this.
420420
"""
421-
mutable struct MutableTangent{P}
421+
mutable struct MutableTangent{P} <: StructuralTangent{P}
422422
#TODO: we may want to absolutely lock the type of this down
423423
backing::NamedTuple
424424
end
425425

426-
Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx))
426+
MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs))
427+
Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx)
427428
function Base.setproperty!(tangent::MutableTangent, name::Symbol, x)
428429
new_backing = Base.setindex(backing(tangent), x, name)
430+
setfield!(tangent, :backing, new_backing)
431+
return x
429432
end

test/tangent_types/structural_tangent.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,21 @@ end
425425
@test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole
426426
end
427427
end
428+
429+
@testset "MutableTangent" begin
430+
mutable struct MDemo
431+
x::Float64
432+
end
433+
function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x)
434+
y = setfield!(obj, field, x)
435+
= setproperty!(ȯbj, field, ẋ)
436+
return y, ẏ
437+
end
438+
439+
obj = MDemo(99.0)
440+
∂obj = MutableTangent{MDemo}(;x=1.5)
441+
frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0)
442+
443+
@test ∂obj.x == 10.0
444+
@test obj.x == 95.0
445+
end

0 commit comments

Comments
 (0)