Skip to content

Commit 354ecbd

Browse files
authored
Merge pull request #738 from JuliaDiff/ox/setfield
Add rules needed for mutation
2 parents 32bf53d + 810c633 commit 354ecbd

File tree

4 files changed

+25
-2
lines changed

4 files changed

+25
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2020

2121
[compat]
2222
Adapt = "3.4.0, 4"
23-
ChainRulesCore = "1.15.3"
23+
ChainRulesCore = "1.20"
2424
ChainRulesTestUtils = "1.5"
2525
Compat = "3.46, 4.2"
2626
Distributed = "1"

src/rulesets/Base/array.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,13 @@ _instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs
8686
#####
8787

8888
function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x)
89-
return copyto!(y, x), copyto!(ẏ, ẋ)
89+
ifisa AbstractZero
90+
# it's allowed to have an imutable zero tangent for ẏ as long as ẋ is zero
91+
@assert iszero(ẋ)
92+
else
93+
copyto!(ẏ, ẋ)
94+
end
95+
return copyto!(y, x), ẏ
9096
end
9197

9298
function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...)

src/rulesets/Base/base.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ function rrule(::typeof(one), x)
2626
return (one(x), one_pullback)
2727
end
2828

29+
30+
function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj, field, x)
31+
ȯbj::MutableTangent
32+
y = setfield!(obj, field, x)
33+
= setproperty!(ȯbj, field, ẋ)
34+
return y, ẏ
35+
end
36+
2937
# `adjoint`
3038

3139
frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz')

test/rulesets/Base/base.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
mutable struct MDemo
2+
x::Float64
3+
end
4+
15
@testset "base.jl" begin
26
@testset "zero/one" begin
37
for f in [zero, one]
@@ -18,6 +22,11 @@
1822
end
1923
end
2024
end
25+
26+
@testset "setfield!" begin
27+
test_frule(setfield!, MDemo(3.5) MutableTangent{MDemo}(; x=2.0), :x, 5.0)
28+
test_frule(setfield!, MDemo(3.5) MutableTangent{MDemo}(; x=2.0), 1, 5.0)
29+
end
2130

2231
@testset "Trig" begin
2332
@testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))

0 commit comments

Comments
 (0)