Skip to content

Commit ccb6a60

Browse files
committed
Add sparse Enzyme Mode
1 parent 5a87fc9 commit ccb6a60

File tree

6 files changed

+83
-31
lines changed

6 files changed

+83
-31
lines changed

Project.toml

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
3-
authors = [
4-
"Pankaj Mishra <pankajmishra1511@gmail.com>",
5-
"Chris Rackauckas <contact@chrisrackauckas.com>",
6-
]
3+
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
74
version = "2.5.0"
85

96
[deps]
@@ -28,10 +25,12 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2825
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
2926

3027
[weakdeps]
28+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3129
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3230
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3331

3432
[extensions]
33+
SparseDiffToolsEnzymeExt = "Enzyme"
3534
SparseDiffToolsSymbolicsExt = "Symbolics"
3635
SparseDiffToolsZygoteExt = "Zygote"
3736

@@ -41,6 +40,7 @@ Adapt = "3.0"
4140
ArrayInterface = "7.4.2"
4241
Compat = "4"
4342
DataStructures = "0.18"
43+
Enzyme = "0.11"
4444
FiniteDiff = "2.8.1"
4545
ForwardDiff = "0.10"
4646
Graphs = "1"
@@ -61,6 +61,7 @@ ArrayInterfaceBandedMatrices = "2e50d22c-5be1-4042-81b1-c572ed69783d"
6161
ArrayInterfaceBlockBandedMatrices = "5331f1e9-51c7-46b0-a9b0-df4434785e0a"
6262
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
6363
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
64+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6465
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
6566
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
6667
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -71,17 +72,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7172
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7273

7374
[targets]
74-
test = [
75-
"Test",
76-
"ArrayInterfaceBandedMatrices",
77-
"ArrayInterfaceBlockBandedMatrices",
78-
"BandedMatrices",
79-
"BlockBandedMatrices",
80-
"IterativeSolvers",
81-
"Pkg",
82-
"Random",
83-
"SafeTestsets",
84-
"Symbolics",
85-
"Zygote",
86-
"StaticArrays",
87-
]
75+
test = ["Test", "ArrayInterfaceBandedMatrices", "ArrayInterfaceBlockBandedMatrices", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]

ext/SparseDiffToolsEnzymeExt.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
module SparseDiffToolsEnzymeExt
2+
3+
import ArrayInterface: fast_scalar_indexing
4+
import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!, AutoSparseEnzyme
5+
# FIXME: For Enzyme we currently assume reverse mode
6+
import ADTypes: AutoEnzyme
7+
using Enzyme
8+
9+
using ForwardDiff
10+
11+
## Satisfying High-Level Interface for Sparse Jacobians
12+
function __gradient(::Union{AutoSparseEnzyme, AutoEnzyme}, f, x, cols)
13+
dx = zero(x)
14+
autodiff(Reverse, __f̂, Const(f), Duplicated(x, dx), Const(cols))
15+
return vec(dx)
16+
end
17+
18+
function __gradient!(::Union{AutoSparseEnzyme, AutoEnzyme}, f!, fx, x, cols)
19+
dx = zero(x)
20+
dfx = zero(fx)
21+
autodiff(Reverse, __f̂, Active, Const(f!), Duplicated(fx, dfx), Duplicated(x, dx),
22+
Const(cols))
23+
return dx
24+
end
25+
26+
function __jacobian!(J::AbstractMatrix, ::Union{AutoSparseEnzyme, AutoEnzyme}, f, x)
27+
J .= jacobian(Reverse, f, x, Val(size(J, 1)))
28+
return J
29+
end
30+
31+
@views function __jacobian!(J, ad::Union{AutoSparseEnzyme, AutoEnzyme}, f!, fx, x)
32+
# This version is slowish not sure how to do jacobians for inplace functions
33+
@warn "Current code for computing jacobian for inplace functions in Enzyme is slow." maxlog=1
34+
dfx = zero(fx)
35+
dx = zero(x)
36+
37+
function __f_row_idx(f!, fx, x, row_idx)
38+
f!(fx, x)
39+
if fast_scalar_indexing(fx)
40+
return fx[row_idx]
41+
else
42+
# Avoid the GPU Arrays scalar indexing error
43+
return sum(selectdim(fx, 1, row_idx:row_idx))
44+
end
45+
end
46+
47+
for row_idx in 1:size(J, 1)
48+
autodiff(Reverse, __f_row_idx, Const(f!), DuplicatedNoNeed(fx, dfx),
49+
Duplicated(x, dx), Const(row_idx))
50+
J[row_idx, :] .= dx
51+
fill!(dfx, 0)
52+
fill!(dx, 0)
53+
end
54+
55+
return J
56+
end
57+
58+
end

src/SparseDiffTools.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ export update_coefficients, update_coefficients!, value!
8787

