From 13633ca555dd4435bfc5e70af7d678a09987862b Mon Sep 17 00:00:00 2001 From: imreddyTeja Date: Mon, 31 Mar 2025 14:44:11 -0700 Subject: [PATCH] Add scalar_fieldmatrix Add a function to convert a FieldMatrix where each matrix entry has an eltype of some struct into a FieldMatrix where each entry has an eltype of a scalar. Add additional tests for scalar_matrixfields Use @test_all in tests Make suggested changes to tests and field_name_dict.jl Revert unrolled_findfirst Clean up field matrix tests and add support for DiagonalMatrixRows CamelCase struct name Clean up tests and get_scalar_keys wip backup Minimal working with allocs WIP1 WIP more allocs fix Assorted cleanup Fix dx/dx case reduce code duplication; fix example Add gpu test further cleanup, extend diagonalrow fix names test and comments Add docs docs bugfix remvoe bad refs fix docs formatting --- .buildkite/pipeline.yml | 12 + docs/src/matrix_fields.md | 80 ++++ src/Geometry/axistensors.jl | 3 + src/MatrixFields/MatrixFields.jl | 1 + src/MatrixFields/field_name.jl | 18 + src/MatrixFields/field_name_dict.jl | 418 ++++++++++++++++++- test/MatrixFields/field_matrix_solvers.jl | 58 +-- test/MatrixFields/field_names.jl | 25 +- test/MatrixFields/matrix_field_test_utils.jl | 209 +++++++++- test/MatrixFields/scalar_fieldmatrix.jl | 157 +++++++ 10 files changed, 898 insertions(+), 83 deletions(-) create mode 100644 test/MatrixFields/scalar_fieldmatrix.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 93bdf9b49b..1c19c6d973 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -861,6 +861,18 @@ steps: agents: slurm_gpus: 1 + - label: "Unit: scalar_fieldmatrix (CPU)" + key: cpu_scalar_fieldmatrix + command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/scalar_fieldmatrix.jl" + + - label: "Unit: mscalar_fieldmatrix (GPU)" + key: gpu_scalar_fieldmatrix + command: "julia --color=yes --project=.buildkite test/MatrixFields/scalar_fieldmatrix.jl" + env: + CLIMACOMMS_DEVICE: "CUDA" + agents: + slurm_gpus: 1 + - group: "Unit: MatrixFields - broadcasting (CPU)" steps: diff --git a/docs/src/matrix_fields.md b/docs/src/matrix_fields.md index 4c89aa765d..d1a48d48c3 100644 --- a/docs/src/matrix_fields.md +++ b/docs/src/matrix_fields.md @@ -89,6 +89,10 @@ preconditioner_cache check_preconditioner lazy_or_concrete_preconditioner apply_preconditioner +get_scalar_keys +get_field_first_index_offset +broadcasted_get_field_type +inner_type_ignore_adjoint ``` ## Utilities @@ -98,4 +102,80 @@ column_field2array column_field2array_view field2arrays field2arrays_view +scalar_fieldmatrix ``` + +## Indexing a FieldMatrix + +A FieldMatrix entry can be: + +- An `UniformScaling`, which contains a `Number` +- A `DiagonalMatrixRow`, which can contain aything +- A `ColumnwiseBandMatrixField`, where each row is a [`BandMatrixRow`](@ref) where the band element type is representable with the space's base number type. + +If an entry contains a composite type, the fields of that type can be extracted. +This is also true for nested composite types. + +For example: + +```@example 1 +using ClimaCore.CommonSpaces # hide +import ClimaCore: MatrixFields, Quadratures # hide +import ClimaCore.MatrixFields: @name # hide +space = Box3DSpace(; # hide + z_elem = 3, # hide + x_min = 0, # hide + x_max = 1, # hide + y_min = 0, # hide + y_max = 1, # hide + z_min = 0, # hide + z_max = 10, # hide + periodic_x = false, # hide + periodic_y = false, # hide + n_quad_points = 1, # hide + quad = Quadratures.GL{1}(), # hide + x_elem = 1, # hide + y_elem = 2, # hide + staggering = CellCenter() # hide + ) # hide +nt_entry_field = fill(MatrixFields.DiagonalMatrixRow((; foo = 1.0, bar = 2.0)), space) +nt_fieldmatrix = MatrixFields.FieldMatrix((@name(a), @name(b)) => nt_entry_field) +nt_fieldmatrix[(@name(a), @name(b))] +``` + +The internal values of the named tuples can be extracted with + +```@example 1 +nt_fieldmatrix[(@name(a.foo), @name(b))] +``` + +and + +```@example 1 +nt_fieldmatrix[(@name(a.bar), @name(b))] +``` + +If the key `(@name(name1), @name(name2))` corresponds to an entry, then +`(@name(foo.bar.buz), @name(biz.bop.fud))` would be the internal key for the key +`(@name(name1.foo.bar.buz), @name(name2.biz.bop.fud))`. + +Currently, internal values cannot be extracted in all situations. Extracting interal values +works when: + +- The second name in the internal key is empty, and the first name in the internal key accesses internal values for the type of element contained in each row of the entry. This does not work when the element type of each row is a 2d tensor. + +- The first name in the internal key is empty, and the type of element contained in each row of the entry is an `AxisVector` or the adjoint of an `AxisVector`. In this case, the second name must access inernal values for the type of `AxisVector` contained in each row. + +- The element type of each row in the entry is a 2d tensor, and the internal key is of the form `(@name(components.data.:(1)), @name(components.data.:(2)))`, but possibly with different numbers to index into the 2d tensor + +- The element type of each row in the entry is some number of nested `Tuple`s and `NamedTuple`s, and the first name in the internal key accesses an `AxisVector` or the adjoint of an `AxisVector` from the outer `Tuple`/`NamedTuple`, and the second name in the inernal key accesses a component of the `AxisVector` + +If the `FieldMatrix` represents a Jacobian, then extracting internal values works when an entry represents: + +- The partial derrivative of an `AxisVector`, `Tuple`, or `NamedTuple` with respect to a scalar. + +- The partial derrivative of a scalar with respect to an `AxisVector`. + +- The partial derrivative of a `Tuple`, or `NamedTuple` with respect to an `AxisVector`. In this case, the first name of the internal key must index into the tuple and result in a scalar. + +- The partial derrivative of an `AxisVector` with respect to an `AxisVector`. In this case, the partial derrivative of a component of the first `AxisVector` with respect to a component of the second `AxisVector` can be extracted, but not an entire `AxisVector` with respect to a component, or a component with respect to an entire `AxisVector` diff --git a/src/Geometry/axistensors.jl b/src/Geometry/axistensors.jl index 534e628380..dd2eb22e51 100644 --- a/src/Geometry/axistensors.jl +++ b/src/Geometry/axistensors.jl @@ -308,6 +308,9 @@ Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} = const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}} +const AxisVectorOrAdj{T, A, S} = + Union{AxisVector{T, A, S}, AdjointAxisVector{T, A, S}} + Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) = getindex(components(va), i) Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) = diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index e2acdc67c1..43853e0096 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -58,6 +58,7 @@ import ..Utilities: PlusHalf, half import ..RecursiveApply: rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv import ..RecursiveApply: ⊠, ⊞, ⊟ +import ..DataLayouts import ..DataLayouts: AbstractData import ..DataLayouts: vindex import ..Geometry diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl index bd05f844c1..714772e5f2 100644 --- a/src/MatrixFields/field_name.jl +++ b/src/MatrixFields/field_name.jl @@ -50,6 +50,9 @@ extract_first(::FieldName{name_chain}) where {name_chain} = first(name_chain) drop_first(::FieldName{name_chain}) where {name_chain} = FieldName(Base.tail(name_chain)...) +extract_last(::FieldName{name_chain}) where {name_chain} = + name_chain[length(name_chain)] + has_field(x, ::FieldName{()}) = true has_field(x, name::FieldName) = extract_first(name) in propertynames(x) && @@ -59,6 +62,18 @@ get_field(x, ::FieldName{()}) = x get_field(x, name::FieldName) = get_field(getproperty(x, extract_first(name)), drop_first(name)) +""" + broadcasted_get_field_type(::Type{X}, name::FieldName) + +Returns the type of the field accessed by `name` in the type `X`. +""" +broadcasted_get_field_type(::Type{X}, ::FieldName{()}) where {X} = X +broadcasted_get_field_type(::Type{X}, name::FieldName) where {X} = + broadcasted_get_field_type( + fieldtype(X, extract_first(name)), + drop_first(name), + ) + broadcasted_has_field(::Type{X}, ::FieldName{()}) where {X} = true broadcasted_has_field(::Type{X}, name::FieldName) where {X} = extract_first(name) in fieldnames(X) && @@ -199,4 +214,7 @@ if hasfield(Method, :recursion_relation) for m in methods(get_subtree_at_name) m.recursion_relation = dont_limit end + for m in methods(broadcasted_get_field_type) + m.recursion_relation = dont_limit + end end diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index 2d98bb3ff3..15f6482924 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -152,18 +152,97 @@ get_internal_key(child_name_pair::FieldNamePair, name_pair::FieldNamePair) = ( get_internal_entry(entry, name::FieldName, key_error) = get_field(entry, name) get_internal_entry(entry, name_pair::FieldNamePair, key_error) = name_pair == (@name(), @name()) ? entry : throw(key_error) -get_internal_entry( - entry::ScalingFieldMatrixEntry, +get_internal_entry(entry::UniformScaling, name_pair::FieldNamePair, key_error) = + if name_pair[1] == name_pair[2] + entry + elseif is_overlapping_name(name_pair[1], name_pair[2]) + throw(key_error) + else + zero(entry) + end +function get_internal_entry( + entry::DiagonalMatrixRow, name_pair::FieldNamePair, key_error, -) = - if name_pair[1] == name_pair[2] +) + T = eltype(entry) + if name_pair == (@name(), @name()) + if T <: Adjoint{Real, Real} # Adjoint of a real is itself + DiagonalMatrixRow(entry.entries.:(1).parent) + else + entry + end + elseif name_pair[1] == name_pair[2] && + !broadcasted_has_field(T, name_pair[1]) && + !is_child_name(name_pair[2], @name(components.data)) entry + elseif name_pair[2] == @name() && + broadcasted_has_field(eltype(entry), name_pair[1]) + # TODO: dvec/dvec + @assert !(T <: Geometry.Axis2TensorOrAdj) + DiagonalMatrixRow( + broadcasted_get_field(entry.entries.:(1), name_pair[1]), + ) + elseif name_pair[1] == @name() && + broadcasted_has_field(eltype(entry), name_pair[2]) && + @assert T <: Geometry.AxisVectorOrAdj + DiagonalMatrixRow( + broadcasted_get_field(entry.entries.:(1), name_pair[2]), + ) + # TODO: dtuple/dvec and dvec/dvec and generic entries + + T <: Adjoint ? append_internal_name(@name(parent), name_pair[2]) : + name_pair[2] + elseif is_child_value( + name_pair, + (@name(components.data), @name(components.data)), + ) && T <: Geometry.Axis2Tensor + row_index = extract_last(name_pair[1]) + col_index = extract_last(name_pair[2]) + (n_rows, n_cols) = map(length, axes(T)) + @assert row_index <= n_rows && col_index <= n_cols + + DiagonalMatrixRow( + broadcasted_get_field( + entry.entries.:(1), + append_internal_name( + @name(components.data), + MatrixFields.FieldName( + n_rows * (col_index - 1) + row_index, + ), + ), + ), + ) + elseif broadcasted_has_field(T, name_pair[1]) && + is_child_name(name_pair[2], @name(components.data)) + @assert T <: Union{NamedTuple, Tuple} + @assert broadcasted_get_field_type(T, name_pair[1]) <: + Geometry.AxisVectorOrAdj + if broadcasted_get_field_type(T, name_pair[1]) <: Adjoint + DiagonalMatrixRow( + broadcasted_get_field( + entry.entries.:(1), + append_internal_name( + name_pair[1], + append_internal_name(@name(parent), name_pair[2]), + ), + ), + ) + else + DiagonalMatrixRow( + broadcasted_get_field( + entry.entries.:(1), + append_internal_name(name_pair[1], name_pair[2]), + ), + ) + end + elseif is_overlapping_name(name_pair[1], name_pair[2]) throw(key_error) else zero(entry) end +end function get_internal_entry( entry::ColumnwiseBandMatrixField, name_pair::FieldNamePair, @@ -173,24 +252,108 @@ function get_internal_entry( # See note above matrix_product_keys in field_name_set.jl for more details. T = eltype(eltype(entry)) if name_pair == (@name(), @name()) - entry - elseif name_pair[1] == name_pair[2] - # multiplication case 3 or 4, first argument - @assert T <: Geometry.SingleValue && - !broadcasted_has_field(T, name_pair[1]) - entry - elseif name_pair[2] == @name() && broadcasted_has_field(T, name_pair[1]) - # multiplication case 2 or 4, second argument - Base.broadcasted(entry) do matrix_row - map(matrix_row) do matrix_row_entry - broadcasted_get_field(matrix_row_entry, name_pair[1]) - end - end # Note: This assumes that the entry is in a FieldMatrixBroadcasted. + if T <: Adjoint{Real, Real} # Adjoint of a real is itself + entry.parent + else + entry + end + elseif name_pair[1] == name_pair[2] && + !broadcasted_has_field(T, name_pair[1]) && + !is_child_name(name_pair[2], @name(components.data)) # don't enter case if attempting to access a component + entry # multiplication case 3 or 4, first argument else - throw(key_error) + start_index = if name_pair[1] == @name() || name_pair[2] == @name() + target_chain = if name_pair[1] == @name() + # dscalar/dvec + # TODO: dtuple/dvec and dvec/dvec and generic entries + @assert T <: Geometry.AxisVectorOrAdj + T <: Adjoint ? + append_internal_name(@name(parent), name_pair[2]) : + name_pair[2] + else + # dvec/dscalar, dtuple/dscalar + # TODO: dtuple/dvec will return a broadcasted object, dvec/dvec incorrect + # can't index into a 2d tensor with only one name + @assert !(T <: Geometry.Axis2TensorOrAdj) + name_pair[1] + end + if broadcasted_has_field(T, target_chain) + target_field_eltype = + broadcasted_get_field_type(T, target_chain) + if target_field_eltype == eltype(parent(entry)) + 1 + get_field_first_index_offset( + target_chain, + target_field_eltype, + T, + ) + else + return Base.broadcasted(entry) do matrix_row + map(matrix_row) do matrix_row_entry + broadcasted_get_field(matrix_row_entry, target_chain) + end + end + end + else + throw(key_error) + end + elseif name_pair[2] != @name() && name_pair[1] != @name() + if is_child_value( + name_pair, + (@name(components.data), @name(components.data)), + ) + # dvec/dvec + @assert T <: Geometry.Axis2Tensor + row_index = extract_last(name_pair[1]) + col_index = extract_last(name_pair[2]) + (n_rows, n_cols) = map(length, axes(T)) + @assert row_index <= n_rows && col_index <= n_cols + n_rows * (col_index - 1) + row_index + elseif broadcasted_has_field(T, name_pair[1]) && + is_child_name(name_pair[2], @name(components.data)) + + # dNTuple/dvec + @assert T <: Union{NamedTuple, Tuple} && + broadcasted_has_field(T, name_pair[1]) + @assert broadcasted_get_field_type(T, name_pair[1]) <: + Geometry.AxisVectorOrAdj + get_field_first_index_offset( + name_pair[1], + broadcasted_get_field_type(T, name_pair[1]), + T, + ) + extract_last(name_pair[2]) + else + throw(key_error) + end + else + throw(key_error) + end + band_element_size = div(sizeof(T), sizeof(eltype(parent(entry)))) + T_band = eltype(entry) + singleton_datalayout = DataLayouts.singleton(Fields.field_values(entry)) + # BandMatrixRow with same lowest diagonal and bandwidth as `entry`, with a scalar eltype + scalar_band_type = band_matrix_row_type( + outer_diagonals(T_band)..., + eltype(parent(entry)), + ) + field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry)) + parent_indices = DataLayouts.to_data_specific_field( + singleton_datalayout, + (:, :, start_index:band_element_size:field_dim_size, :, :), + ) + + scalar_data = view(parent(entry), parent_indices...) + + values = DataLayouts.union_all(singleton_datalayout){ + scalar_band_type, + Base.tail(DataLayouts.type_params(Fields.field_values(entry)))..., + }( + scalar_data, + ) + Fields.Field(values, axes(entry)) end end + # Similar behavior to indexing an array with a slice. function Base.getindex(dict::FieldNameDict, new_keys::FieldNameSet) common_keys = intersect(keys(dict), new_keys) @@ -237,6 +400,221 @@ function Base.one(matrix::FieldMatrix) return FieldNameDict(inferred_diagonal_keys, entries) end +""" + get_field_first_index_offset(name::FieldName, ::Type{T}, ::Type{S}) + +Returns the offset of the the field with name `name` in an object of type `S` +in multiples of `sizeof(T)`. +""" +function get_field_first_index_offset( + name::FieldName, + ::Type{T}, + ::Type{S}, +) where {T, S} + if name == @name() + return 0 + end + child_name = extract_first(name) + child_type = fieldtype(S, child_name) + remaining_field_chain = drop_first(name) + field_index = + unrolled_filter(i -> fieldname(S, i) == child_name, 1:fieldcount(S))[1] + return DataLayouts.fieldtypeoffset(T, S, field_index) + + get_field_first_index_offset(remaining_field_chain, T, child_type) +end +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(get_field_first_index_offset) + m.recursion_relation = dont_limit + end +end + + + +""" + inner_type_ignore_adjoint(x) + +If x is a Field, return the type of the elements inside it, ignoring Adjoint wrappers. +Otherwise return the type of x, also ignoring Adjoint wrappers. +""" +inner_type_ignore_adjoint(x::Fields.Field) = + inner_type_ignore_adjoint(Fields.field_values(x)) +inner_type_ignore_adjoint(x::DataLayouts.AbstractData{<:Adjoint}) = + eltype(x.parent) +inner_type_ignore_adjoint(x::DataLayouts.AbstractData) = eltype(x) +inner_type_ignore_adjoint(x::Adjoint) = typeof(x.parent) +inner_type_ignore_adjoint(x) = typeof(x) + +append_component_name(name::FieldName, component::Int) = append_internal_name( + name, + append_internal_name( + @name(components.data), + MatrixFields.FieldName(component), + ), +) + +""" + get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector) + +Returns a `FieldMatrixKeys` object that contains the keys of all the scalar +entries in the `FieldMatrix` `dict`. +""" +function get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector) + keys_tuple = unrolled_flatmap(keys(dict).values) do key + target_eltype = eltype(Y) + entry = dict[unrolled_filter(isequal(key), keys(dict).values)[1]] + if entry isa UniformScaling # uniformscalings can only contain numbers + (key,) + elseif entry isa ColumnwiseBandMatrixField || + entry isa DiagonalMatrixRow + first_band = entry.entries.:(1) + # if entry isa DiagonalMatrixRow, then first_band is not a field + # adjoints can be ignored, assuming that the only adjoint a FieldMatrix will contain + # is the adjoint of a vector or a number, as the indices do not change, and they must contain Reals + T_first_band = inner_type_ignore_adjoint(first_band) + + if T_first_band <: target_eltype # The case when entry contains numbers or adjoint of numbers + (key,) + else + row_key = get_field(Y, key[1]) + col_key = get_field(Y, key[2]) + row_key_type = eltype(row_key) + col_key_type = eltype(col_key) + + row_key_type <: Geometry.SingleValue || + col_key_type <: Geometry.SingleValue || + error("cannot get scalar keys for key $key") + + if col_key_type <: Geometry.AxisTensor && # dvec/dvec + row_key_type <: Geometry.AxisTensor + T_first_band <: Geometry.Axis2Tensor{target_eltype} || + error( + "expected key $key to be a 2-tensor of $target_eltype and found a $T_first_band", + ) + axis_components = map(length, axes(T_first_band)) + unrolled_map( + MatrixFields.unrolled_product( + 1:axis_components[1], + 1:axis_components[2], + ), + ) do (component1, component2) + ( + append_component_name(key[1], component1), + append_component_name(key[2], component2), + ) + end + elseif ( + row_key_type <: Number && + col_key_type <: Geometry.AxisTensor + ) || ( + col_key_type <: Number && + row_key_type <: Geometry.AxisTensor + ) # dscalar/dvec or dvec/dscalar result in same types + T_first_band <: Geometry.AxisVector{target_eltype} || + error( + "expected key $key to be a vector of $target_eltype and found a $T_first_band", + ) + ncomponents = length(axes(T_first_band)[1]) + unrolled_map(1:ncomponents) do component + if row_key_type <: Geometry.AxisTensor # dvec/dscalar + (append_component_name(key[1], component), key[2]) + else # dscalar/dvec + (key[1], append_component_name(key[2], component)) + end + end + + elseif row_key_type <: Union{NamedTuple, Tuple} && + col_key_type <: Geometry.AxisVector #dtuple/dvec + unrolled_flatmap( + filtered_names( + x -> + inner_type_ignore_adjoint(x) <: + Geometry.SingleValue, + first_band, + ), + ) do dependent_name + T_inner = inner_type_ignore_adjoint( + get_field(first_band, dependent_name), + ) + T_inner <: Geometry.AxisVector{target_eltype} || + error( + "expected key $key to be a tuple of vectors and found a tuple of $T_inner", + ) + ncomponents1 = length(axes(T_inner)[1]) + unrolled_map(1:ncomponents1) do component + ( + append_internal_name(key[1], dependent_name), + append_component_name(key[2], component), + ) + end + end + elseif row_key_type <: Union{NamedTuple, Tuple} && + col_key_type <: Number # dtuple/dscalar + unrolled_map( + filtered_names( + x -> eltype(x) <: target_eltype, + row_key, + ), + ) do dependent_name + (append_internal_name(key[1], dependent_name), key[2]) + + end + end + end + else + error("Cannot get scalar keys for key $key") + end + + end + return FieldMatrixKeys(keys_tuple) +end + + +""" + scalar_fieldmatrix(field_matrix::FieldMatrix) + +Constructs a `FieldNameDict` where the keys and entries are views +of the entries of `field_matrix`, which corresponding to the +scalar components of entries of `field_matrix`. + +# Example usage +```julia +e¹² = Geometry.Covariant12Vector(1.6, 0.7) +e₃ = Geometry.Contravariant3Vector(1.0) +e³ = Geometry.Covariant3Vector(1) +ᶜᶜmat3 = fill(TridiagonalMatrixRow(2.0, 3.2, 2.1), center_space) +ᶜᶠmat2 = fill(BidiagonalMatrixRow(4.3, 1.7), center_space) +ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,) +ρχ_unit = (;ρq_liq = 1.0, ρq_ice = 1.0) +ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit ⊠ e₃')), ᶜᶠmat2) +ᶜvec = random_field(Float64, center_space) +ᶠvec = random_field(Float64, face_space) + b = Fields.FieldVector(; + c = ᶜvec .* ((; ρχ = ρχ_unit, uₕ = e¹², sgsʲs = ((; ρa = 1),)),), + f = ᶠvec .* ((; u₃ = e³),), +) +A = MatrixFields.FieldMatrix( + (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, +) + +A_scalar = MatrixFields.scalar_fieldmatrix(A, b) +keys(A_scalar) +# Output: +# (@name(c.ρχ.ρq_liq), @name(f.u₃.components.data.:(1))) +# (@name(c.ρχ.ρq_ice), @name(f.u₃.components.data.:(1))) +# (@name(c.uₕ.components.data.:(1)), @name(c.sgsʲs.:(1).ρa)) +# (@name(c.uₕ.components.data.:(2)), @name(c.sgsʲs.:(1).ρa)) +``` +""" +function scalar_fieldmatrix(field_matrix::FieldMatrix, Y::Fields.FieldVector) + scalar_keys = get_scalar_keys(field_matrix, Y) + entries = unrolled_map(scalar_keys.values) do key + field_matrix[key] + end + return FieldNameDict(scalar_keys, entries) +end + replace_name_tree(dict::FieldNameDict, name_tree) = FieldNameDict(replace_name_tree(keys(dict), name_tree), values(dict)) @@ -546,8 +924,8 @@ function Base.Broadcast.broadcasted( ) product_value = scaling_value(entry1) * scaling_value(entry2) product_value isa Number ? - UniformScaling(product_value) : - DiagonalMatrixRow(product_value) + (UniformScaling(product_value),) : + (DiagonalMatrixRow(product_value),) elseif entry1 isa ScalingFieldMatrixEntry Base.Broadcast.broadcasted(*, (scaling_value(entry1),), entry2) elseif entry2 isa ScalingFieldMatrixEntry diff --git a/test/MatrixFields/field_matrix_solvers.jl b/test/MatrixFields/field_matrix_solvers.jl index 524d5b8b9e..4f65861fd5 100644 --- a/test/MatrixFields/field_matrix_solvers.jl +++ b/test/MatrixFields/field_matrix_solvers.jl @@ -376,15 +376,10 @@ end center_gs_unit = (; dry_center_gs_unit..., ρatke = 1, ρχ = ρχ_unit) center_sgsʲ_unit = (; ρa = 1, ρae_tot = 1, ρaχ = ρaχ_unit) - ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,) ᶠᶜmat2_u₃_scalar = ᶠᶜmat2 .* (e³,) ᶜᶠmat2_scalar_u₃ = ᶜᶠmat2 .* (e₃',) - ᶜᶠmat2_uₕ_u₃ = ᶜᶠmat2 .* (e¹² * e₃',) ᶠᶠmat3_u₃_u₃ = ᶠᶠmat3 .* (e³ * e₃',) - ᶜᶜmat3_ρχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit)), ᶜᶜmat3) - ᶜᶜmat3_ρaχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit)), ᶜᶜmat3) ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit ⊠ e₃')), ᶜᶠmat2) - ᶜᶠmat2_ρaχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit ⊠ e₃')), ᶜᶠmat2) # We need to use Fix1 and Fix2 instead of defining anonymous functions in # order for the result of map to be inferrable. @@ -464,7 +459,10 @@ end ), b = b_moist_dycore_diagnostic_edmf, ) - + ( + A_moist_dycore_prognostic_edmf_prognostic_surface, + b_moist_dycore_prognostic_edmf_prognostic_surface, + ) = dycore_prognostic_EDMF_FieldMatrix(FT) test_field_matrix_solver(; test_name = "similar solve to ClimaAtmos's moist dycore + prognostic \ EDMF + prognostic surface temperature with implicit \ @@ -478,53 +476,7 @@ end n_iters = 6, ), ), - A = MatrixFields.FieldMatrix( - # GS-GS blocks: - (@name(sfc), @name(sfc)) => I, - (@name(c.ρ), @name(c.ρ)) => I, - (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, - (@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3, - (@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3, - (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3, - (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, - (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, - (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, - (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, - # GS-SGS blocks: - (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => ᶜᶜmat3, - (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, - (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, - (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_ρχ_scalar, - (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, - (@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρχ_u₃, - (@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_uₕ_u₃, - (@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => ᶠᶜmat2_u₃_scalar, - (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, - # SGS-SGS blocks: - (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, - (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, - (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, - (@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) => - ᶜᶠmat2_scalar_u₃, - (@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) => - ᶜᶠmat2_scalar_u₃, - (@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρaχ_u₃, - (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) => - ᶠᶜmat2_u₃_scalar, - (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) => - ᶠᶜmat2_u₃_scalar, - (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, - ), + A = A_moist_dycore_prognostic_edmf_prognostic_surface, b = b_moist_dycore_prognostic_edmf_prognostic_surface, ) end diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 0253dd57bc..6b8f51790f 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -770,9 +770,9 @@ end (@name(a), @name(a)) => -I_CT3XC3, ) - for (vector, matrix, I_foo, I_a) in ( - (vector_of_scalars, matrix_of_scalars, I, I), - (vector_of_vectors, matrix_of_tensors, I_C12XCT12, I_CT3XC3), + for (vector, matrix, I_foo, I_a, is_scalar_test) in ( + (vector_of_scalars, matrix_of_scalars, I, I, true), + (vector_of_vectors, matrix_of_tensors, I_C12XCT12, I_CT3XC3, false), ) @test_all MatrixFields.field_vector_view(vector) == MatrixFields.FieldVectorView( @@ -834,18 +834,25 @@ end @test_throws KeyError matrix[@name(a), @name(a.c)] @test_throws KeyError matrix[@name(a.c), @name(a)] - @test_throws KeyError matrix[@name(foo), @name(foo._value)] - @test_throws KeyError matrix[@name(foo._value), @name(foo)] + @test_throws AssertionError matrix[@name(foo), @name(foo._value)] + if is_scalar_test + @test_throws KeyError matrix[@name(foo._value), @name(foo)] + else + @test_throws AssertionError matrix[@name(foo._value), @name(foo)] + end @test_all matrix[@name(a), @name(a)] == -I_a @test_all matrix[@name(a.c), @name(a.c)] == -I_a @test_all matrix[@name(a.c), @name(a.b)] == zero(I_a) @test_all matrix[@name(foo._value), @name(foo._value)] == matrix[@name(foo), @name(foo)] - - @test_all matrix[@name(foo._value), @name(a.b)] isa - Base.AbstractBroadcasted - @test Base.materialize(matrix[@name(foo._value), @name(a.b)]) == map( + entry = matrix[@name(foo._value), @name(a.b)] + @test_all entry isa ( + is_scalar_test ? MatrixFields.ColumnwiseBandMatrixField : + Base.AbstractBroadcasted + ) + entry = is_scalar_test ? entry : Base.materialize(entry) + @test entry == map( row -> map(foo -> foo.value, row), matrix[@name(foo), @name(a.b)], ) diff --git a/test/MatrixFields/matrix_field_test_utils.jl b/test/MatrixFields/matrix_field_test_utils.jl index 9e65be8830..acaa5a04b2 100644 --- a/test/MatrixFields/matrix_field_test_utils.jl +++ b/test/MatrixFields/matrix_field_test_utils.jl @@ -21,6 +21,10 @@ import ClimaCore: Operators, Quadratures using ClimaCore.MatrixFields +import ClimaCore.Utilities: half +import ClimaCore.RecursiveApply: ⊠ +import LinearAlgebra: I, norm, ldiv!, mul! +import ClimaCore.MatrixFields: @name # Test that an expression is true and that it is also type-stable. macro test_all(expression) @@ -32,7 +36,7 @@ macro test_all(expression) end end -# Compute the minimum time (in seconds) required to run an expression after it +# Compute the minimum time (in seconds) required to run an expression after it # has been compiled. This macro is used instead of @benchmark from # BenchmarkTools.jl because the latter is extremely slow (it appears to keep # triggering recompilations and allocating a lot of memory in the process). @@ -134,6 +138,209 @@ function test_field_broadcast(; end end +# Create a field matrix for a similar solve to ClimaAtmos's moist dycore + prognostic, +# EDMF + prognostic surface temperature with implicit acoustic waves and SGS fluxes +# also returns corresponding FieldVector +function dycore_prognostic_EDMF_FieldMatrix( + ::Type{FT}, + center_space = nothing, + face_space = nothing, +) where {FT} + seed!(1) # For reproducibility with random fields + if isnothing(center_space) || isnothing(face_space) + center_space, face_space = test_spaces(FT) + end + surface_space = Spaces.level(face_space, half) + surface_space = Spaces.level(face_space, half) + sfc_vec = random_field(FT, surface_space) + ᶜvec = random_field(FT, center_space) + ᶠvec = random_field(FT, face_space) + λ = 10 + ᶜᶜmat1 = random_field(DiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶜᶠmat2 = random_field(BidiagonalMatrixRow{FT}, center_space) ./ λ + ᶠᶜmat2 = random_field(BidiagonalMatrixRow{FT}, face_space) ./ λ + ᶜᶜmat3 = random_field(TridiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶠᶠmat3 = random_field(TridiagonalMatrixRow{FT}, face_space) ./ λ .+ (I,) + # Geometry.Covariant123Vector(1, 2, 3) * Geometry.Covariant12Vector(1, 2)' + e¹² = Geometry.Covariant12Vector(1, 1) + e₁₂ = Geometry.Contravariant12Vector(1, 1) + e³ = Geometry.Covariant3Vector(1) + e₃ = Geometry.Contravariant3Vector(1) + + ρχ_unit = (; ρq_tot = 1, ρq_liq = 1, ρq_ice = 1, ρq_rai = 1, ρq_sno = 1) + ρaχ_unit = + (; ρaq_tot = 1, ρaq_liq = 1, ρaq_ice = 1, ρaq_rai = 1, ρaq_sno = 1) + + + ᶠᶜmat2_u₃_scalar = ᶠᶜmat2 .* (e³,) + ᶜᶠmat2_scalar_u₃ = ᶜᶠmat2 .* (e₃',) + ᶠᶠmat3_u₃_u₃ = ᶠᶠmat3 .* (e³ * e₃',) + ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit ⊠ e₃')), ᶜᶠmat2) + ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,) + ᶜᶜmat3_uₕ_uₕ = + ᶜᶜmat3 .* ( + Geometry.Covariant12Vector(1, 0) * + Geometry.Contravariant12Vector(1, 0)' + + Geometry.Covariant12Vector(0, 1) * + Geometry.Contravariant12Vector(0, 1)', + ) + ᶜᶠmat2_uₕ_u₃ = ᶜᶠmat2 .* (e¹² * e₃',) + ᶜᶜmat3_ρχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit)), ᶜᶜmat3) + ᶜᶜmat3_ρaχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit)), ᶜᶜmat3) + ᶜᶠmat2_ρaχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit ⊠ e₃')), ᶜᶠmat2) + + dry_center_gs_unit = (; ρ = 1, ρe_tot = 1, uₕ = e¹²) + center_gs_unit = (; dry_center_gs_unit..., ρatke = 1, ρχ = ρχ_unit) + center_sgsʲ_unit = (; ρa = 1, ρae_tot = 1, ρaχ = ρaχ_unit) + + b = Fields.FieldVector(; + sfc = sfc_vec .* ((; T = 1),), + c = ᶜvec .* ((; center_gs_unit..., sgsʲs = (center_sgsʲ_unit,)),), + f = ᶠvec .* ((; u₃ = e³, sgsʲs = ((; u₃ = e³),)),), + ) + A = MatrixFields.FieldMatrix( + # GS-GS blocks: + (@name(sfc), @name(sfc)) => I, + (@name(c.ρ), @name(c.ρ)) => I, + (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, + (@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3, + (@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3, + (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3_uₕ_uₕ, + (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, + # GS-SGS blocks: + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => ᶜᶜmat3, + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, + (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, + (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_ρχ_scalar, + (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, + (@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_uₕ_u₃, + (@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, + # SGS-SGS blocks: + (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, + (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, + (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, + (@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) => + ᶜᶠmat2_scalar_u₃, + (@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) => + ᶜᶠmat2_scalar_u₃, + (@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρaχ_u₃, + (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) => + ᶠᶜmat2_u₃_scalar, + (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) => + ᶠᶜmat2_u₃_scalar, + (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, + ) + return A, b +end + +function scaling_only_dycore_prognostic_EDMF_FieldMatrix( + ::Type{FT}, + center_space = nothing, + face_space = nothing, +) where {FT} + seed!(1) # For reproducibility with random fields + if isnothing(center_space) || isnothing(face_space) + center_space, face_space = test_spaces(FT) + end + surface_space = Spaces.level(face_space, half) + surface_space = Spaces.level(face_space, half) + sfc_vec = random_field(FT, surface_space) + ᶜvec = random_field(FT, center_space) + ᶠvec = random_field(FT, face_space) + λ = 10 + # Geometry.Covariant123Vector(1, 2, 3) * Geometry.Covariant12Vector(1, 2)' + e¹² = Geometry.Covariant12Vector(FT(1), FT(1)) + e₁₂ = Geometry.Contravariant12Vector(FT(1), FT(1)) + e³ = Geometry.Covariant3Vector(FT(1)) + e₃ = Geometry.Contravariant3Vector(FT(1)) + + ρχ_unit = (; + ρq_tot = FT(1), + ρq_liq = FT(1), + ρq_ice = FT(1), + ρq_rai = FT(1), + ρq_sno = FT(1), + ) + ρaχ_unit = (; + ρaq_tot = FT(1), + ρaq_liq = FT(1), + ρaq_ice = FT(1), + ρaq_rai = FT(1), + ρaq_sno = FT(1), + ) + + + + ᶠᶠu₃_u₃ = DiagonalMatrixRow(e³ * e₃') + ᶜᶜuₕ_scalar = DiagonalMatrixRow(e¹²) + ᶜᶜuₕ_uₕ = DiagonalMatrixRow( + Geometry.Covariant12Vector(FT(1), FT(0)) * + Geometry.Contravariant12Vector(FT(1), FT(0))' + + Geometry.Covariant12Vector(FT(0), FT(1)) * + Geometry.Contravariant12Vector(FT(0), FT(1))', + ) + ᶜᶜρχ_scalar = DiagonalMatrixRow(ρχ_unit) + ᶜᶜρaχ_scalar = DiagonalMatrixRow(ρaχ_unit) + + dry_center_gs_unit = (; ρ = FT(1), ρe_tot = FT(1), uₕ = e¹²) + center_gs_unit = (; dry_center_gs_unit..., ρatke = FT(1), ρχ = ρχ_unit) + center_sgsʲ_unit = (; ρa = FT(1), ρae_tot = FT(1), ρaχ = ρaχ_unit) + + b = Fields.FieldVector(; + sfc = sfc_vec .* ((; T = 1),), + c = ᶜvec .* ((; center_gs_unit..., sgsʲs = (center_sgsʲ_unit,)),), + f = ᶠvec .* ((; u₃ = e³, sgsʲs = ((; u₃ = e³),)),), + ) + A = MatrixFields.FieldMatrix( + # GS-GS blocks: + (@name(sfc), @name(sfc)) => I, + (@name(c.ρ), @name(c.ρ)) => I, + (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜuₕ_uₕ, + (@name(f.u₃), @name(f.u₃)) => ᶠᶠu₃_u₃, + # GS-SGS blocks: + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜρχ_scalar, + (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜuₕ_scalar, + (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠu₃_u₃, + # SGS-SGS blocks: + (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, + (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, + (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, + (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠu₃_u₃, + ) + return A, b +end + # Generate extruded finite difference spaces for testing. Include topography # when possible. function test_spaces(::Type{FT}) where {FT} diff --git a/test/MatrixFields/scalar_fieldmatrix.jl b/test/MatrixFields/scalar_fieldmatrix.jl new file mode 100644 index 0000000000..a1d5f07fe2 --- /dev/null +++ b/test/MatrixFields/scalar_fieldmatrix.jl @@ -0,0 +1,157 @@ +using Test +using JET + +import ClimaCore: + Geometry, Domains, Meshes, Spaces, Fields, MatrixFields, CommonSpaces +import ClimaCore.Utilities: half +import ClimaComms +import ClimaCore.MatrixFields: @name +ClimaComms.@import_required_backends +include("matrix_field_test_utils.jl") + +@testset "get_field_first_index_offset" begin + FT = Float64 + struct Singleton{T} + x::T + end + struct TwoFields{T1, T2} + x::T1 + y::T2 + end + function test_get_field_first_index_offset( + name, + ::Type{T}, + ::Type{S}, + expected_offset, + ) where {T, S} + @test_all MatrixFields.get_field_first_index_offset(name, T, S) == + expected_offset + end + test_get_field_first_index_offset( + @name(x), + FT, + Singleton{Singleton{Singleton{Singleton{FT}}}}, + 0, + ) + test_get_field_first_index_offset( + @name(x.x.x.x), + FT, + Singleton{Singleton{Singleton{Singleton{FT}}}}, + 0, + ) + test_get_field_first_index_offset( + @name(y.x), + FT, + TwoFields{TwoFields{FT, FT}, TwoFields{FT, FT}}, + 2, + ) + test_get_field_first_index_offset( + @name(y.y), + FT, + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + 3, + ) + test_get_field_first_index_offset( + @name(y.y), + Float32, + TwoFields{TwoFields{FT, FT}, TwoFields{FT, FT}}, + 6, + ) + test_get_field_first_index_offset( + @name(y.y.x), + FT, + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + 3, + ) + test_get_field_first_index_offset( + @name(y.y.y.x), + FT, + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + 4, + ) +end + +@testset "broadcasted_get_field_type" begin + FT = Float64 + struct Singleton{T} + x::T + end + struct TwoFields{T1, T2} + x::T1 + y::T2 + end + function test_broadcasted_get_field_type( + name, + ::Type{T}, + expected_type, + ) where {T} + @test_all MatrixFields.broadcasted_get_field_type(T, name) == + expected_type + end + test_broadcasted_get_field_type( + @name(x), + Singleton{Singleton{Singleton{Singleton{FT}}}}, + Singleton{Singleton{Singleton{FT}}}, + ) + test_broadcasted_get_field_type( + @name(x.x.x), + Singleton{Singleton{Singleton{Singleton{FT}}}}, + Singleton{FT}, + ) + test_broadcasted_get_field_type( + @name(y.x), + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + FT, + ) + test_broadcasted_get_field_type( + @name(y.y.y), + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + Singleton{FT}, + ) +end + +@testset "fieldmatrix to scalar fieldmatrix unit tests" begin + FT = Float64 + A, b = dycore_prognostic_EDMF_FieldMatrix(FT) + for (A, b) in ( + dycore_prognostic_EDMF_FieldMatrix(FT), + scaling_only_dycore_prognostic_EDMF_FieldMatrix(FT), + ) + @test all( + entry -> + entry isa MatrixFields.UniformScaling || + eltype(eltype(entry)) <: FT, + MatrixFields.scalar_fieldmatrix(A, b).entries, + ) + test_get(A, entry, key) = A[key] === entry + for (key, entry) in MatrixFields.scalar_fieldmatrix(A, b) + @test test_get(A, entry, key) + @test (@allocated test_get(A, entry, key)) == 0 + @test_opt test_get(A, entry, key) + end + + function scalar_fieldmatrix_wrapper(field_matrix_of_tensors, b) + A_scalar = + MatrixFields.scalar_fieldmatrix(field_matrix_of_tensors, b) + return nothing + end + scalar_fieldmatrix_wrapper(A, b) + @test (@allocated scalar_fieldmatrix_wrapper(A, b)) == 0 + @test_opt MatrixFields.scalar_fieldmatrix(A, b) + end +end