Skip to content

Commit 1365d5e

Browse files
committed
Last fixes
1 parent 6c5811f commit 1365d5e

File tree

4 files changed

+26
-17
lines changed

4 files changed

+26
-17
lines changed

ext/SparseDiffToolsPolyesterForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ADTypes, SparseDiffTools, PolyesterForwardDiff, UnPack, Random, SparseArra
44
import ForwardDiff
55
import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection,
66
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache,
7+
sparse_jacobian_cache_aux,
78
sparse_jacobian!,
89
sparse_jacobian_static_array, __standard_tag, __chunksize,
910
polyesterforwarddiff_color_jacobian

src/highlevel/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,8 @@ const __init_𝒥 = init_jacobian
273273

274274
# Misc Functions
275275
function __chunksize(
276-
::Union{AutoSparse{<:AutoForwardDiff}{C}, AutoForwardDiff{C},
277-
AutoSparse{<:AutoPolyesterForwardDiff}{C}, AutoPolyesterForwardDiff{C}},
276+
::Union{AutoSparse{<:AutoForwardDiff{C}}, AutoForwardDiff{C},
277+
AutoSparse{<:AutoPolyesterForwardDiff{C}}, AutoPolyesterForwardDiff{C}},
278278
x) where {C}
279279
C isa ForwardDiff.Chunk && return C
280280
return __chunksize(Val(C), x)

