Skip to content

Commit d1b1f77

Browse files
committed
wip
1 parent cf5ff9e commit d1b1f77

File tree

5 files changed

+224
-51
lines changed

5 files changed

+224
-51
lines changed

.buildkite/pipeline.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -874,13 +874,13 @@ steps:
874874
agents:
875875
slurm_gpus: 1
876876

877-
- label: "Unit: scalar_field_matrix (CPU)"
878-
key: cpu_scalar_field_matrix
879-
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/scalar_field_matrix.jl"
877+
- label: "Unit: field_matrix_indexing (CPU)"
878+
key: cpu_field_matrix_indexing
879+
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/field_matrix_indexing.jl"
880880

881-
- label: "Unit: scalar_field_matrix (GPU)"
882-
key: gpu_scalar_field_matrix
883-
command: "julia --color=yes --project=.buildkite test/MatrixFields/scalar_field_matrix.jl"
881+
- label: "Unit: field_matrix_indexing (GPU)"
882+
key: gpu_field_matrix_indexing
883+
command: "julia --color=yes --project=.buildkite test/MatrixFields/field_matrix_indexing.jl"
884884
env:
885885
CLIMACOMMS_DEVICE: "CUDA"
886886
agents:

docs/src/matrix_fields.md

Lines changed: 144 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ scalar_field_matrix
108108
A FieldMatrix entry can be:
109109

110110
- An `UniformScaling`, which contains a `Number`
111-
- A `DiagonalMatrixRow`, which can contain aything
112-
- A `ColumnwiseBandMatrixField`, where each row is a [`BandMatrixRow`](@ref) where the band element type is representable with the space's base number type.
111+
- A `DiagonalMatrixRow`, which can contain either a `Number` or a tensor (represented as a `Geometry.Axis2Tensor`)
112+
- A `ColumnwiseBandMatrixField`, where each value is a [`BandMatrixRow`](@ref) with entries of any type that can be represented using the field's base number type.
113113

114114
If an entry contains a composite type, the fields of that type can be extracted.
115115
This is also true for nested composite types.
@@ -164,10 +164,7 @@ An example of this is:
164164

165165
Now consider what happens indexing `A` with the key `(@name(name1.foo.bar.buz), @name(name2.biz.bop.fud))`.
166166

167-
First, a function searches the keys of `A` for a key that `(@name(foo.bar.buz), @name(biz.bop.fud))`
168-
is a child of. In this example, `(@name(foo.bar.buz), @name(biz.bop.fud))` is a child of
169-
the key `(@name(name1), @name(name2))`, and
170-
`(@name(foo.bar.buz), @name(biz.bop.fud))` is referred to as the internal key.
167+
First, `getindex` finds a key in `A` that contains the key being indexed. In this example, `(@name(name1.foo.bar.buz), @name(name2.biz.bop.fud))` is contained within `(@name(name1), @name(name2))`, so `(@name(name1), @name(name2))` is called the "parent key" and `(@name(foo.bar.buz), @name(biz.bop.fud))` is referred to as the "internal key".
171168

172169
Next, the entry that `(@name(name1), @name(name2))` is paired with is recursively indexed
173170
by the internal key.
@@ -181,9 +178,9 @@ works as follows:
181178
then extract the specified component, and recurse on it with the remaining `internal_name_pair`.
182179
3. If the element type of each band of `entry` is a `Geometry.AdjointAxisVector`, then recurse on the parent of the adjoint.
183180
4. If `internal_name_pair[1]` is not empty, and the first name in it is a field of the element type of each band of `entry`,
184-
extract that field from `entry`, and recurse on the it with the remaining names of `internal_name_pair[1]` and all of `internal_name_pair[2]`
185-
5. If `internal_name_pair[2]` is not empty, and the first name in it is a field of the element type of each row of `entry`,
186-
extract that field from `entry`, and recurse on the it with all of `internal_name_pair[1]` and the remaining names of `internal_name_pair[2]`
181+
extract that field from `entry`, and recurse into it with the remaining names of `internal_name_pair[1]` and all of `internal_name_pair[2]`
182+
5. If `internal_name_pair[2]` is not empty, and the first name in it is a field of the element type of each band of `entry`,
183+
extract that field from `entry`, and recurse into it with all of `internal_name_pair[1]` and the remaining names of `internal_name_pair[2]`
187184
6. At this point, if none of the previous cases are true, both `internal_name_pair[1]` and `internal_name_pair[2]` should be
188185
non-empty, and it is assumed that `entry` is being used to implicitly represent some tensor structure. If the first name in
189186
`internal_name_pair[1]` is equivalent to `internal_name_pair[2]`, then both the first names are dropped, and entry is recursed onto.
@@ -194,3 +191,141 @@ the following situations:
194191

