Skip to content

Commit aba2fcb

Browse files
mcabbottmzgubicoxinabox
authored
Take is_inplaceable_destination seriously (#577)
* take is_inplaceable_destination seriously * fix tests on 1.9 * Update src/accumulation.jl Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au> * three tweaks Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au>
1 parent fbb4936 commit aba2fcb

File tree

4 files changed

+55
-37
lines changed

4 files changed

+55
-37
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ export frule_via_ad, rrule_via_ad
1212
# definition helper macros
1313
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
1414
export ProjectTo, canonicalize, unthunk # tangent operations
15-
export add!! # gradient accumulation operations
15+
export add!!, is_inplaceable_destination # gradient accumulation operations
1616
export ignore_derivatives, @ignore_derivatives
1717
# tangents
1818
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk

src/accumulation.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,30 +38,41 @@ end
3838
is_inplaceable_destination(x) -> Bool
3939
4040
Returns true if `x` is suitable for for storing inplace accumulation of gradients.
41-
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate
42-
tangent.
43-
Wrapper array types do not need to overload this if they overload `Base.parent`, and are
44-
`is_inplaceable_destination` if and only if their parent array is.
45-
Other types should overload this, as it defaults to `false`.
41+
For arrays this means `x .= y` will mutate `x`, if `y` is an appropriate tangent.
42+
43+
Here "appropriate" means that `y` cannot be complex unless `x` is too,
44+
and that for structured matrices like `x isa Diagonal`, `y` shares this structure.
45+
46+
!!! note "history"
47+
Wrapper array types should overload this function if they can be written into.
48+
Before ChainRulesCore 1.16, it would guess `true` for most wrappers based on `parent`,
49+
but this is not safe, e.g. it will lead to an error with ReadOnlyArrays.jl.
50+
51+
There must always be a correct non-mutating path, so in uncertain cases,
52+
this function returns `false`.
4653
"""
4754
is_inplaceable_destination(::Any) = false
55+
4856
is_inplaceable_destination(::Array) = true
57+
is_inplaceable_destination(:: Array{<:Integer}) = false
58+
4959
is_inplaceable_destination(::SparseVector) = true
5060
is_inplaceable_destination(::SparseMatrixCSC) = true
51-
is_inplaceable_destination(::BitArray) = true
52-
function is_inplaceable_destination(x::AbstractArray)
53-
p = parent(x)
54-
p === x && return false # no parent
55-
# basically all wrapper types delegate `setindex!` to their `parent` after some
56-
# processing and so are mutable if their `parent` is.
57-
return is_inplaceable_destination(p)
61+
62+
function is_inplaceable_destination(x::SubArray)
63+
alpha = is_inplaceable_destination(parent(x))
64+
beta = x.indices isa Tuple{Vararg{Union{Integer, Base.Slice, UnitRange}}}
65+
return alpha && beta
5866
end
5967

68+
for T in [:PermutedDimsArray, :ReshapedArray]
69+
@eval is_inplaceable_destination(x::Base.$T) = is_inplaceable_destination(parent(x))
70+
end
71+
for T in [:Adjoint, :Transpose, :Diagonal, :UpperTriangular, :LowerTriangular]
72+
@eval is_inplaceable_destination(x::LinearAlgebra.$T) = is_inplaceable_destination(parent(x))
73+
end
6074
# Hermitian and Symmetric are too fussy to deal with right now
6175
# https://github.com/JuliaLang/julia/issues/38056
62-
# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236
63-
is_inplaceable_destination(::LinearAlgebra.Hermitian) = false
64-
is_inplaceable_destination(::LinearAlgebra.Symmetric) = false
6576

6677
function debug_add!(accumuland, t::InplaceableThunk)
6778
returned_value = t.add!(accumuland)

test/accumulation.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@
22
@testset "is_inplaceable_destination" begin
33
is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination
44

5-
@test is_inplaceable_destination([1, 2, 3, 4])
6-
@test !is_inplaceable_destination(1:4)
5+
@test is_inplaceable_destination([1.0, 2.0, 3.0])
6+
@test !is_inplaceable_destination([1, 2, 3, 4]) # gradients cannot reliably be written into integer arrays
7+
@test !is_inplaceable_destination(1:4.0)
78

8-
@test is_inplaceable_destination(Diagonal([1, 2, 3, 4]))
9-
@test !is_inplaceable_destination(Diagonal(1:4))
9+
@test is_inplaceable_destination(Diagonal([1.0, 2.0, 3.0]))
10+
@test !is_inplaceable_destination(Diagonal(1:4.0))
1011

11-
@test is_inplaceable_destination(view([1, 2, 3, 4], :, :))
12-
@test !is_inplaceable_destination(view(1:4, :, :))
12+
@test is_inplaceable_destination(view([1.0, 2.0, 3.0], :, :))
13+
@test is_inplaceable_destination(view([1.0 2.0; 3.0 4.0], :, 2))
14+
@test !is_inplaceable_destination(view(1:4.0, :, :))
15+
mat = view([1.0, 2.0, 3.0], :, fill(1, 10))
16+
@test !is_inplaceable_destination(mat) # The concern is that `mat .+= x` is unsafe on GPU / parallel.
1317

14-
@test is_inplaceable_destination(falses(4))
18+
@test !is_inplaceable_destination(falses(4)) # gradients can never be written into boolean
1519
@test is_inplaceable_destination(spzeros(4))
1620
@test is_inplaceable_destination(spzeros(2, 2))
1721

18-
@test !is_inplaceable_destination(1.3)
19-
@test !is_inplaceable_destination(@SVector [1, 2, 3])
20-
@test !is_inplaceable_destination(Hermitian([1 2; 2 4]))
21-
@test !is_inplaceable_destination(Symmetric([1 2; 2 4]))
22+
@test !is_inplaceable_destination(1:3.0)
23+
@test !is_inplaceable_destination(@SVector [1.0, 2.0, 3.0])
24+
@test !is_inplaceable_destination(Hermitian([1.0 2.0; 2.0 4.0]))
2225
end
2326

2427
@testset "add!!" begin

test/tangent_types/tangent.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ end
109109
@test NoTangent() === @inferred Base.tail(ntang1)
110110

111111
# TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516
112-
if VERSION >= v"1.8-"
113-
@test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true
114-
else
115-
@test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true
116-
end
112+
# if VERSION >= v"1.8-"
113+
# @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true
114+
# else
115+
# @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true
116+
# end
117117
@test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false
118118

119119
@test length(Tangent{Foo}(; x=2.5)) == 1
@@ -148,12 +148,16 @@ end
148148
cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1)
149149
@test reverse(c) === cr
150150

151-
# can't reverse a named tuple or a dict
152-
@test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0))
151+
if VERSION < v"1.9-"
152+
# can't reverse a named tuple or a dict
153+
@test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0))
153154

154-
d = Dict(:x => 1, :y => 2.0)
155-
cdict = Tangent{typeof(d),typeof(d)}(d)
156-
@test_throws MethodError reverse(Tangent{Foo}())
155+
d = Dict(:x => 1, :y => 2.0)
156+
cdict = Tangent{typeof(d),typeof(d)}(d)
157+
@test_throws MethodError reverse(Tangent{Foo}())
158+
else
159+
# These now work but do we care?
160+
end
157161
end
158162

159163
@testset "unset properties" begin

0 commit comments

Comments
 (0)