src/highlevel/forward_or_reverse_mode.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
function sparse_jacobian_cache_aux(::ForwardOrReverseMode, ad::AbstractADType,
22
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
3-
if ad isa AutoEnzyme
3+
if ad isa Union{AutoEnzyme, AutoSparse{<:AutoEnzyme}}
44
return sparse_jacobian_cache_aux(ReverseMode(), ad, sd, f, x; fx)
5-
elseif ad isa AutoDiffractor
5+
elseif ad isa Union{AutoDiffractor, AutoSparse{<:AutoDiffractor}}
66
return sparse_jacobian_cache_aux(ForwardMode(), ad, sd, f, x; fx)
77
else
88
error("Unknown mixed mode AD")
@@ -11,9 +11,9 @@ end
1111

1212
function sparse_jacobian_cache_aux(::ForwardOrReverseMode, ad::AbstractADType,
1313
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
14-
if ad isa AutoEnzyme
14+
if ad isa Union{AutoEnzyme, AutoSparse{<:AutoEnzyme}}
1515
return sparse_jacobian_cache_aux(ReverseMode(), ad, sd, f!, fx, x)
16-
elseif ad isa AutoDiffractor
16+
elseif ad isa Union{AutoDiffractor, AutoSparse{<:AutoDiffractor}}
1717
return sparse_jacobian_cache_aux(ForwardMode(), ad, sd, f!, fx, x)
1818
else
1919
error("Unknown mixed mode AD")

test/test_sparse_jacobian.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ using ADTypes, SparseDiffTools,
55
StaticArrays
66
using ADTypes: dense_ad
77

8+
function nice_string(ad::AbstractADType)
9+
if ad isa AutoSparse
10+
return "AutoSparse($(nice_string(dense_ad(ad))))"
11+
else
12+
return nameof(typeof(ad))
13+
end
14+
end
15+
816
function __chunksize(::Union{
917
AutoSparse{<:AutoForwardDiff{C}}, AutoForwardDiff{C},
1018
AutoSparse{<:AutoPolyesterForwardDiff{C}}, AutoPolyesterForwardDiff{C}
@@ -70,7 +78,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
7078
AutoPolyesterForwardDiff(; chunksize = 4)
7179
]
7280

73-
@testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in DIFFTYPES
81+
@testset "sparse_jacobian $(nice_string(difftype)): Out of Place" for difftype in DIFFTYPES
7482
@testset "Cache & Reuse" begin
7583
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
7684
J = init_jacobian(cache)
@@ -81,7 +89,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
8189
@inferred sparse_jacobian!(J, difftype, cache, fdiff, x)
8290

8391
t₁ = @elapsed sparse_jacobian!(J, difftype, cache, fdiff, x)
84-
@info "$(nameof(typeof(difftype)))() `sparse_jacobian!` (only differentiation) time: $(t₁)s"
92+
@info "$(nice_string(difftype))() `sparse_jacobian!` (only differentiation) time: $(t₁)s"
8593

8694
J = sparse_jacobian(difftype, cache, fdiff, x)
8795

@@ -92,7 +100,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
92100
end
93101

94102
t₂ = @elapsed sparse_jacobian(difftype, cache, fdiff, x)
95-
@info "$(nameof(typeof(difftype)))() `sparse_jacobian` (with matrix allocation) time: $(t₂)s"
103+
@info "$(nice_string(difftype))() `sparse_jacobian` (with matrix allocation) time: $(t₂)s"
96104
end
97105

98106
@testset "Single Use" begin
@@ -104,7 +112,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
104112
end
105113

106114
t₁ = @elapsed sparse_jacobian(difftype, sd, fdiff, x)
107-
@info "$(nameof(typeof(difftype)))() `sparse_jacobian` (complete) time: $(t₁)s"
115+
@info "$(nice_string(difftype))() `sparse_jacobian` (complete) time: $(t₁)s"
108116

109117
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
110118
J = init_jacobian(cache)
@@ -115,13 +123,13 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
115123
@inferred sparse_jacobian!(J, difftype, sd, fdiff, x)
116124

117125
t₂ = @elapsed sparse_jacobian!(J, difftype, sd, fdiff, x)
118-
@info "$(nameof(typeof(difftype)))() `sparse_jacobian!` (with matrix coloring) time: $(t₂)s"
126+
@info "$(nice_string(difftype))() `sparse_jacobian!` (with matrix coloring) time: $(t₂)s"
119127
end
120128
end
121129

122130
@info "Inplace Place Function"
123131

124-
@testset "sparse_jacobian $(nameof(typeof(difftype))): In place" for difftype in (
132+
@testset "sparse_jacobian $(nice_string(difftype)): In place" for difftype in (
125133
AutoSparse(AutoForwardDiff()), AutoForwardDiff(),
126134
AutoSparse(AutoForwardDiff(; chunksize = 0)), AutoForwardDiff(; chunksize = 0),
127135
AutoSparse(AutoForwardDiff(; chunksize = 4)), AutoForwardDiff(; chunksize = 4),
@@ -138,7 +146,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
138146
@inferred sparse_jacobian!(J, difftype, cache, fdiff, y, x)
139147

140148
t₁ = @elapsed sparse_jacobian!(J, difftype, cache, fdiff, y, x)
141-
@info "$(nameof(typeof(difftype)))() `sparse_jacobian!` (only differentiation) time: $(t₁)s"
149+
@info "$(nice_string(difftype))() `sparse_jacobian!` (only differentiation) time: $(t₁)s"
142150

143151
J = sparse_jacobian(difftype, cache, fdiff, y, x)
144152

@@ -148,7 +156,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
148156
end
149157

150158
t₂ = @elapsed sparse_jacobian(difftype, cache, fdiff, y, x)
151-
@info "$(nameof(typeof(difftype)))() `sparse_jacobian` (with jacobian allocation) time: $(t₂)s"
159+
@info "$(nice_string(difftype))() `sparse_jacobian` (with jacobian allocation) time: $(t₂)s"
152160
end
153161

154162
@testset "Single Use" begin
@@ -160,7 +168,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
160168
end
161169

162170
t₁ = @elapsed sparse_jacobian(difftype, sd, fdiff, y, x)
163-
@info "$(nameof(typeof(difftype)))() `sparse_jacobian` (complete) time: $(t₁)s"
171+
@info "$(nice_string(difftype))() `sparse_jacobian` (complete) time: $(t₁)s"
164172

165173
J = init_jacobian(cache)
166174

@@ -170,11 +178,11 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
170178
@inferred sparse_jacobian!(J, difftype, sd, fdiff, y, x)
171179

172180
t₂ = @elapsed sparse_jacobian!(J, difftype, sd, fdiff, y, x)
173-
@info "$(nameof(typeof(difftype)))() `sparse_jacobian!` (with matrix coloring) time: $(t₂)s"
181+
@info "$(nice_string(difftype))() `sparse_jacobian!` (with matrix coloring) time: $(t₂)s"
174182
end
175183
end
176184

177-
@testset "sparse_jacobian $(nameof(typeof(difftype))): In place" for difftype in (
185+
@testset "sparse_jacobian $(nice_string(difftype)): In place" for difftype in (
178186
AutoSparse(AutoZygote()),
179187
AutoZygote())
180188
y = similar(x)

0 commit comments

Comments
 (0)