@@ -89,9 +89,10 @@ function update_column_matrices!(alg::AutoDenseJacobian, cache, Y, p, dtγ, t)
89
89
device = ClimaComms. device (Y. c)
90
90
column_indices = column_index_iterator (Y)
91
91
scalar_names = scalar_field_names (Y)
92
- scalar_level_indices = scalar_level_index_pairs (Y)
93
- batch_size = max_simultaneous_derivatives (alg)
94
- batch_size_val = Val (batch_size)
92
+ jacobian_axis_index_to_field_vector_index_map =
93
+ enumerate (field_vector_index_iterator (Y))
94
+ n_εs = max_simultaneous_derivatives (alg)
95
+ n_εs_val = Val (n_εs)
95
96
96
97
p_dual_args = ntuple (Val (fieldcount (typeof (p)))) do cache_field_index
97
98
cache_field_name = fieldname (typeof (p), cache_field_index)
@@ -105,29 +106,30 @@ function update_column_matrices!(alg::AutoDenseJacobian, cache, Y, p, dtγ, t)
105
106
end
106
107
p_dual = AtmosCache (p_dual_args... )
107
108
108
- batches = Iterators. partition (scalar_level_indices, batch_size)
109
- for batch_scalar_level_indices in ClimaComms. threadable (device, batches)
109
+ batches =
110
+ Iterators. partition (jacobian_axis_index_to_field_vector_index_map, n_εs)
111
+ for indices_for_Y_axis in ClimaComms. threadable (device, batches)
110
112
Y_dual .= Y
111
113
112
114
# Add a unique ε to Y for each scalar level index in this batch. With
113
115
# Y_col and Yᴰ_col denoting the columns of Y and Y_dual at column_index,
114
- # set Yᴰ_col to Y_col + I[:, batch_scalar_level_indices ] * εs, where I
115
- # is the identity matrix for Y_col (i.e., the value of ∂Y_col/∂Y_col),
116
- # εs is a vector of batch_size dual number components, and
117
- # batch_scalar_level_indices are the batch's indices into Y_col.
116
+ # set Yᴰ_col to Y_col + I[:, indices_for_Y_axis ] * εs, where I is the
117
+ # identity matrix for Y_col (i.e., the value of ∂Y_col/∂Y_col), εs is a
118
+ # vector of n_εs dual number components, and indices_for_Y_axis are the
119
+ # batch's indices into Y_col.
118
120
ClimaComms. @threaded device begin
119
121
# On multithreaded devices, assign one thread to each combination of
120
122
# spatial column index and scalar level index in this batch.
121
123
for column_index in column_indices,
122
124
(ε_index, (_, (scalar_index, level_index))) in
123
- enumerate (batch_scalar_level_indices )
125
+ enumerate (indices_for_Y_axis )
124
126
125
- Y_partials = ntuple (i -> i == ε_index ? 1 : 0 , batch_size_val )
126
- Y_dual_increment = ForwardDiff. Dual {Jacobian} (0 , Y_partials... )
127
+ Y_partials = ntuple (== ( ε_index), n_εs_val )
128
+ Y_dual_εs_value = ForwardDiff. Dual {Jacobian} (0 , Y_partials)
127
129
unrolled_applyat (scalar_index, scalar_names) do name
128
130
field = MatrixFields. get_field (Y_dual, name)
129
131
@inbounds point (field, level_index, column_index... )[] +=
130
- Y_dual_increment
132
+ Y_dual_εs_value
131
133
end
132
134
end
133
135
end
@@ -141,19 +143,19 @@ function update_column_matrices!(alg::AutoDenseJacobian, cache, Y, p, dtγ, t)
141
143
# with col_matrix denoting the matrix at the corresponding matrix_index
142
144
# in column_matrices, copy the coefficients of the εs in Yₜᴰ_col into
143
145
# col_matrix, where the previous steps have set Yₜᴰ_col to
144
- # Yₜ_col + (∂Yₜ_col/∂Y_col)[:, batch_scalar_level_indices ] * εs.
145
- # Specifically, set col_matrix[scalar_level_index1, scalar_level_index2]
146
- # to ∂Yₜ_col[scalar_level_index1]/∂Y_col[scalar_level_index2], obtaining
146
+ # Yₜ_col + (∂Yₜ_col/∂Y_col)[:, indices_for_Y_axis ] * εs. Specifically, set
147
+ # col_matrix[scalar_level_index1, scalar_level_index2] to
148
+ # ∂Yₜ_col[scalar_level_index1]/∂Y_col[scalar_level_index2], obtaining
147
149
# this derivative from the coefficient of εs[ε_index] in
148
150
# Yₜᴰ_col[scalar_level_index1], where ε_index is the index of
149
- # scalar_level_index2 in batch_scalar_level_indices . After all batches
150
- # have been processed, col_matrix is the full Jacobian ∂Yₜ_col/∂Y_col.
151
+ # scalar_level_index2 in indices_for_Y_axis . After all batches have been
152
+ # processed, col_matrix is the full Jacobian ∂Yₜ_col/∂Y_col.
151
153
ClimaComms. @threaded device begin
152
154
# On multithreaded devices, assign one thread to each combination of
153
155
# spatial column index and scalar level index.
154
156
for (matrix_index, column_index) in enumerate (column_indices),
155
157
(scalar_level_index1, (scalar_index1, level_index1)) in
156
- scalar_level_indices
158
+ jacobian_axis_index_to_field_vector_index_map
157
159
158
160
Yₜ_dual_value =
159
161
unrolled_applyat (scalar_index1, scalar_names) do name
@@ -162,7 +164,7 @@ function update_column_matrices!(alg::AutoDenseJacobian, cache, Y, p, dtγ, t)
162
164
end
163
165
Yₜ_partials = ForwardDiff. partials (Yₜ_dual_value)
164
166
for (ε_index, (scalar_level_index2, _)) in
165
- enumerate (batch_scalar_level_indices )
167
+ enumerate (indices_for_Y_axis )
166
168
cartesian_index =
167
169
(scalar_level_index1, scalar_level_index2, matrix_index)
168
170
@inbounds column_matrices[cartesian_index... ] =
@@ -193,14 +195,15 @@ function invert_jacobian!(::AutoDenseJacobian, cache, ΔY, R)
193
195
device = ClimaComms. device (ΔY. c)
194
196
column_indices = column_index_iterator (ΔY)
195
197
scalar_names = scalar_field_names (ΔY)
196
- scalar_level_indices = scalar_level_index_pairs (ΔY)
198
+ vector_index_to_field_vector_index_map =
199
+ enumerate (field_vector_index_iterator (ΔY))
197
200
198
201
# Copy all scalar values from R into column_lu_vectors.
199
202
ClimaComms. @threaded device begin
200
203
# On multithreaded devices, assign one thread to each index into R.
201
204
for (vector_index, column_index) in enumerate (column_indices),
202
205
(scalar_level_index, (scalar_index, level_index)) in
203
- scalar_level_indices
206
+ vector_index_to_field_vector_index_map
204
207
205
208
value = unrolled_applyat (scalar_index, scalar_names) do name
206
209
field = MatrixFields. get_field (R, name)
@@ -219,7 +222,7 @@ function invert_jacobian!(::AutoDenseJacobian, cache, ΔY, R)
219
222
# On multithreaded devices, assign one thread to each index into ΔY.
220
223
for (vector_index, column_index) in enumerate (column_indices),
221
224
(scalar_level_index, (scalar_index, level_index)) in
222
- scalar_level_indices
225
+ vector_index_to_field_vector_index_map
223
226
224
227
@inbounds value =
225
228
column_lu_vectors[scalar_level_index, vector_index]
0 commit comments