Skip to content

Commit da37d43

Browse files
committed
Add AutoSparseJacobian algorithm for implicit solver
1 parent fc1bacd commit da37d43

File tree

10 files changed

+471
-58
lines changed

10 files changed

+471
-58
lines changed

.buildkite/Manifest-v1.11.toml

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ version = "0.5.17"
363363
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
364364

365365
[[deps.ClimaAtmos]]
366-
deps = ["Adapt", "ArgParse", "Artifacts", "AtmosphericProfilesLibrary", "ClimaComms", "ClimaCore", "ClimaDiagnostics", "ClimaParams", "ClimaTimeSteppers", "ClimaUtilities", "CloudMicrophysics", "Dates", "ForwardDiff", "Insolation", "Interpolations", "LazyArtifacts", "LazyBroadcast", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "NullBroadcasts", "RRTMGP", "Random", "SciMLBase", "StaticArrays", "Statistics", "SurfaceFluxes", "Thermodynamics", "UnrolledUtilities", "YAML"]
366+
deps = ["Adapt", "ArgParse", "Artifacts", "AtmosphericProfilesLibrary", "ClimaComms", "ClimaCore", "ClimaDiagnostics", "ClimaParams", "ClimaTimeSteppers", "ClimaUtilities", "CloudMicrophysics", "Dates", "ForwardDiff", "Insolation", "Interpolations", "LazyArtifacts", "LazyBroadcast", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "NullBroadcasts", "RRTMGP", "Random", "SciMLBase", "SparseMatrixColorings", "StaticArrays", "Statistics", "SurfaceFluxes", "Thermodynamics", "UnrolledUtilities", "YAML"]
367367
path = ".."
368368
uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
369369
version = "0.30.3"
@@ -381,9 +381,11 @@ weakdeps = ["CUDA", "MPI"]
381381

382382
[[deps.ClimaCore]]
383383
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LazyBroadcast", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "UnrolledUtilities"]
384-
git-tree-sha1 = "c6ab151ea66f3756566abc039c76ae767e490446"
384+
git-tree-sha1 = "371c3801d16438113f22e3db3d6da69cee5b3a0b"
385+
repo-rev = "tr/refactor-fm-internal-index"
386+
repo-url = "https://github.com/CliMA/ClimaCore.jl.git"
385387
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
386-
version = "0.14.34"
388+
version = "0.14.33"
387389
weakdeps = ["CUDA", "Krylov"]
388390

389391
[deps.ClimaCore.extensions]
@@ -503,11 +505,6 @@ git-tree-sha1 = "37ea44092930b1811e666c3bc38065d7d87fcc74"
503505
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
504506
version = "0.13.1"
505507

506-
[[deps.Combinatorics]]
507-
git-tree-sha1 = "8010b6bb3388abe68d95743dcbea77650bb2eddf"
508-
uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
509-
version = "1.0.3"
510-
511508
[[deps.CommonDataModel]]
512509
deps = ["CFTime", "DataStructures", "Dates", "Preferences", "Printf", "Statistics"]
513510
git-tree-sha1 = "358bf5a7d5c1387b995a43577673290c5d344758"
@@ -1027,12 +1024,6 @@ git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2"
10271024
uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe"
10281025
version = "1.0.2"
10291026

1030-
[[deps.HCubature]]
1031-
deps = ["Combinatorics", "DataStructures", "LinearAlgebra", "QuadGK", "StaticArrays"]
1032-
git-tree-sha1 = "19ef9f0cb324eed957b7fe7257ac84e8ed8a48ec"
1033-
uuid = "19dc6840-f33b-545b-b366-655c7e3ffd49"
1034-
version = "1.7.0"
1035-
10361027
[[deps.HDF5]]
10371028
deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"]
10381029
git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c"
@@ -2306,6 +2297,20 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
23062297
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
23072298
version = "1.11.0"
23082299