195192
1. The internal key indexes to a type different than the basetype of the entry
196193
2. The internal key indexes to a zero-ed value
194+
3. The internal key slices an `AxisTensor`
195+
196+
### Implicit Tensor Structure Optimization
197+
198+
```@setup 2
199+
using ClimaCore.CommonSpaces
200+
using ClimaCore.Geometry
201+
using ClimaCore.Fields
202+
import ClimaCore: MatrixFields
203+
import ClimaCore.MatrixFields: @name
204+
FT = Float64
205+
space = ColumnSpace(FT ;
206+
z_elem = 6,
207+
z_min = 0,
208+
z_max = 10,
209+
staggering = CellCenter()
210+
)
211+
```
212+
213+
If using a `FieldMatrix` to represent a jacobian, entries with certain structures
214+
can be stored in an optimized manner.
215+
216+
The optimization assumes that if indexing into an entry of scalars, the user intends the
217+
entry to have an implicit tensor structure, with the scalar values representing a scaling of the
218+
tensor identity. If both the first and second name in the name pair are equivalent, then they index onto the diagonal,
219+
and the scalar value is returned. Otherwise, they index off the diagonal, and a zero value
220+
is returned.
221+
222+
The following sections refer the `Field`s
223+
$f$ and $g$, which both have values of type `Covariant12Vector` and are defined on a column domain, which is discretized with $N_v$ layers.
224+
The notation $f_{n}[i]$ where $ 0 < n \leq N_v$ and $i \in (1,2)$ refers to the $i$ component of the element of $f$
225+
at the $i$ vertical level. $g$ is indexed similarly. Although $f$ and $g$ have values of type
226+
`Covariant12Vector`, this optimization works for any two `Field`s of `AxisVector`s
227+
228+
```@example 2
229+
f = map(x -> rand(Geometry.Covariant12Vector{Float64}), Fields.local_geometry_field(space))
230+
g = map(x -> rand(Geometry.Covariant12Vector{Float64}), Fields.local_geometry_field(space))
231+
```
232+
233+
#### Uniform Scaling Case
234+
235+
If $\frac{\partial f_n[i]}{\partial g_n[j]} = [i = j]$ for some scalar $k$, then the
236+
non-optimized entry would be represented by a diagonal matrix with values of an identity 2d tensor. If $k=2$, then
237+
238+
```@example 2
239+
identity_axis2tensor = Geometry.Covariant12Vector(FT(1), FT(0)) * # hide
240+
Geometry.Contravariant12Vector(FT(1), FT(0))' + # hide
241+
Geometry.Covariant12Vector(FT(0), FT(1)) * # hide
242+
Geometry.Contravariant12Vector(FT(0), FT(1))' # hide
243+
k = 2
244+
∂f_∂g = fill(MatrixFields.DiagonalMatrixRow(k * identity_axis2tensor), space)
245+
```
246+
247+
Individual components can be indexed into:
248+
249+
```@example 2
250+
J = MatrixFields.FieldMatrix((@name(f), @name(g))=> ∂f_∂g)
251+
J[[(@name(f.components.data.:(1)), @name(g.components.data.:(1)))]]
252+
```
253+
254+
```@example 2
255+
J[[(@name(f.components.data.:(2)), @name(g.components.data.:(1)))]]
256+
```
257+
258+
The above example indexes into $\frac{\partial f_n[1]}{\partial g_n[1]}$ where $ 0 < n \leq N_v$
259+
260+
The entry can
261+
also be represeted with a single `DiagonalMatrixRow`, as follows:
262+
263+
```@example 2
264+
∂f_∂g_optimized = MatrixFields.DiagonalMatrixRow(k * identity_axis2tensor)
265+
```
266+
267+
`∂f_∂g_optimized` is a single `DiagonalMatrixRow`, which represents a diagonal matrix with the
268+
same tensor along the diagonal. In this case, that tensor is $k$ multiplied by the identity matrix, and that can be
269+
represented with `k * I` as follows
270+
271+
```@example 2
272+
∂f_∂g_more_optimized = MatrixFields.DiagonalMatrixRow(k * identity_axis2tensor)
273+
```
274+
275+
Individual components of `∂f_∂g_optimized` can be indexed in the same way as `∂f_∂g`.
276+
277+
```@example 2
278+
J_unoptimized = MatrixFields.FieldMatrix((@name(f), @name(g)) => ∂f_∂g)
279+
J_unoptimized[(@name(f.components.data.:(1)), @name(g.components.data.:(1)))]
280+
```
281+
282+
```@example 2
283+
J_more_optimized = MatrixFields.FieldMatrix((@name(f), @name(g)) => ∂f_∂g_optimized)
284+
J_more_optimized[(@name(f.components.data.:(1)), @name(g.components.data.:(1)))]
285+
```
286+
287+
```@example 2
288+
J_more_optimized[(@name(f.components.data.:(1)), @name(g.components.data.:(2)))]
289+
```
290+
291+
`∂f_∂g` stores $2 * 2 * N_v$ floats in memory, `∂f_∂g_optimized` stores `$2*2$ floats, and
292+
`∂f_∂g_more_optimized` stores only one float.
293+
294+
#### Vertically Varying Case
295+
296+
The implicit tensor optimization can also be used when
297+
$\frac{\partial f_n[i]}{\partial g_n[j]} = [i = j] * h(f_n, g_n)$.
298+
299+
In this case, a full `ColumnWiseBandMatrixField` must be used.
300+
301+
```@example 2
302+
∂f_∂g_optimized = map(x -> MatrixFields.DiagonalMatrixRow(rand(Float64)), ∂f_∂g)
303+
```
304+
305+
```@example 2
306+
J_optimized = MatrixFields.FieldMatrix((@name(f), @name(g)) => ∂f_∂g_optimized)
307+
J_optimized[(@name(f.components.data.:(1)), @name(g.components.data.:(1)))]
308+
```
309+
310+
```@example 2
311+
Base.Broadcast.materialize(J_optimized[(@name(f.components.data.:(2)), @name(g.components.data.:(1)))])
312+
```
313+
314+
#### bandwidth > 1 case
315+
316+
The implicit tensor optimization can also be used when
317+
$\frac{\partial f_n[i]}{\partial g[j]} = [i = j] * h(f_n, g_{n-k_1}, ..., g_{n+k_2})$ where
318+
$b_1$ and $b_2$ are the lower and upper bandwidth. Say $b_1 = b_2 = 1$. Then
319+
320+
```@example 2
321+
∂f_∂g_optimized = map(x -> MatrixFields.TridiagonalMatrixRow(rand(Float64), rand(Float64), rand(Float64)), ∂f_∂g)
322+
```
323+
324+
```@example 2
325+
J_optimized = MatrixFields.FieldMatrix((@name(f), @name(g)) => ∂f_∂g_optimized)
326+
J_optimized[(@name(f.components.data.:(1)), @name(g.components.data.:(1)))]
327+
```
328+
329+
```@example 2
330+
Base.Broadcast.materialize(J_optimized[(@name(f.components.data.:(2)), @name(g.components.data.:(1)))])
331+
```

src/MatrixFields/field_name_dict.jl

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,11 @@ function get_internal_entry(
213213
name_pair == (@name(), @name()) && return entry
214214
S = eltype(eltype(entry))
215215
T = eltype(parent(entry))
216-
(start_offset, target_type, apply_zero) =
216+
(start_offset, target_type, index_method) =
217217
field_offset_and_type(name_pair, T, S, name_pair)
218-
if target_type <: eltype(parent(entry)) && !apply_zero
219-
band_element_size =
220-
DataLayouts.typesize(eltype(parent(entry)), eltype(eltype(entry)))
218+
if isa(index_method, Val{:view})
219+
@assert target_type <: T
220+
band_element_size = DataLayouts.typesize(T, S)
221221
singleton_datalayout = DataLayouts.singleton(Fields.field_values(entry))
222222
scalar_band_type =
223223
band_matrix_row_type(outer_diagonals(eltype(entry))..., target_type)
@@ -234,14 +234,15 @@ function get_internal_entry(
234234
scalar_data,
235235
)
236236
return Fields.Field(values, axes(entry))
237-
elseif apply_zero && start_offset == 0
237+
elseif isa(index_method, Val{:broadcasted_zero})
238+
# implicit tensor structure optimization, off diagonal
238239
zero_value = zero(target_type)
239240
return Base.broadcasted(entry) do matrix_row
240241
map(x -> zero_value, matrix_row)
241242
end
242-
elseif target_type == S && start_offset == 0
243+
elseif target_type == S && isa(index_method, Val{:view_of_blocks})
243244
return entry
244-
else
245+
else # fallback to broadcasted indexing on each element, currently no support for view_of_blocks
245246
return Base.broadcasted(entry) do matrix_row
246247
map(matrix_row) do matrix_row_entry
247248
get_internal_entry(matrix_row_entry, name_pair)
@@ -306,35 +307,49 @@ end
306307
field_offset_and_type(name_pair::FieldNamePair, ::Type{T}, ::Type{S}, full_key::FieldNamePair)
307308
308309
Returns the offset of the field with name `name_pair` in an object of type `S` in
309-
multiples of `sizeof(T)` and the type of the field with name `name_pair`.
310+
multiples of `sizeof(T)`, the type of the field with name `name_pair`, and a `Val` indicating
311+
what method can index a ClimaCore `Field` of `S` with `name_pair`.
312+
313+
The third return value is one of the following:
314+
- `Val(:view)`: indexing with a view is possible
315+
- `Val(:view_of_blocks)`: indexing with a view of blocks is possible (this is not implemented)
316+
- `Val(:broadcasted_fallback)`: indexing with a view is not possible
317+
- `Val(:broadcasted_zero)`: indexing with a view is not possible, and the `name_pair` indexes
318+
off diagonal with implicit tensor structure optimization (see MatrixFields docs)
310319
311-
When `S` is a `Geometry.Axis2Tensor`, the name pair must index into a scalar of
312-
the tensor or be empty. In other words, the name pair cannot index into a slice.
320+
When `S` is a `Geometry.Axis2Tensor`, and the name pair indexes to a slice of
321+
the tensor, an offset of `-1` is returned . In other words, the name pair cannot index into a slice.
313322
314323
If neither element of `name_pair` is `@name()`, the first name in the pair is indexed with
315324
first, and then the second name is used to index the result of the first.
325+
326+
This is an internal funtion designed to be used with `get_internal_entry(::ColumnwiseBandMatrixField)`
316327
"""
317328
function field_offset_and_type(
318329
name_pair::FieldNamePair,
319330
::Type{T},
320331
::Type{S},
321332
full_key::FieldNamePair,
322333
) where {S, T}
323-
name_pair == (@name(), @name()) && return (0, S, false) # base case
324-
if S <: Geometry.Axis2Tensor &&
325-
all(n -> is_child_name(n, @name(components.data)), name_pair)# special case to calculate index
326-
(name_pair[1] == @name() || name_pair[2] == @name()) &&
327-
throw(KeyError(full_key))
334+
if name_pair == (@name(), @name()) # recursion base case
335+
# if S <: T, then its possible to construct a strided view in the indexing function
336+
return (0, S, S <: T ? Val(:view) : Val(:view_of_blocks))
337+
elseif S <: Geometry.Axis2Tensor &&
338+
all(n -> is_child_name(n, @name(components.data)), name_pair) # special case to calculate index
328339
internal_row_name =
329340
extract_internal_name(name_pair[1], @name(components.data))
330341
internal_col_name =
331342
extract_internal_name(name_pair[2], @name(components.data))
332343
row_index = extract_first(internal_row_name)
333344
col_index = extract_first(internal_col_name)
345+
if ((row_index isa Number) && (col_index isa Colon)) ||
346+
((row_index isa Colon) && (col_index isa Number))
347+
return (0, S, Val{:broadcasted_fallback}()) # slice case, return trigger fallback
348+
end
334349
((row_index isa Number) && (col_index isa Number)) ||
335-
throw(KeyError(full_key)) # slicing not supported
350+
throw(KeyError(full_key))
336351
(n_rows, n_cols) = map(length, axes(S))
337-
(remaining_offset, end_type, apply_zero) = field_offset_and_type(
352+
(remaining_offset, end_type, index_method) = field_offset_and_type(
338353
(drop_first(internal_row_name), drop_first(internal_col_name)),
339354
T,
340355
eltype(S),
@@ -345,20 +360,19 @@ function field_offset_and_type(
345360
return (
346361
(n_rows * (col_index - 1) + row_index - 1) + remaining_offset,
347362
end_type,
348-
apply_zero,
363+
index_method,
349364
)
350-
elseif S <: Geometry.AdjointAxisVector
365+
elseif S <: Geometry.AdjointAxisVector # bypass adjoint because indexing parent is equivalent
351366
return field_offset_and_type(name_pair, T, fieldtype(S, 1), full_key)
352367
elseif name_pair[1] != @name() &&
353-
extract_first(name_pair[1]) in fieldnames(S)
354-
368+
extract_first(name_pair[1]) in fieldnames(S) # index with first part of name_pair[1]
355369
remaining_field_chain = (drop_first(name_pair[1]), name_pair[2])
356370
child_type = fieldtype(S, extract_first(name_pair[1]))
357371
field_index = unrolled_filter(
358372
i -> fieldname(S, i) == extract_first(name_pair[1]),
359373
1:fieldcount(S),
360374
)[1]
361-
(remaining_offset, end_type, apply_zero) = field_offset_and_type(
375+
(remaining_offset, end_type, index_method) = field_offset_and_type(
362376
remaining_field_chain,
363377
T,
364378
child_type,
@@ -367,18 +381,17 @@ function field_offset_and_type(
367381
return (
368382
DataLayouts.fieldtypeoffset(T, S, field_index) + remaining_offset,
369383
end_type,
370-
apply_zero,
384+
index_method,
371385
)
372386
elseif name_pair[2] != @name() &&
373-
extract_first(name_pair[2]) in fieldnames(S)
374-
387+
extract_first(name_pair[2]) in fieldnames(S) # index with first part of name_pair[2]
375388
remaining_field_chain = name_pair[1], drop_first(name_pair[2])
376389
child_type = fieldtype(S, extract_first(name_pair[2]))
377390
field_index = unrolled_filter(
378391
i -> fieldname(S, i) == extract_first(name_pair[2]),
379392
1:fieldcount(S),
380393
)[1]
381-
(remaining_offset, end_type, apply_zero) = field_offset_and_type(
394+
(remaining_offset, end_type, index_method) = field_offset_and_type(
382395
remaining_field_chain,
383396
T,
384397
child_type,
@@ -387,10 +400,10 @@ function field_offset_and_type(
387400
return (
388401
DataLayouts.fieldtypeoffset(T, S, field_index) + remaining_offset,
389402
end_type,
390-
apply_zero,
403+
index_method,
391404
)
392-
elseif !any(isequal(@name()), name_pair) # implicit tensor structure
393-
(remaining_offset, end_type, apply_zero) = field_offset_and_type(
405+
elseif !any(isequal(@name()), name_pair) # implicit tensor structure optimization
406+
(remaining_offset, end_type, index_method) = field_offset_and_type(
394407
(drop_first(name_pair[1]), drop_first(name_pair[2])),
395408
T,
396409
S,
@@ -400,7 +413,7 @@ function field_offset_and_type(
400413
remaining_offset,
401414
end_type,
402415
extract_first(name_pair[1]) == extract_first(name_pair[2]) ?
403-
apply_zero : true,
416+
index_method : Val(:broadcasted_zero), # zero if off diagonal
404417
)
405418
else
406419
throw(KeyError(full_key))

0 commit comments

Comments
 (0)