Skip to content

Commit 6d01d95

Browse files
committed
Add scalar_field_matrix
Add scalar_fieldmatrix Add a function to convert a FieldMatrix where each matrix entry has an eltype of some struct into a FieldMatrix where each entry has an eltype of a scalar. Add additional tests for scalar_matrixfields Use @test_all in tests Make suggested changes to tests and field_name_dict.jl Revert unrolled_findfirst Clean up field matrix tests and add support for DiagonalMatrixRows CamelCase struct name Clean up tests and get_scalar_keys wip backup Minimal working with allocs WIP1 WIP more allocs fix Assorted cleanup Fix dx/dx case reduce code duplication; fix example Add gpu test further cleanup, extend diagonalrow fix names test and comments Add docs docs bugfix remvoe bad refs fix docs formatting WIP Y fields pre-switch to type space should work fix broken tests bugfix fix implicit tensor rep tests WIPP1 working state Improve readability at cost of concise code update docs further cleanup propgate full key vs keyerror propogate name_tree scalar_fielmatrix to scalar_field_matrix
1 parent 5f968c8 commit 6d01d95

File tree

9 files changed

+845
-94
lines changed

9 files changed

+845
-94
lines changed

.buildkite/pipeline.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,18 @@ steps:
874874
agents:
875875
slurm_gpus: 1
876876

877+
- label: "Unit: scalar_field_matrix (CPU)"
878+
key: cpu_scalar_field_matrix
879+
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/scalar_field_matrix.jl"
880+
881+
- label: "Unit: scalar_field_matrix (GPU)"
882+
key: gpu_scalar_field_matrix
883+
command: "julia --color=yes --project=.buildkite test/MatrixFields/scalar_field_matrix.jl"
884+
env:
885+
CLIMACOMMS_DEVICE: "CUDA"
886+
agents:
887+
slurm_gpus: 1
888+
877889
- group: "Unit: MatrixFields - broadcasting (CPU)"
878890
steps:
879891