2300+
[[deps.SparseMatrixColorings]]
2301+
deps = ["ADTypes", "DocStringExtensions", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays"]
2302+
git-tree-sha1 = "ab958b4fec46d1f1d057bb8e2a99bfdb90744646"
2303+
uuid = "0a514795-09f3-496d-8182-132a7b665d35"
2304+
version = "0.4.20"
2305+
2306+
[deps.SparseMatrixColorings.extensions]
2307+
SparseMatrixColoringsCliqueTreesExt = "CliqueTrees"
2308+
SparseMatrixColoringsColorsExt = "Colors"
2309+
2310+
[deps.SparseMatrixColorings.weakdeps]
2311+
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
2312+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
2313+
23092314
[[deps.SpecialFunctions]]
23102315
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
23112316
git-tree-sha1 = "41852b8679f78c8d8961eeadc8f62cef861a52e3"

.buildkite/ci_driver.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,63 @@ include(joinpath(pkgdir(CA), "post_processing", "ci_plots.jl"))
4545
ref_job_id = config.parsed_args["reference_job_id"]
4646
reference_job_id = isnothing(ref_job_id) ? simulation.job_id : ref_job_id
4747

48+
if true # TODO: Add a debug_manual_jacobian flag.
49+
Y_end = integrator.u
50+
t_end = integrator.t
51+
dt = integrator.dt
52+
timestepper_algorithm = integrator.alg
53+
tableau_coefficients =
54+
timestepper_algorithm isa CA.CTS.RosenbrockAlgorithm ?
55+
timestepper_algorithm.tableau.Γ : timestepper_algorithm.tableau.a_imp
56+
γs = unique(filter(!iszero, CA.LinearAlgebra.diag(tableau_coefficients)))
57+
dtγ = float(dt) * γs[end]
58+
59+
auto_jac_alg = integrator.cache.newtons_method_cache.j.alg
60+
manual_jac_alg = auto_jac_alg.sparse_alg
61+
auto_jac = CA.Jacobian(auto_jac_alg, Y_end, atmos)
62+
manual_jac = CA.Jacobian(manual_jac_alg, Y_end, atmos)
63+
CA.update_jacobian!(auto_jac, Y_end, p, dtγ, t_end)
64+
CA.update_jacobian!(manual_jac, Y_end, p, dtγ, t_end)
65+
auto_matrix = auto_jac.cache.matrix.matrix
66+
manual_matrix = manual_jac.cache.matrix.matrix
67+
auto_scalar_matrix = CA.MatrixFields.scalar_field_matrix(auto_matrix)
68+
manual_scalar_matrix = CA.MatrixFields.scalar_field_matrix(manual_matrix)
69+
difference_scalar_matrix = auto_scalar_matrix .- manual_scalar_matrix
70+
71+
@info "Debugging manual Jacobian"
72+
for scalar_block_name in keys(auto_scalar_matrix)
73+
auto_block = auto_scalar_matrix[scalar_block_name]
74+
manual_block = manual_scalar_matrix[scalar_block_name]
75+
difference_block = difference_scalar_matrix[scalar_block_name]
76+
77+
auto_block isa CA.Fields.Field || continue
78+
79+
println("$scalar_block_name:")
80+
println("\t$(eltype(auto_block))")
81+
82+
(_, _, auto_lower_band, auto_upper_band) =
83+
CA.MatrixFields.band_matrix_info(auto_block)
84+
(_, _, manual_lower_band, manual_upper_band) =
85+
CA.MatrixFields.band_matrix_info(manual_block)
86+
for band in auto_lower_band:auto_upper_band
87+
auto_band_index = band - auto_lower_band + 1
88+
auto_band_average =
89+
mean(abs, auto_block.entries.:($auto_band_index))
90+
is_manual = band in manual_lower_band:manual_upper_band
91+
println("\tAverage of$(is_manual ? "" : " padding") band $band:")
92+
println("\t\t$auto_band_average (auto)")
93+
is_manual || continue
94+
manual_band_index = band - manual_lower_band + 1
95+
manual_band_average =
96+
mean(abs, manual_block.entries.:($manual_band_index))
97+
difference_band_average =
98+
mean(abs, difference_block.entries.:($auto_band_index))
99+
println("\t\t$manual_band_average (manual)")
100+
println("\t\t$difference_band_average (difference)")
101+
end
102+
end
103+
end
104+
48105
if sol_res.ret_code == :simulation_crashed
49106
error(
50107
"The ClimaAtmos simulation has crashed. See the stack trace for details.",

.buildkite/pipeline.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ steps:
2828

2929
- echo "--- Instantiate .buildkite"
3030
- "julia --project=.buildkite -e 'using Pkg; Pkg.instantiate(;verbose=true); Pkg.precompile(;strict=true); using CUDA; CUDA.precompile_runtime(); Pkg.status()'"
31+
- "julia --project=.buildkite -e 'using Pkg; Pkg.add(Pkg.PackageSpec(;name=\"ClimaCore\", rev=\"tr/refactor-fm-internal-index\"))'"
3132

3233
agents:
3334
slurm_cpus_per_task: 8
@@ -631,7 +632,7 @@ steps:
631632
--job_id amip_target_edonly_nonequil
632633
artifact_paths: "amip_target_edonly_nonequil/output_active/*"
633634
agents:
634-
slurm_mem: 20GB
635+
slurm_mem: 64GB
635636

636637
- group: "Diagnostic EDMFX"
637638
steps:

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ NullBroadcasts = "0d71be07-595a-4f89-9529-4065a4ab43a6"
2929
RRTMGP = "a01a1ee8-cea4-48fc-987c-fc7878d79da1"
3030
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3131
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
32+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
3233
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3334
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3435
SurfaceFluxes = "49b00bb7-8bd4-4f2b-b78c-51cd0450215f"
@@ -42,7 +43,7 @@ ArgParse = "1"
4243
Artifacts = "1"
4344
AtmosphericProfilesLibrary = "0.1.7"
4445
ClimaComms = "0.6.8"
45-
ClimaCore = "0.14.34"
46+
ClimaCore = "0.14.33"
4647
ClimaDiagnostics = "0.2.12"
4748
ClimaParams = "0.10.32"
4849
ClimaTimeSteppers = "0.8.2"
@@ -62,6 +63,7 @@ NullBroadcasts = "0.1"
6263
RRTMGP = "0.21.2"
6364
Random = "1"
6465
SciMLBase = "2.12"
66+
SparseMatrixColorings = "0.4.20"
6567
StaticArrays = "1.7"
6668
Statistics = "1"
6769
SurfaceFluxes = "0.11, 0.12"

src/ClimaAtmos.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ include(
6161
joinpath("prognostic_equations", "implicit", "manual_sparse_jacobian.jl"),
6262
)
6363
include(joinpath("prognostic_equations", "implicit", "auto_dense_jacobian.jl"))
64+
include(joinpath("prognostic_equations", "implicit", "auto_sparse_jacobian.jl"))
6465
include(joinpath("prognostic_equations", "implicit", "autodiff_utils.jl"))
6566

6667
include(joinpath("prognostic_equations", "water_advection.jl"))

src/prognostic_equations/implicit/auto_dense_jacobian.jl

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ function update_column_matrices!(alg::AutoDenseJacobian, cache, Y, p, dtγ, t)
8989
device = ClimaComms.device(Y.c)
9090
column_indices = column_index_iterator(Y)
9191
scalar_names = scalar_field_names(Y)
92-
scalar_level_indices = scalar_level_index_pairs(Y)
93-
batch_size = max_simultaneous_derivatives(alg)
94-
batch_size_val = Val(batch_size)
92+
jacobian_axis_index_to_field_vector_index_map =
93+
enumerate(field_vector_index_iterator(Y))
94+
n_εs = max_simultaneous_derivatives(alg)
95+
n_εs_val = Val(n_εs)
9596

9697
p_dual_args = ntuple(Val(fieldcount(typeof(p)))) do cache_field_index
9798
cache_field_name = fieldname(typeof(p), cache_field_index)
@@ -105,29 +106,30 @@ function update_column_matrices!(alg::AutoDenseJacobian, cache, Y, p, dtγ, t)
105106
end
106107
p_dual = AtmosCache(p_dual_args...)
107108

108-
batches = Iterators.partition(scalar_level_indices, batch_size)
109-
for batch_scalar_level_indices in ClimaComms.threadable(device, batches)
109+
batches =
110+
Iterators.partition(jacobian_axis_index_to_field_vector_index_map, n_εs)
111+
for indices_for_Y_axis in ClimaComms.threadable(device, batches)
110112
Y_dual .= Y
111113

112114
# Add a unique ε to Y for each scalar level index in this batch. With
113115
# Y_col and Yᴰ_col denoting the columns of Y and Y_dual at column_index,
114-
# set Yᴰ_col to Y_col + I[:, batch_scalar_level_indices] * εs, where I
115-
# is the identity matrix for Y_col (i.e., the value of ∂Y_col/∂Y_col),
116-
# εs is a vector of batch_size dual number components, and
117-
# batch_scalar_level_indices are the batch's indices into Y_col.
116+
# set Yᴰ_col to Y_col + I[:, indices_for_Y_axis] * εs, where I is the
117+
# identity matrix for Y_col (i.e., the value of ∂Y_col/∂Y_col), εs is a
118+
# vector of n_εs dual number components, and indices_for_Y_axis are the
119+
# batch's indices into Y_col.
118120
ClimaComms.@threaded device begin
119121
# On multithreaded devices, assign one thread to each combination of
120122
# spatial column index and scalar level index in this batch.
121123
for column_index in column_indices,
122124
(ε_index, (_, (scalar_index, level_index))) in
123-
enumerate(batch_scalar_level_indices)
125+
enumerate(indices_for_Y_axis)
124126

125-
Y_partials = ntuple(i -> i == ε_index ? 1 : 0, batch_size_val)
126-
Y_dual_increment = ForwardDiff.Dual{Jacobian}(0, Y_partials...)
127+
Y_partials = ntuple(==(ε_index), n_εs_val)
128+
Y_dual_εs_value = ForwardDiff.Dual{Jacobian}(0, Y_partials)
127129
unrolled_applyat(scalar_index, scalar_names) do name
128130
field = MatrixFields.get_field(Y_dual, name)
129131
@inbounds point(field, level_index, column_index...)[] +=
130-
Y_dual_increment
132+
Y_dual_εs_value
131133
end
132134
end
133135
end
@@ -141,19 +143,19 @@ function update_column_matrices!(alg::AutoDenseJacobian, cache, Y, p, dtγ, t)
141143
# with col_matrix denoting the matrix at the corresponding matrix_index
142144
# in column_matrices, copy the coefficients of the εs in Yₜᴰ_col into
143145
# col_matrix, where the previous steps have set Yₜᴰ_col to
144-
# Yₜ_col + (∂Yₜ_col/∂Y_col)[:, batch_scalar_level_indices] * εs.
145-
# Specifically, set col_matrix[scalar_level_index1, scalar_level_index2]
146-
# to ∂Yₜ_col[scalar_level_index1]/∂Y_col[scalar_level_index2], obtaining
146+
# Yₜ_col + (∂Yₜ_col/∂Y_col)[:, indices_for_Y_axis] * εs. Specifically, set
147+
# col_matrix[scalar_level_index1, scalar_level_index2] to
148+
# ∂Yₜ_col[scalar_level_index1]/∂Y_col[scalar_level_index2], obtaining
147149
# this derivative from the coefficient of εs[ε_index] in
148150
# Yₜᴰ_col[scalar_level_index1], where ε_index is the index of
149-
# scalar_level_index2 in batch_scalar_level_indices. After all batches
150-
# have been processed, col_matrix is the full Jacobian ∂Yₜ_col/∂Y_col.
151+
# scalar_level_index2 in indices_for_Y_axis. After all batches have been
152+
# processed, col_matrix is the full Jacobian ∂Yₜ_col/∂Y_col.
151153
ClimaComms.@threaded device begin
152154
# On multithreaded devices, assign one thread to each combination of
153155
# spatial column index and scalar level index.
154156
for (matrix_index, column_index) in enumerate(column_indices),
155157
(scalar_level_index1, (scalar_index1, level_index1)) in
156-
scalar_level_indices
158+
jacobian_axis_index_to_field_vector_index_map
157159

158160
Yₜ_dual_value =
159161
unrolled_applyat(scalar_index1, scalar_names) do name
@@ -162,7 +164,7 @@ function update_column_matrices!(alg::AutoDenseJacobian, cache, Y, p, dtγ, t)
162164
end
163165
Yₜ_partials = ForwardDiff.partials(Yₜ_dual_value)
164166
for (ε_index, (scalar_level_index2, _)) in
165-
enumerate(batch_scalar_level_indices)
167+
enumerate(indices_for_Y_axis)
166168
cartesian_index =
167169
(scalar_level_index1, scalar_level_index2, matrix_index)
168170
@inbounds column_matrices[cartesian_index...] =
@@ -193,14 +195,15 @@ function invert_jacobian!(::AutoDenseJacobian, cache, ΔY, R)
193195
device = ClimaComms.device(ΔY.c)
194196
column_indices = column_index_iterator(ΔY)
195197
scalar_names = scalar_field_names(ΔY)
196-
scalar_level_indices = scalar_level_index_pairs(ΔY)
198+
vector_index_to_field_vector_index_map =
199+
enumerate(field_vector_index_iterator(ΔY))
197200

198201
# Copy all scalar values from R into column_lu_vectors.
199202
ClimaComms.@threaded device begin
200203
# On multithreaded devices, assign one thread to each index into R.
201204
for (vector_index, column_index) in enumerate(column_indices),
202205
(scalar_level_index, (scalar_index, level_index)) in
203-
scalar_level_indices
206+
vector_index_to_field_vector_index_map
204207

205208
value = unrolled_applyat(scalar_index, scalar_names) do name
206209
field = MatrixFields.get_field(R, name)
@@ -219,7 +222,7 @@ function invert_jacobian!(::AutoDenseJacobian, cache, ΔY, R)
219222
# On multithreaded devices, assign one thread to each index into ΔY.
220223
for (vector_index, column_index) in enumerate(column_indices),
221224
(scalar_level_index, (scalar_index, level_index)) in
222-
scalar_level_indices
225+
vector_index_to_field_vector_index_map
223226

224227
@inbounds value =
225228
column_lu_vectors[scalar_level_index, vector_index]

0 commit comments

Comments
 (0)