Skip to content

Commit ebff614

Browse files
committed
Make suggested changes to tests and field_name_dict.jl
1 parent 3f81735 commit ebff614

File tree

4 files changed

+174
-176
lines changed

4 files changed

+174
-176
lines changed

src/MatrixFields/field_name_dict.jl

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -182,23 +182,22 @@ function get_internal_entry(
182182
elseif name_pair[2] == @name() && broadcasted_has_field(T, name_pair[1])
183183
# multiplication case 2 or 4, second argument
184184
target_field_eltype = broadcasted_get_field_type(T, name_pair[1])
185-
if target_field_eltype <: Number
185+
if target_field_eltype == eltype(parent(entry))
186186
T_band = eltype(entry)
187187
singleton_datalayout =
188188
DataLayouts.singleton(Fields.field_values(entry))
189189
# BandMatrixRow with same lowest diagonal and bandwidth as `entry`, with a scalar eltype
190-
scalar_band_type = BandMatrixRow{
191-
T_band.parameters[1],
192-
T_band.parameters[2],
193-
eltype(parent(entry)),
194-
}
190+
scalar_band_type = band_matrix_row_type(
191+
outer_diagonals(T_band)...,
192+
target_field_eltype,
193+
)
195194
field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry))
196195
scalar_field_offset = get_field_first_index_offset(
197196
name_pair[1],
198197
target_field_eltype,
199198
T,
200199
)
201-
band_element_size = Int(div(sizeof(T), sizeof(target_field_eltype)))
200+
band_element_size = div(sizeof(T), sizeof(target_field_eltype))
202201
parent_indices = DataLayouts.to_data_specific_field(
203202
singleton_datalayout,
204203
(
@@ -295,7 +294,7 @@ function get_field_first_index_offset(
295294
child_type = fieldtype(S, child_name)
296295
remaining_field_chain = drop_first(name)
297296
field_index =
298-
unrolled_filter(i -> fieldname(S, i) == child_name, 1:fieldcount(S))[1]
297+
UnrolledUtilities.unrolled_findfirst(isequal(child_name), fieldnames(S))
299298
return DataLayouts.fieldtypeoffset(T, S, field_index) +
300299
get_field_first_index_offset(remaining_field_chain, T, child_type)
301300
end
@@ -314,14 +313,18 @@ entries in the `FieldMatrix` `dict`.
314313
"""
315314
function get_scalar_keys(dict::FieldMatrix)
316315
keys_tuple = unrolled_flatmap(keys(dict).values) do key
317-
_, entry = unrolled_filter(pair -> key == pair[1], pairs(dict))[1]
316+
entry = values(dict)[UnrolledUtilities.unrolled_findfirst(
317+
isequal(key),
318+
keys(dict).values,
319+
)]
318320
entry =
319321
entry isa ColumnwiseBandMatrixField ? entry.entries.:(1) : entry
320322
unrolled_map(
321-
filtered_child_names(
322-
field -> eltype(field) <: Number,
323+
filtered_names(
324+
field ->
325+
(field isa UniformScaling) ||
326+
eltype(field) == eltype(parent(field)),
323327
entry,
324-
@name()
325328
),
326329
) do name
327330
(append_internal_name(key[1], name), key[2])
@@ -339,24 +342,27 @@ scalar components of entries of `field_matrix`.
339342
340343
# Example usage
341344
```julia
342-
struct foo{T1, T2}
343-
a::T
344-
b::T2
345-
end
346-
mat1 = fill(DiagonalMatrixRow(ClimaCore.Geometry.Covariant12Vector(1.0, 2.0)), space)
347-
mat2 = fill(DiagonalMatrixRow(foo(foo(1.0, 2.0), 3.0)), space)
345+
e¹² = Geometry.Covariant12Vector(1.6, 0.7)
346+
e₃ = Geometry.Contravariant3Vector(1.0)
347+
ᶜᶜmat3 = fill(TridiagonalMatrixRow(2.0, 3.2, 2.1), center_space)
348+
ᶜᶠmat2 = fill(BidiagonalMatrixRow(4.3, 1.7), center_space)
349+
ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,)
350+
ρχ_unit = (;ρq_liq = 1.0, ρq_ice = 1.0)
351+
ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit ⊠ e₃')), ᶜᶠmat2)
352+
353+
348354
A = MatrixFields.FieldMatrix(
349-
(@name(biz), @name(baz)) => mat1,
350-
(@name(bip), @name(bop)) => mat2,
355+
(@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃,
356+
(@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar,
351357
)
358+
352359
A_scalar = MatrixFields.scalar_fieldmatrix(A)
353360
keys(A_scalar)
354361
# Output:
355-
# (@name(biz.components.data.:(1)), @name(baz))
356-
# (@name(biz.components.data.:(2)), @name(baz))
357-
# (@name(bip.a.a), @name(bop))
358-
# (@name(bip.a.b), @name(bop))
359-
# (@name(bip.b), @name(bop))
362+
# (@name(c.ρχ.ρq_liq.parent.components.data.:(1)), @name(f.u₃))
363+
# (@name(c.ρχ.ρq_ice.parent.components.data.:(1)), @name(f.u₃))
364+
# (@name(c.uₕ.components.data.:(1)), @name(c.sgsʲs.:(1).ρa))
365+
# (@name(c.uₕ.components.data.:(2)), @name(c.sgsʲs.:(1).ρa))
360366
```
361367
"""
362368
function scalar_fieldmatrix(field_matrix::FieldMatrix)

test/MatrixFields/field_matrix_solvers.jl

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -376,15 +376,10 @@ end
376376
center_gs_unit = (; dry_center_gs_unit..., ρatke = 1, ρχ = ρχ_unit)
377377
center_sgsʲ_unit = (; ρa = 1, ρae_tot = 1, ρaχ = ρaχ_unit)
378378

379-
ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,)
380379
ᶠᶜmat2_u₃_scalar = ᶠᶜmat2 .* (e³,)
381380
ᶜᶠmat2_scalar_u₃ = ᶜᶠmat2 .* (e₃',)
382-
ᶜᶠmat2_uₕ_u₃ = ᶜᶠmat2 .* (e¹² * e₃',)
383381
ᶠᶠmat3_u₃_u₃ = ᶠᶠmat3 .* (e³ * e₃',)
384-
ᶜᶜmat3_ρχ_scalar = map(Base.Fix1(map, Base.Fix2(, ρχ_unit)), ᶜᶜmat3)
385-
ᶜᶜmat3_ρaχ_scalar = map(Base.Fix1(map, Base.Fix2(, ρaχ_unit)), ᶜᶜmat3)
386382
ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(, ρχ_unit e₃')), ᶜᶠmat2)
387-
ᶜᶠmat2_ρaχ_u₃ = map(Base.Fix1(map, Base.Fix2(, ρaχ_unit e₃')), ᶜᶠmat2)
388383
# We need to use Fix1 and Fix2 instead of defining anonymous functions in
389384
# order for the result of map to be inferrable.
390385

@@ -478,52 +473,21 @@ end
478473
n_iters = 6,
479474
),
480475
),
481-
A = MatrixFields.FieldMatrix(
482-
# GS-GS blocks:
483-
(@name(sfc), @name(sfc)) => I,
484-
(@name(c.ρ), @name(c.ρ)) => I,
485-
(@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3,
486-
(@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3,
487-
(@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3,
488-
(@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3,
489-
(@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃,
490-
(@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃,
491-
(@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃,
492-
(@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃,
493-
(@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar,
494-
(@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar,
495-
(@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃,
496-
# GS-SGS blocks:
497-
(@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => ᶜᶜmat3,
498-
(@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => ᶜᶜmat3,
499-
(@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => ᶜᶜmat3,
500-
(@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => ᶜᶜmat3,
501-
(@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => ᶜᶜmat3,
502-
(@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => ᶜᶜmat3,
503-
(@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3,
504-
(@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3,
505-
(@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_ρχ_scalar,
506-
(@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar,
507-
(@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃,
508-
(@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃,
509-
(@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρχ_u₃,
510-
(@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_uₕ_u₃,
511-
(@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => ᶠᶜmat2_u₃_scalar,
512-
(@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃,
513-
# SGS-SGS blocks:
514-
(@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I,
515-
(@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I,
516-
(@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I,
517-
(@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) =>
518-
ᶜᶠmat2_scalar_u₃,
519-
(@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) =>
520-
ᶜᶠmat2_scalar_u₃,
521-
(@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρaχ_u₃,
522-
(@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) =>
523-
ᶠᶜmat2_u₃_scalar,
524-
(@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) =>
525-
ᶠᶜmat2_u₃_scalar,
526-
(@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃,
476+
A = dycore_prognostic_EDMF_FieldMatrix(;
477+
ᶜᶜmat1,
478+
ᶜᶠmat2,
479+
ᶠᶜmat2,
480+
ᶜᶜmat3,
481+
ᶠᶠmat3,
482+
e¹²,
483+
e³,
484+
e₃,
485+
ρχ_unit,
486+
ρaχ_unit,
487+
ᶜᶠmat2_ρχ_u₃,
488+
ᶠᶠmat3_u₃_u₃,
489+
ᶜᶠmat2_scalar_u₃,
490+
ᶠᶜmat2_u₃_scalar,
527491
),
528492
b = b_moist_dycore_prognostic_edmf_prognostic_surface,
529493
)

test/MatrixFields/matrix_field_test_utils.jl

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ import ClimaCore:
2121
Operators,
2222
Quadratures
2323
using ClimaCore.MatrixFields
24+
import ClimaCore.Utilities: half
25+
import ClimaCore.RecursiveApply:
26+
import LinearAlgebra: I, norm, ldiv!, mul!
27+
import ClimaCore.MatrixFields: @name
2428

2529
# Test that an expression is true and that it is also type-stable.
2630
macro test_all(expression)
@@ -32,7 +36,7 @@ macro test_all(expression)
3236
end
3337
end
3438

35-
# Compute the minimum time (in seconds) required to run an expression after it
39+
# Compute the minimum time (in seconds) required to run an expression after it
3640
# has been compiled. This macro is used instead of @benchmark from
3741
# BenchmarkTools.jl because the latter is extremely slow (it appears to keep
3842
# triggering recompilations and allocating a lot of memory in the process).
@@ -134,6 +138,85 @@ function test_field_broadcast(;
134138
end
135139
end
136140

141+
# Create a field matrix for a similar solve to ClimaAtmos's moist dycore + prognostic,
142+
# EDMF + prognostic surface temperature with implicit acoustic waves and SGS fluxes
143+
function dycore_prognostic_EDMF_FieldMatrix(;
144+
ᶜᶜmat1,
145+
ᶜᶠmat2,
146+
ᶠᶜmat2,
147+
ᶜᶜmat3,
148+
ᶠᶠmat3,
149+
e¹²,
150+
e³,
151+
e₃,
152+
ρχ_unit,
153+
ρaχ_unit,
154+
ᶜᶠmat2_ρχ_u₃,
155+
ᶠᶠmat3_u₃_u₃,
156+
ᶜᶠmat2_scalar_u₃,
157+
ᶠᶜmat2_u₃_scalar,
158+
)
159+
160+
ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,)
161+
ᶜᶠmat2_uₕ_u₃ = ᶜᶠmat2 .* (e¹² * e₃',)
162+
ᶜᶜmat3_ρχ_scalar = map(Base.Fix1(map, Base.Fix2(, ρχ_unit)), ᶜᶜmat3)
163+
ᶜᶜmat3_ρaχ_scalar = map(Base.Fix1(map, Base.Fix2(, ρaχ_unit)), ᶜᶜmat3)
164+
ᶜᶠmat2_ρaχ_u₃ = map(Base.Fix1(map, Base.Fix2(, ρaχ_unit e₃')), ᶜᶠmat2)
165+
return MatrixFields.FieldMatrix(
166+
# GS-GS blocks:
167+
(@name(sfc), @name(sfc)) => I,
168+
(@name(c.ρ), @name(c.ρ)) => I,
169+
(@name(c.ρe_tot), @name(c.ρe_tot)) => deepcopy(ᶜᶜmat3),
170+
(@name(c.ρatke), @name(c.ρatke)) => deepcopy(ᶜᶜmat3),
171+
(@name(c.ρχ), @name(c.ρχ)) => deepcopy(ᶜᶜmat3),
172+
(@name(c.uₕ), @name(c.uₕ)) => deepcopy(ᶜᶜmat3),
173+
(@name(c.ρ), @name(f.u₃)) => deepcopy(ᶜᶠmat2_scalar_u₃),
174+
(@name(c.ρe_tot), @name(f.u₃)) => deepcopy(ᶜᶠmat2_scalar_u₃),
175+
(@name(c.ρatke), @name(f.u₃)) => deepcopy(ᶜᶠmat2_scalar_u₃),
176+
(@name(c.ρχ), @name(f.u₃)) => deepcopy(ᶜᶠmat2_ρχ_u₃),
177+
(@name(f.u₃), @name(c.ρ)) => deepcopy(ᶠᶜmat2_u₃_scalar),
178+
(@name(f.u₃), @name(c.ρe_tot)) => deepcopy(ᶠᶜmat2_u₃_scalar),
179+
(@name(f.u₃), @name(f.u₃)) => deepcopy(ᶠᶠmat3_u₃_u₃),
180+
# GS-SGS blocks:
181+
(@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => deepcopy(ᶜᶜmat3),
182+
(@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) =>
183+
deepcopy(ᶜᶜmat3),
184+
(@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) =>
185+
deepcopy(ᶜᶜmat3),
186+
(@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) =>
187+
deepcopy(ᶜᶜmat3),
188+
(@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) =>
189+
deepcopy(ᶜᶜmat3),
190+
(@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) =>
191+
deepcopy(ᶜᶜmat3),
192+
(@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => deepcopy(ᶜᶜmat3),
193+
(@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => deepcopy(ᶜᶜmat3),
194+
(@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => deepcopy(ᶜᶜmat3_ρχ_scalar),
195+
(@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => deepcopy(ᶜᶜmat3_uₕ_scalar),
196+
(@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => deepcopy(ᶜᶠmat2_scalar_u₃),
197+
(@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => deepcopy(ᶜᶠmat2_scalar_u₃),
198+
(@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => deepcopy(ᶜᶠmat2_ρχ_u₃),
199+
(@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => deepcopy(ᶜᶠmat2_uₕ_u₃),
200+
(@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => deepcopy(ᶠᶜmat2_u₃_scalar),
201+
(@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => deepcopy(ᶠᶠmat3_u₃_u₃),
202+
# SGS-SGS blocks:
203+
(@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I,
204+
(@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I,
205+
(@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I,
206+
(@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) =>
207+
deepcopy(ᶜᶠmat2_scalar_u₃),
208+
(@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) =>
209+
deepcopy(ᶜᶠmat2_scalar_u₃),
210+
(@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) =>
211+
deepcopy(ᶜᶠmat2_ρaχ_u₃),
212+
(@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) =>
213+
deepcopy(ᶠᶜmat2_u₃_scalar),
214+
(@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) =>
215+
deepcopy(ᶠᶜmat2_u₃_scalar),
216+
(@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) =>
217+
deepcopy(ᶠᶠmat3_u₃_u₃),
218+
)
219+
end
137220
# Generate extruded finite difference spaces for testing. Include topography
138221
# when possible.
139222
function test_spaces(::Type{FT}) where {FT}

0 commit comments

Comments
 (0)