Skip to content

Commit e4581ec

Browse files
committed
Dont break code for non-functions
1 parent 62ff85d commit e4581ec

File tree

6 files changed

+22
-18
lines changed

6 files changed

+22
-18
lines changed

src/highlevel/coloring.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
# Approximate Jacobian Sparsity Detection
3434
## Right now we hardcode it to use `ForwardDiff`
3535
function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f::F, x; fx = nothing,
36-
kwargs...) where {F <: Function}
36+
kwargs...) where {F}
3737
@unpack ntrials, rng = alg
3838
fx = fx === nothing ? f(x) : fx
3939
J = fill!(similar(fx, length(fx), length(x)), 0)
@@ -48,7 +48,7 @@ function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f::F, x; f
4848
end
4949

5050
function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f!::F, fx, x;
51-
kwargs...) where {F <: Function}
51+
kwargs...) where {F}
5252
@unpack ntrials, rng = alg
5353
cfg = ForwardDiff.JacobianConfig(f!, fx, x)
5454
J = fill!(similar(fx, length(fx), length(x)), 0)

src/highlevel/common.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,11 @@ function __chunksize(::Val{C}, x) where {C}
252252
error("$(C)::$(typeof(C)) is not a valid chunksize!")
253253
end
254254
end
255-
__chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C}}) where {C} = C
255+
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C}}) where {C}
256+
C === nothing && return nothing
257+
C isa Integer && !(C isa Bool) && return C 0 ? nothing : Val(C)
258+
return nothing
259+
end
256260

257261
__f̂(f, x, idxs) = dot(vec(f(x)), idxs)
258262

src/highlevel/finite_diff.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ struct FiniteDiffJacobianCache{CO, CA, J, FX, X} <: AbstractMaybeSparseJacobianC
77
end
88

99
function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
10-
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F <: Function}
10+
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1111
coloring_result = sd(fd, f, x)
1212
fx = fx === nothing ? similar(f(x)) : fx
1313
if coloring_result isa NoMatrixColoring
@@ -22,7 +22,7 @@ function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
2222
end
2323

2424
function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
25-
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F <: Function}
25+
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
2626
coloring_result = sd(fd, f!, fx, x)
2727
if coloring_result isa NoMatrixColoring
2828
cache = FiniteDiff.JacobianCache(x, fx)
@@ -36,13 +36,13 @@ function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
3636
end
3737

3838
function sparse_jacobian!(J::AbstractMatrix, fd, cache::FiniteDiffJacobianCache, f::F,
39-
x) where {F <: Function}
39+
x) where {F}
4040
f!(y, x) = (y .= f(x))
4141
return sparse_jacobian!(J, fd, cache, f!, cache.fx, x)
4242
end
4343

4444
function sparse_jacobian!(J::AbstractMatrix, _, cache::FiniteDiffJacobianCache, f!::F, _,
45-
x) where {F <: Function}
45+
x) where {F}
4646
FiniteDiff.finite_difference_jacobian!(J, f!, x, cache.cache)
4747
return J
4848
end

src/highlevel/forward_mode.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99
struct SparseDiffToolsTag end
1010

1111
function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
12-
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F <: Function}
12+
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1313
coloring_result = sd(ad, f, x)
1414
fx = fx === nothing ? similar(f(x)) : fx
1515
if coloring_result isa NoMatrixColoring
@@ -25,7 +25,7 @@ function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff}
2525
end
2626

2727
function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
28-
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F <: Function}
28+
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
2929
coloring_result = sd(ad, f!, fx, x)
3030
if coloring_result isa NoMatrixColoring
3131
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x),
@@ -40,7 +40,7 @@ function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff}
4040
end
4141

4242
function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache, f::F,
43-
x) where {F <: Function}
43+
x) where {F}
4444
if cache.cache isa ForwardColorJacCache
4545
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
4646
else
@@ -50,7 +50,7 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
5050
end
5151

5252
function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache, f!::F, fx,
53-
x) where {F <: Function}
53+
x) where {F}
5454
if cache.cache isa ForwardColorJacCache
5555
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
5656
else

src/highlevel/reverse_mode.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct ReverseModeJacobianCache{CO, CA, J, FX, X, I} <: AbstractMaybeSparseJacob
88
end
99

1010
function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
11-
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F <: Function}
11+
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1212
fx = fx === nothing ? similar(f(x)) : fx
1313
coloring_result = sd(ad, f, x)
1414
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
@@ -17,7 +17,7 @@ function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
1717
end
1818

1919
function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
20-
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F <: Function}
20+
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
2121
coloring_result = sd(ad, f!, fx, x)
2222
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
2323
return ReverseModeJacobianCache(coloring_result, nothing, jac_prototype, fx, x,
@@ -34,12 +34,12 @@ function sparse_jacobian!(J::AbstractMatrix, ad, cache::ReverseModeJacobianCache
3434
end
3535

3636
function __sparse_jacobian_reverse_impl!(J::AbstractMatrix, ad, idx_vec,
37-
cache::MatrixColoringResult, f::F, x) where {F <: Function}
37+
cache::MatrixColoringResult, f::F, x) where {F}
3838
return __sparse_jacobian_reverse_impl!(J, ad, idx_vec, cache, f, nothing, x)
3939
end
4040

4141
function __sparse_jacobian_reverse_impl!(J::AbstractMatrix, ad, idx_vec,
42-
cache::MatrixColoringResult, f::F, fx, x) where {F <: Function}
42+
cache::MatrixColoringResult, f::F, fx, x) where {F}
4343
# If `fx` is `nothing` then assume `f` is not in-place
4444
x_ = __maybe_copy_x(ad, x)
4545
fx_ = __maybe_copy_x(ad, fx)

test/test_sparse_jacobian.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
4141
@testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in (AutoSparseZygote(),
4242
AutoZygote(), AutoSparseForwardDiff(), AutoForwardDiff(),
4343
AutoSparseForwardDiff(; chunksize = 0), AutoForwardDiff(; chunksize = 0),
44-
AutoSparseForwardDiff(; chunksize = 8), AutoForwardDiff(; chunksize = 8),
44+
AutoSparseForwardDiff(; chunksize = 4), AutoForwardDiff(; chunksize = 4),
4545
AutoSparseFiniteDiff(), AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
4646
@testset "Cache & Reuse" begin
4747
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
@@ -95,8 +95,8 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
9595

9696
@testset "sparse_jacobian $(nameof(typeof(difftype))): In place" for difftype in (AutoSparseForwardDiff(),
9797
AutoForwardDiff(), AutoSparseForwardDiff(; chunksize = 0),
98-
AutoForwardDiff(; chunksize = 0), AutoSparseForwardDiff(; chunksize = 8),
99-
AutoForwardDiff(; chunksize = 8), AutoSparseFiniteDiff(), AutoFiniteDiff(),
98+
AutoForwardDiff(; chunksize = 0), AutoSparseForwardDiff(; chunksize = 4),
99+
AutoForwardDiff(; chunksize = 4), AutoSparseFiniteDiff(), AutoFiniteDiff(),
100100
AutoEnzyme(), AutoSparseEnzyme())
101101
y = similar(x)
102102
cache = sparse_jacobian_cache(difftype, sd, fdiff, y, x)

0 commit comments

Comments
 (0)