Skip to content

Commit db9d14a

Browse files
committed
Fix
1 parent dbd8a73 commit db9d14a

9 files changed

+32
-21
lines changed

ext/SparseDiffToolsEnzymeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module SparseDiffToolsEnzymeExt
22

33
import ArrayInterface: fast_scalar_indexing
44
import SparseDiffTools: __f̂, __maybe_copy_x, __jacobian!, __gradient, __gradient!,
5-
AutoSparse{<:AutoEnzyme}, __test_backend_loaded
5+
__test_backend_loaded
66
# FIXME: For Enzyme we currently assume reverse mode
77
import ADTypes: AutoSparse, AutoEnzyme
88
using Enzyme

ext/SparseDiffToolsPolyesterForwardDiffExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
1717
x::X
1818
end
1919

20-
function sparse_jacobian_cache(
20+
function sparse_jacobian_cache_aux(::ADTypes.ForwardMode,
2121
ad::Union{AutoSparse{<:AutoPolyesterForwardDiff}, AutoPolyesterForwardDiff},
2222
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
2323
coloring_result = sd(ad, f, x)
@@ -35,7 +35,7 @@ function sparse_jacobian_cache(
3535
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
3636
end
3737

38-
function sparse_jacobian_cache(
38+
function sparse_jacobian_cache_aux(::ADTypes.ForwardMode,
3939
ad::Union{AutoSparse{<:AutoPolyesterForwardDiff}, AutoPolyesterForwardDiff},
4040
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
4141
coloring_result = sd(ad, f!, fx, x)

ext/SparseDiffToolsZygoteExt.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import SparseDiffTools: numback_hesvec!,
1111
numback_hesvec, autoback_hesvec!, autoback_hesvec, auto_vecjac!,
1212
auto_vecjac
1313
import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!
14-
import ADTypes: AutoZygote, AutoSparse{<:AutoZygote}
14+
import ADTypes: AutoZygote, AutoSparse
1515

1616
@inline __test_backend_loaded(::Union{AutoSparse{<:AutoZygote}, AutoZygote}) = nothing
1717

@@ -21,15 +21,17 @@ function __gradient(::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f::F, x, cols
2121
return vec(∂x)
2222
end
2323

24-
function __gradient!(::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f!::F, fx, x, cols) where {F}
24+
function __gradient!(
25+
::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f!::F, fx, x, cols) where {F}
2526
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
2627
end
2728

2829
# Zygote doesn't provide a way to accumulate directly into `J`. So we modify the code from
2930
# https://github.com/FluxML/Zygote.jl/blob/82c7a000bae7fb0999275e62cc53ddb61aed94c7/src/lib/grad.jl#L140-L157C4
3031
import Zygote: _jvec, _eyelike, _gradcopy!
3132

32-
@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f::F,
33+
@views function __jacobian!(
34+
J::AbstractMatrix, ::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f::F,
3335
x) where {F}
3436
y, back = Zygote.pullback(_jvec f, x)
3537
δ = _eyelike(y)
@@ -40,7 +42,8 @@ import Zygote: _jvec, _eyelike, _gradcopy!
4042
return J
4143
end
4244

43-
function __jacobian!(_, ::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f!::F, fx, x) where {F}
45+
function __jacobian!(
46+
_, ::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f!::F, fx, x) where {F}
4447
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
4548
end
4649

src/highlevel/coloring.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ end
2424

2525
# Prespecified Colorvecs
2626
function (alg::PrecomputedJacobianColorvec)(ad::AutoSparse, args...; kwargs...)
27-
colorvec = _get_colorvec(alg, ad)
27+
colorvec = _get_colorvec(alg, mode(ad))
2828
J = alg.jac_prototype
2929
(nz_rows, nz_cols) = ArrayInterface.findstructralnz(J)
3030
return MatrixColoringResult(colorvec, J, nz_rows, nz_cols)

src/highlevel/common.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ If `fx` is not specified, it will be computed by calling `f(x)`.
170170
171171
A cache for computing the Jacobian of type `AbstractMaybeSparseJacobianCache`.
172172
"""
173-
function sparse_jacobian_cache end
173+
function sparse_jacobian_cache(ad::AbstractADType, sd::AbstractSparsityDetection, args...)
174+
return sparse_jacobian_cache_aux(mode(ad), ad, sd, args...)
175+
end
174176

175177
function sparse_jacobian_static_array(ad, cache, f, x::SArray)
176178
# Not the most performant fallback

src/highlevel/finite_diff.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ end
88

99
__getfield(c::FiniteDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype
1010

11-
function sparse_jacobian_cache(fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFiniteDiff},
11+
function sparse_jacobian_cache_aux(
12+
::ForwardMode, fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFiniteDiff},
1213
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1314
coloring_result = sd(fd, f, x)
1415
fx = fx === nothing ? similar(f(x)) : fx
@@ -23,7 +24,8 @@ function sparse_jacobian_cache(fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFinit
2324
return FiniteDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
2425
end
2526

26-
function sparse_jacobian_cache(fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFiniteDiff},
27+
function sparse_jacobian_cache_aux(
28+
::ForwardMode, fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFiniteDiff},
2729
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
2830
coloring_result = sd(fd, f!, fx, x)
2931
if coloring_result isa NoMatrixColoring

src/highlevel/forward_mode.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ __standard_tag(::Nothing, f::F, x) where {F} = ForwardDiff.Tag(f, eltype(x))
1212
__standard_tag(tag::ForwardDiff.Tag, ::F, _) where {F} = tag
1313
__standard_tag(tag, f::F, x) where {F} = ForwardDiff.Tag(f, eltype(x))
1414

15-
function sparse_jacobian_cache(ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
15+
function sparse_jacobian_cache_aux(
16+
::ForwardMode, ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
1617
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1718
coloring_result = sd(ad, f, x)
1819
fx = fx === nothing ? similar(f(x)) : fx
@@ -29,7 +30,8 @@ function sparse_jacobian_cache(ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForw
2930
return ForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
3031
end
3132

32-
function sparse_jacobian_cache(ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
33+
function sparse_jacobian_cache_aux(
34+
::ForwardMode, ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
3335
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
3436
coloring_result = sd(ad, f!, fx, x)
3537
tag = __standard_tag(ad.tag, f!, x)

src/highlevel/reverse_mode.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99

1010
__getfield(c::ReverseModeJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype
1111

12-
function sparse_jacobian_cache(ad::Union{AutoEnzyme, ReverseMode},
12+
function sparse_jacobian_cache_aux(::ReverseMode, ad::AbstractADType,
1313
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1414
fx = fx === nothing ? similar(f(x)) : fx
1515
coloring_result = sd(ad, f, x)
@@ -18,7 +18,7 @@ function sparse_jacobian_cache(ad::Union{AutoEnzyme, ReverseMode},
1818
collect(1:length(fx)))
1919
end
2020

21-
function sparse_jacobian_cache(ad::Union{AutoEnzyme, ReverseMode},
21+
function sparse_jacobian_cache_aux(::ReverseMode, ad::AbstractADType,
2222
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
2323
coloring_result = sd(ad, f!, fx, x)
2424
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))

test/test_sparse_jacobian.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
5656
@info "Out of Place Function"
5757

5858
DIFFTYPES = [AutoSparse(AutoZygote()), AutoZygote(), AutoSparse(AutoForwardDiff()),
59-
AutoForwardDiff(), AutoSparse{<:AutoForwardDiff}(; chunksize = 0),
60-
AutoForwardDiff(; chunksize = 0), AutoSparse{<:AutoForwardDiff}(; chunksize = 4),
59+
AutoForwardDiff(), AutoSparse(AutoForwardDiff(; chunksize = 0)),
60+
AutoForwardDiff(; chunksize = 0), AutoSparse(AutoForwardDiff(; chunksize = 4)),
6161
AutoForwardDiff(; chunksize = 4), AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(),
6262
AutoEnzyme(), AutoSparse(AutoEnzyme())]
6363

6464
if VERSION v"1.9"
6565
append!(DIFFTYPES,
6666
[AutoSparse(AutoPolyesterForwardDiff()), AutoPolyesterForwardDiff(),
67-
AutoSparse{<:AutoPolyesterForwardDiff}(; chunksize = 0),
67+
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 0)),
6868
AutoPolyesterForwardDiff(; chunksize = 0),
69-
AutoSparse{<:AutoPolyesterForwardDiff}(; chunksize = 4),
69+
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 4)),
7070
AutoPolyesterForwardDiff(; chunksize = 4)])
7171
end
7272

@@ -124,7 +124,8 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
124124
@testset "sparse_jacobian $(nameof(typeof(difftype))): In place" for difftype in (
125125
AutoSparse(AutoForwardDiff()),
126126
AutoForwardDiff(), AutoSparse{<:AutoForwardDiff}(; chunksize = 0),
127-
AutoForwardDiff(; chunksize = 0), AutoSparse{<:AutoForwardDiff}(; chunksize = 4),
127+
AutoForwardDiff(; chunksize = 0), AutoSparse{<:AutoForwardDiff}(;
128+
chunksize = 4),
128129
AutoForwardDiff(; chunksize = 4), AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(),
129130
AutoEnzyme(), AutoSparse(AutoEnzyme()))
130131
y = similar(x)
@@ -211,7 +212,8 @@ end
211212
end
212213

213214
@testset "Static Arrays" begin
214-
@testset "No Allocations: $(difftype)" for difftype in (AutoSparse(AutoForwardDiff()),
215+
@testset "No Allocations: $(difftype)" for difftype in (
216+
AutoSparse(AutoForwardDiff()),
215217
AutoForwardDiff())
216218
J = __sparse_jacobian_no_allocs(difftype, NoSparsityDetection(), fvcat, x_sa)
217219
@test J J_true_sa

0 commit comments

Comments
 (0)