Skip to content

Add scalar_fieldmatrix #2289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,18 @@ steps:
agents:
slurm_gpus: 1

- label: "Unit: scalar_fieldmatrix (CPU)"
key: cpu_scalar_fieldmatrix
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/scalar_fieldmatrix.jl"

- label: "Unit: mscalar_fieldmatrix (GPU)"
key: gpu_scalar_fieldmatrix
command: "julia --color=yes --project=.buildkite test/MatrixFields/scalar_fieldmatrix.jl"
env:
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_gpus: 1

- group: "Unit: MatrixFields - broadcasting (CPU)"
steps:

Expand Down
80 changes: 80 additions & 0 deletions docs/src/matrix_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ preconditioner_cache
check_preconditioner
lazy_or_concrete_preconditioner
apply_preconditioner
get_scalar_keys
get_field_first_index_offset
broadcasted_get_field_type
inner_type_ignore_adjoint
```

## Utilities
Expand All @@ -98,4 +102,80 @@ column_field2array
column_field2array_view
field2arrays
field2arrays_view
scalar_fieldmatrix
```

## Indexing a FieldMatrix

A FieldMatrix entry can be:

- An `UniformScaling`, which contains a `Number`
- A `DiagonalMatrixRow`, which can contain aything
- A `ColumnwiseBandMatrixField`, where each row is a [`BandMatrixRow`](@ref) where the band element type is representable with the space's base number type.

If an entry contains a composite type, the fields of that type can be extracted.
This is also true for nested composite types.

For example:

```@example 1
using ClimaCore.CommonSpaces # hide
import ClimaCore: MatrixFields, Quadratures # hide
import ClimaCore.MatrixFields: @name # hide
space = Box3DSpace(; # hide
z_elem = 3, # hide
x_min = 0, # hide
x_max = 1, # hide
y_min = 0, # hide
y_max = 1, # hide
z_min = 0, # hide
z_max = 10, # hide
periodic_x = false, # hide
periodic_y = false, # hide
n_quad_points = 1, # hide
quad = Quadratures.GL{1}(), # hide
x_elem = 1, # hide
y_elem = 2, # hide
staggering = CellCenter() # hide
) # hide
nt_entry_field = fill(MatrixFields.DiagonalMatrixRow((; foo = 1.0, bar = 2.0)), space)
nt_fieldmatrix = MatrixFields.FieldMatrix((@name(a), @name(b)) => nt_entry_field)
nt_fieldmatrix[(@name(a), @name(b))]
```

The internal values of the named tuples can be extracted with

```@example 1
nt_fieldmatrix[(@name(a.foo), @name(b))]
```

and

```@example 1
nt_fieldmatrix[(@name(a.bar), @name(b))]
```

If the key `(@name(name1), @name(name2))` corresponds to an entry, then
`(@name(foo.bar.buz), @name(biz.bop.fud))` would be the internal key for the key
`(@name(name1.foo.bar.buz), @name(name2.biz.bop.fud))`.

Currently, internal values cannot be extracted in all situations. Extracting interal values
works when:

- The second name in the internal key is empty, and the first name in the internal key accesses internal values for the type of element contained in each row of the entry. This does not work when the element type of each row is a 2d tensor.

- The first name in the internal key is empty, and the type of element contained in each row of the entry is an `AxisVector` or the adjoint of an `AxisVector`. In this case, the second name must access inernal values for the type of `AxisVector` contained in each row.

- The element type of each row in the entry is a 2d tensor, and the internal key is of the form `(@name(components.data.:(1)), @name(components.data.:(2)))`, but possibly with different numbers to index into the 2d tensor

- The element type of each row in the entry is some number of nested `Tuple`s and `NamedTuple`s, and the first name in the internal key accesses an `AxisVector` or the adjoint of an `AxisVector` from the outer `Tuple`/`NamedTuple`, and the second name in the inernal key accesses a component of the `AxisVector`

If the `FieldMatrix` represents a Jacobian, then extracting internal values works when an entry represents:

- The partial derrivative of an `AxisVector`, `Tuple`, or `NamedTuple` with respect to a scalar.

- The partial derrivative of a scalar with respect to an `AxisVector`.

- The partial derrivative of a `Tuple`, or `NamedTuple` with respect to an `AxisVector`. In this case, the first name of the internal key must index into the tuple and result in a scalar.

- The partial derrivative of an `AxisVector` with respect to an `AxisVector`. In this case, the partial derrivative of a component of the first `AxisVector` with respect to a component of the second `AxisVector` can be extracted, but not an entire `AxisVector` with respect to a component, or a component with respect to an entire `AxisVector`
3 changes: 3 additions & 0 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} =

const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}}

const AxisVectorOrAdj{T, A, S} =
Union{AxisVector{T, A, S}, AdjointAxisVector{T, A, S}}

Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) =
getindex(components(va), i)
Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) =
Expand Down
1 change: 1 addition & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
import ..RecursiveApply: ⊠, ⊞, ⊟
import ..DataLayouts
import ..DataLayouts: AbstractData
import ..DataLayouts: vindex
import ..Geometry
Expand Down
18 changes: 18 additions & 0 deletions src/MatrixFields/field_name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ extract_first(::FieldName{name_chain}) where {name_chain} = first(name_chain)
drop_first(::FieldName{name_chain}) where {name_chain} =
FieldName(Base.tail(name_chain)...)

extract_last(::FieldName{name_chain}) where {name_chain} =
name_chain[length(name_chain)]

has_field(x, ::FieldName{()}) = true
has_field(x, name::FieldName) =
extract_first(name) in propertynames(x) &&
Expand All @@ -59,6 +62,18 @@ get_field(x, ::FieldName{()}) = x
get_field(x, name::FieldName) =
get_field(getproperty(x, extract_first(name)), drop_first(name))

"""
broadcasted_get_field_type(::Type{X}, name::FieldName)

Returns the type of the field accessed by `name` in the type `X`.
"""
broadcasted_get_field_type(::Type{X}, ::FieldName{()}) where {X} = X
broadcasted_get_field_type(::Type{X}, name::FieldName) where {X} =
broadcasted_get_field_type(
fieldtype(X, extract_first(name)),
drop_first(name),
)

broadcasted_has_field(::Type{X}, ::FieldName{()}) where {X} = true
broadcasted_has_field(::Type{X}, name::FieldName) where {X} =
extract_first(name) in fieldnames(X) &&
Expand Down Expand Up @@ -199,4 +214,7 @@ if hasfield(Method, :recursion_relation)
for m in methods(get_subtree_at_name)
m.recursion_relation = dont_limit
end
for m in methods(broadcasted_get_field_type)
m.recursion_relation = dont_limit
end
end
Loading
Loading