Skip to content

Commit 834d58e

Browse files
committed
Handle Subarrays
1 parent ccb6a60 commit 834d58e

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

ext/SparseDiffToolsEnzymeExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module SparseDiffToolsEnzymeExt
22

33
import ArrayInterface: fast_scalar_indexing
4-
import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!, AutoSparseEnzyme
4+
import SparseDiffTools: __f̂,
5+
__maybe_copy_x, __jacobian!, __gradient, __gradient!, AutoSparseEnzyme
56
# FIXME: For Enzyme we currently assume reverse mode
67
import ADTypes: AutoEnzyme
78
using Enzyme
@@ -55,4 +56,6 @@ end
5556
return J
5657
end
5758

59+
__maybe_copy_x(::Union{AutoSparseEnzyme, AutoEnzyme}, x::SubArray) = copy(x)
60+
5861
end

src/highlevel/common.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,6 @@ function __init_𝒥(c::AbstractMaybeSparseJacobianCache)
129129
end
130130
__init_𝒥(::Nothing, ::Type{T}, fx, x) where {T} = similar(fx, T, length(fx), length(x))
131131
__init_𝒥(J, ::Type{T}, _, _) where {T} = similar(J, T, size(J, 1), size(J, 2))
132+
133+
__maybe_copy_x(_, x) = x
134+
__maybe_copy_x(_, ::Nothing) = nothing

src/highlevel/reverse_mode.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,22 @@ function sparse_jacobian!(J::AbstractMatrix, ad, cache::ReverseModeJacobianCache
3434
end
3535

3636
function __sparse_jacobian_reverse_impl!(J::AbstractMatrix, ad, idx_vec,
37-
cache::MatrixColoringResult, f, x, fx = nothing)
37+
cache::MatrixColoringResult, f, x)
38+
return __sparse_jacobian_reverse_impl!(J, ad, idx_vec, cache, f, nothing, x)
39+
end
40+
41+
function __sparse_jacobian_reverse_impl!(J::AbstractMatrix, ad, idx_vec,
42+
cache::MatrixColoringResult, f, fx, x)
3843
# If `fx` is `nothing` then assume `f` is not in-place
44+
x_ = __maybe_copy_x(ad, x)
45+
fx_ = __maybe_copy_x(ad, fx)
3946
@unpack colorvec, nz_rows, nz_cols = cache
4047
for c in 1:maximum(colorvec)
4148
@. idx_vec = colorvec == c
4249
if fx === nothing
43-
gs = __gradient(ad, f, x, idx_vec)
50+
gs = __gradient(ad, f, x_, idx_vec)
4451
else
45-
gs = __gradient!(ad, f, fx, x, idx_vec)
52+
gs = __gradient!(ad, f, fx_, x_, idx_vec)
4653
end
4754
pick_idxs = filter(i -> colorvec[nz_rows[i]] == c, 1:length(nz_rows))
4855
row_idxs = nz_rows[pick_idxs]

0 commit comments

Comments
 (0)