8888
# High Level Interface: sparse_jacobian
8989
export AutoSparseZygote # FIXME: Remove once https://github.com/SciML/ADTypes.jl/pull/16 is merged
90+
export AutoSparseEnzyme
91+
9092
export NoSparsityDetection,
9193
SymbolicsSparsityDetection, JacPrototypeSparsityDetection, AutoSparsityDetection
9294
export sparse_jacobian, sparse_jacobian_cache, sparse_jacobian!

src/highlevel/common.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
const AbstractSparseADType = Union{AbstractSparseForwardMode, AbstractSparseReverseMode,
22
AbstractSparseFiniteDifferences}
33

4+
struct AutoSparseEnzyme <: AbstractSparseReverseMode end
5+
46
# Sparsity Detection
57
abstract type AbstractMaybeSparsityDetection end
68
abstract type AbstractSparsityDetection <: AbstractMaybeSparsityDetection end

src/highlevel/reverse_mode.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ struct ReverseModeJacobianCache{CO, CA, J, FX, X, I} <: AbstractMaybeSparseJacob
77
idx_vec::I
88
end
99

10-
function sparse_jacobian_cache(ad::AbstractReverseMode, sd::AbstractMaybeSparsityDetection,
11-
f, x; fx = nothing)
10+
function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
11+
sd::AbstractMaybeSparsityDetection, f, x; fx = nothing)
1212
fx = fx === nothing ? similar(f(x)) : fx
1313
coloring_result = sd(ad, f, x)
1414
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
1515
return ReverseModeJacobianCache(coloring_result, nothing, jac_prototype, fx, x,
1616
collect(1:length(fx)))
1717
end
1818

19-
function sparse_jacobian_cache(ad::AbstractReverseMode, sd::AbstractMaybeSparsityDetection,
20-
f!, fx, x)
19+
function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
20+
sd::AbstractMaybeSparsityDetection, f!, fx, x)
2121
coloring_result = sd(ad, f!, fx, x)
2222
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
2323
return ReverseModeJacobianCache(coloring_result, nothing, jac_prototype, fx, x,

test/test_sparse_jacobian.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
## Sparse Jacobian tests
2-
using SparseDiffTools, Symbolics, ForwardDiff, LinearAlgebra, SparseArrays, Zygote
2+
using SparseDiffTools, Symbolics, ForwardDiff, LinearAlgebra, SparseArrays, Zygote, Enzyme
33
using Test
44

55
@views function fdiff(y, x) # in-place
6-
y[(begin + 1):(end - 1)] .= x[begin:(end - 2)] .- 2 .* x[(begin + 1):(end - 1)] .+
7-
x[(begin + 2):end]
8-
y[begin] = -2 * x[begin] + x[begin + 1]
9-
y[end] = x[end - 1] - 2 * x[end]
6+
L = length(x)
7+
y[2:(L - 1)] .= x[1:(L - 2)] .- 2 .* x[2:(L - 1)] .+ x[3:L]
8+
y[1] = -2 * x[1] + x[2]
9+
y[L] = x[L - 1] - 2 * x[L]
1010
return nothing
1111
end
1212

1313
@views function fdiff(x) # out-of-place
14-
y₂ = x[begin:(end - 2)] .- 2 .* x[(begin + 1):(end - 1)] .+ x[(begin + 2):end]
14+
L = length(x)
15+
y₂ = x[1:(L - 2)] .- 2 .* x[2:(L - 1)] .+ x[3:L]
1516
y₁ = -2x[1] + x[2]
16-
y₃ = x[end - 1] - 2x[end]
17+
y₃ = x[L - 1] - 2x[L]
1718
return vcat(y₁, y₂, y₃)
1819
end
1920

@@ -36,7 +37,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(jac_prototype = J_spars
3637

3738
@testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in (AutoSparseZygote(),
3839
AutoZygote(), AutoSparseForwardDiff(), AutoForwardDiff(),
39-
AutoSparseFiniteDiff(), AutoFiniteDiff())
40+
AutoSparseFiniteDiff(), AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
4041
@testset "Cache & Reuse" begin
4142
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
4243
J = SparseDiffTools.__init_𝒥(cache)
@@ -88,7 +89,8 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(jac_prototype = J_spars
8889
@info "Inplace Place Function"
8990

9091
@testset "sparse_jacobian $(nameof(typeof(difftype))): In place" for difftype in (AutoSparseForwardDiff(),
91-
AutoForwardDiff(), AutoSparseFiniteDiff(), AutoFiniteDiff())
92+
AutoForwardDiff(), AutoSparseFiniteDiff(), AutoFiniteDiff(), AutoEnzyme(),
93+
AutoSparseEnzyme())
9294
y = similar(x)
9395
cache = sparse_jacobian_cache(difftype, sd, fdiff, y, x)
9496

0 commit comments

Comments
 (0)