docs/src/matrix_fields.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ preconditioner_cache
8989
check_preconditioner
9090
lazy_or_concrete_preconditioner
9191
apply_preconditioner
92+
get_scalar_keys
93+
field_offset_and_type
9294
```
9395

9496
## Utilities
@@ -98,4 +100,97 @@ column_field2array
98100
column_field2array_view
99101
field2arrays
100102
field2arrays_view
103+
scalar_field_matrix
101104
```
105+
106+
## Indexing a FieldMatrix
107+
108+
A FieldMatrix entry can be:
109+
110+
- An `UniformScaling`, which contains a `Number`
111+
- A `DiagonalMatrixRow`, which can contain aything
112+
- A `ColumnwiseBandMatrixField`, where each row is a [`BandMatrixRow`](@ref) where the band element type is representable with the space's base number type.
113+
114+
If an entry contains a composite type, the fields of that type can be extracted.
115+
This is also true for nested composite types.
116+
117+
For example:
118+
119+
```@example 1
120+
using ClimaCore.CommonSpaces # hide
121+
import ClimaCore: MatrixFields, Quadratures # hide
122+
import ClimaCore.MatrixFields: @name # hide
123+
space = Box3DSpace(; # hide
124+
z_elem = 3, # hide
125+
x_min = 0, # hide
126+
x_max = 1, # hide
127+
y_min = 0, # hide
128+
y_max = 1, # hide
129+
z_min = 0, # hide
130+
z_max = 10, # hide
131+
periodic_x = false, # hide
132+
periodic_y = false, # hide
133+
n_quad_points = 1, # hide
134+
quad = Quadratures.GL{1}(), # hide
135+
x_elem = 1, # hide
136+
y_elem = 2, # hide
137+
staggering = CellCenter() # hide
138+
) # hide
139+
nt_entry_field = fill(MatrixFields.DiagonalMatrixRow((; foo = 1.0, bar = 2.0)), space)
140+
nt_fieldmatrix = MatrixFields.FieldMatrix((@name(a), @name(b)) => nt_entry_field)
141+
nt_fieldmatrix[(@name(a), @name(b))]
142+
```
143+
144+
The internal values of the named tuples can be extracted with
145+
146+
```@example 1
147+
nt_fieldmatrix[(@name(a.foo), @name(b))]
148+
```
149+
150+
and
151+
152+
```@example 1
153+
nt_fieldmatrix[(@name(a.bar), @name(b))]
154+
```
155+
156+
### Further Indexing Details
157+
158+
Let key `(@name(name1), @name(name2))` correspond to entry `sample_entry` in `FieldMatrix` `A`.
159+
An example of this is:
160+
161+
```julia
162+
A = MatrixFields.FieldMatrix((@name(name1), @name(name2)) => sample_entry)
163+
```
164+
165+
Now consider what happens indexing `A` with the key `(@name(name1.foo.bar.buz), @name(name2.biz.bop.fud))`.
166+
167+
First, a function searches the keys of `A` for a key that `(@name(foo.bar.buz), @name(biz.bop.fud))`
168+
is a child of. In this example, `(@name(foo.bar.buz), @name(biz.bop.fud))` is a child of
169+
the key `(@name(name1), @name(name2))`, and
170+
`(@name(foo.bar.buz), @name(biz.bop.fud))` is referred to as the internal key.
171+
172+
Next, the entry that `(@name(name1), @name(name2))` is paired with is recursively indexed
173+
by the internal key.
174+
175+
The recursive indexing of an internal entry given some entry `entry` and internal_key `internal_name_pair`
176+
works as follows:
177+
178+
1. If the `internal_name_pair` is blank, return `entry`
179+
2. If the element type of each band of `entry` is an `Axis2Tensor`, and `internal_name_pair` is of the form
180+
`(@name(components.data.1...), @name(components.data.2...))` (potentially with different numbers),
181+
then extract the specified component, and recurse on it with the remaining `internal_name_pair`.
182+
3. If the element type of each band of `entry` is a `Geometry.AdjointAxisVector`, then recurse on the parent of the adjoint.
183+
4. If `internal_name_pair[1]` is not empty, and the first name in it is a field of the element type of each band of `entry`,
184+
extract that field from `entry`, and recurse on the it with the remaining names of `internal_name_pair[1]` and all of `internal_name_pair[2]`
185+
5. If `internal_name_pair[2]` is not empty, and the first name in it is a field of the element type of each row of `entry`,
186+
extract that field from `entry`, and recurse on the it with all of `internal_name_pair[1]` and the remaining names of `internal_name_pair[2]`
187+
6. At this point, if none of the previous cases are true, both `internal_name_pair[1]` and `internal_name_pair[2]` should be
188+
non-empty, and it is assumed that `entry` is being used to implicitly represent some tensor structure. If the first name in
189+
`internal_name_pair[1]` is equivalent to `internal_name_pair[2]`, then both the first names are dropped, and entry is recursed onto.
190+
If the first names are different, both the first names are dropped, and the zero of entry is recursed onto.
191+
192+
When the entry is a `ColumnWiseBandMatrixField`, indexing it will return a broadcasted object in
193+
the following situations:
194+
195+
1. The internal key indexes to a type different than the basetype of the entry
196+
2. The internal key indexes to a zero-ed value

src/Geometry/axistensors.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} =
283283

284284
const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}}
285285

286+
const AxisVectorOrAdj{T, A, S} =
287+
Union{AxisVector{T, A, S}, AdjointAxisVector{T, A, S}}
288+
286289
Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) =
287290
getindex(components(va), i)
288291
Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) =

src/MatrixFields/MatrixFields.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import ..Utilities: PlusHalf, half
5858
import ..RecursiveApply:
5959
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
6060
import ..RecursiveApply: , ,
61+
import ..DataLayouts
6162
import ..DataLayouts: AbstractData
6263
import ..DataLayouts: vindex
6364
import ..Geometry

0 commit comments

Comments
 (0)