|
38 | 38 | is_inplaceable_destination(x) -> Bool
|
39 | 39 |
|
40 | 40 | Returns true if `x` is suitable for for storing inplace accumulation of gradients.
|
41 |
| -For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate |
42 |
| -tangent. |
43 |
| -Wrapper array types do not need to overload this if they overload `Base.parent`, and are |
44 |
| -`is_inplaceable_destination` if and only if their parent array is. |
45 |
| -Other types should overload this, as it defaults to `false`. |
| 41 | +For arrays this means `x .= y` will mutate `x`, if `y` is an appropriate tangent. |
| 42 | +
|
| 43 | +Here "appropriate" means that `y` cannot be complex unless `x` is too, |
| 44 | +and that for structured matrices like `x isa Diagonal`, `y` shares this structure. |
| 45 | +
|
| 46 | +!!! note "history" |
| 47 | + Wrapper array types should overload this function if they can be written into. |
| 48 | + Before ChainRulesCore 1.16, it would guess `true` for most wrappers based on `parent`, |
| 49 | + but this is not safe, e.g. it will lead to an error with ReadOnlyArrays.jl. |
| 50 | +
|
| 51 | +There must always be a correct non-mutating path, so in uncertain cases, |
| 52 | +this function returns `false`. |
46 | 53 | """
|
47 | 54 | is_inplaceable_destination(::Any) = false
|
| 55 | + |
48 | 56 | is_inplaceable_destination(::Array) = true
|
| 57 | +is_inplaceable_destination(:: Array{<:Integer}) = false |
| 58 | + |
49 | 59 | is_inplaceable_destination(::SparseVector) = true
|
50 | 60 | is_inplaceable_destination(::SparseMatrixCSC) = true
|
51 |
| -is_inplaceable_destination(::BitArray) = true |
52 |
| -function is_inplaceable_destination(x::AbstractArray) |
53 |
| - p = parent(x) |
54 |
| - p === x && return false # no parent |
55 |
| - # basically all wrapper types delegate `setindex!` to their `parent` after some |
56 |
| - # processing and so are mutable if their `parent` is. |
57 |
| - return is_inplaceable_destination(p) |
| 61 | + |
| 62 | +function is_inplaceable_destination(x::SubArray) |
| 63 | + alpha = is_inplaceable_destination(parent(x)) |
| 64 | + beta = x.indices isa Tuple{Vararg{Union{Integer, Base.Slice, UnitRange}}} |
| 65 | + return alpha && beta |
58 | 66 | end
|
59 | 67 |
|
| 68 | +for T in [:PermutedDimsArray, :ReshapedArray] |
| 69 | + @eval is_inplaceable_destination(x::Base.$T) = is_inplaceable_destination(parent(x)) |
| 70 | +end |
| 71 | +for T in [:Adjoint, :Transpose, :Diagonal, :UpperTriangular, :LowerTriangular] |
| 72 | + @eval is_inplaceable_destination(x::LinearAlgebra.$T) = is_inplaceable_destination(parent(x)) |
| 73 | +end |
60 | 74 | # Hermitian and Symmetric are too fussy to deal with right now
|
61 | 75 | # https://github.com/JuliaLang/julia/issues/38056
|
62 |
| -# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236 |
63 |
| -is_inplaceable_destination(::LinearAlgebra.Hermitian) = false |
64 |
| -is_inplaceable_destination(::LinearAlgebra.Symmetric) = false |
65 | 76 |
|
66 | 77 | function debug_add!(accumuland, t::InplaceableThunk)
|
67 | 78 | returned_value = t.add!(accumuland)
|
|
0 commit comments