Skip to content

Commit 342b1c0

Browse files
authored
Make single-multiple field solvers mask-aware (#2270)
Make reductions mask-aware
1 parent 71e1117 commit 342b1c0

File tree

6 files changed

+90
-45
lines changed

6 files changed

+90
-45
lines changed

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ NVTX.@annotate function multiple_field_solve!(
2121
Nnames = length(names)
2222
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
2323
sscache = Operators.strip_space(cache)
24+
mask = Spaces.get_mask(axes(x1))
2425
ssx = Operators.strip_space(x)
2526
ssA = Operators.strip_space(A)
2627
ssb = Operators.strip_space(b)
@@ -33,7 +34,7 @@ NVTX.@annotate function multiple_field_solve!(
3334
device = ClimaComms.device(x[first(names)])
3435

3536
us = UniversalSize(Fields.field_values(x1))
36-
args = (device, caches, xs, As, bs, x1, us, Val(Nnames))
37+
args = (device, caches, xs, As, bs, x1, us, mask, Val(Nnames))
3738

3839
nitems = Ni * Nj * Nh * Nnames
3940
threads = threads_via_occupancy(multiple_field_solve_kernel!, args)
@@ -86,10 +87,12 @@ function multiple_field_solve_kernel!(
8687
bs,
8788
x1,
8889
us::UniversalSize,
90+
mask,
8991
::Val{Nnames},
9092
) where {Nnames}
9193
@inbounds begin
9294
(I, iname) = multiple_field_solve_universal_index(us)
95+
DataLayouts.should_compute(mask, I) || return nothing
9396
if multiple_field_solve_is_valid_index(I, us)
9497
(i, j, _, _, h) = I.I
9598
generated_single_field_solve!(

ext/cuda/matrix_fields_single_field_solve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ import ClimaCore.RecursiveApply: ⊠, ⊞, ⊟, rmap, rzero, rdiv
1616
function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
1717
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
1818
us = UniversalSize(Fields.field_values(A))
19-
args = (device, cache, x, A, b, us)
19+
mask = Spaces.get_mask(axes(x))
20+
args = (device, cache, x, A, b, us, mask)
2021
threads = threads_via_occupancy(single_field_solve_kernel!, args)
2122
nitems = Ni * Nj * Nh
2223
n_max_threads = min(threads, nitems)
@@ -30,8 +31,9 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
3031
call_post_op_callback() && post_op_callback(x, device, cache, x, A, b)
3132
end
3233

33-
function single_field_solve_kernel!(device, cache, x, A, b, us)
34+
function single_field_solve_kernel!(device, cache, x, A, b, us, mask)
3435
I = columnwise_universal_index(us)
36+
DataLayouts.should_compute(mask, I) || return nothing
3537
if columnwise_is_valid_index(I, us)
3638
(i, j, _, _, h) = I.I
3739
_single_field_solve!(

ext/cuda/operators_integral.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ function column_reduce_device!(
2020
) where {F, T}
2121
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
2222
us = UniversalSize(Fields.field_values(output))
23+
mask = Spaces.get_mask(space)
24+
if !(mask isa DataLayouts.NoMask) && space isa Spaces.FiniteDifferenceSpace
25+
error("Masks not supported for FiniteDifferenceSpace")
26+
end
2327
args = (
2428
single_column_reduce!,
2529
f,
@@ -29,6 +33,7 @@ function column_reduce_device!(
2933
init,
3034
space,
3135
us,
36+
mask,
3237
)
3338
nitems = Ni * Nj * Nh
3439
threads = threads_via_occupancy(bycolumn_kernel!, args)
@@ -57,6 +62,10 @@ function column_accumulate_device!(
5762
space,
5863
) where {F, T}
5964
out_fv = Fields.field_values(output)
65+
mask = Spaces.get_mask(space)
66+
if !(mask isa DataLayouts.NoMask) && space isa Spaces.FiniteDifferenceSpace
67+
error("Masks not supported for FiniteDifferenceSpace")
68+
end
6069
us = UniversalSize(out_fv)
6170
args = (
6271
single_column_accumulate!,
@@ -67,6 +76,7 @@ function column_accumulate_device!(
6776
init,
6877
space,
6978
us,
79+
mask,
7080
)
7181
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
7282
nitems = Ni * Nj * Nh
@@ -81,7 +91,7 @@ function column_accumulate_device!(
8191
)
8292
end
8393

84-
bycolumn_kernel!(
94+
function bycolumn_kernel!(
8595
single_column_function!::S,
8696
f::F,
8797
transform::T,
@@ -90,11 +100,13 @@ bycolumn_kernel!(
90100
init,
91101
space,
92102
us::DataLayouts.UniversalSize,
93-
) where {S, F, T} =
103+
mask,
104+
) where {S, F, T}
94105
if space isa Spaces.FiniteDifferenceSpace
95106
single_column_function!(f, transform, output, input, init, space)
96107
else
97108
I = columnwise_universal_index(us)
109+
DataLayouts.should_compute(mask, I) || return nothing
98110
if columnwise_is_valid_index(I, us)
99111
(i, j, _, _, h) = I.I
100112
single_column_function!(
@@ -107,3 +119,4 @@ bycolumn_kernel!(
107119
)
108120
end
109121
end
122+
end

src/Fields/indices.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,9 @@ function byslab(
244244
end
245245
end
246246
end
247+
248+
universal_index(colidx::Fields.ColumnIndex{2}) =
249+
CartesianIndex(colidx.ij[1], colidx.ij[2], 1, 1, colidx.h)
250+
251+
universal_index(colidx::Fields.ColumnIndex{1}) =
252+
CartesianIndex(colidx.ij[1], 1, 1, 1, colidx.h)

src/MatrixFields/single_field_solver.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ..DataLayouts
12
dual_type(::Type{A}) where {A} = typeof(Geometry.dual(A.instance))
23

34
inv_return_type(::Type{X}) where {X} = error(
@@ -54,11 +55,10 @@ function single_field_solve_diag_matrix_row!(
5455
A::ColumnwiseBandMatrixField,
5556
b,
5657
)
57-
Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries))
58-
b_vals = Fields.field_values(b)
59-
x_vals = Fields.field_values(x)
60-
(A₀,) = Aⱼs
61-
@. x_vals = inv(A₀) b_vals
58+
# Use fields here, and not field values, so that this operation is
59+
# mask-aware.
60+
A₀ = A.entries.:1
61+
@. x = inv(A₀) b
6262
end
6363
single_field_solve!(_, x, A::ScalingFieldMatrixEntry, b) =
6464
x .= (inv(scaling_value(A)),) .* b
@@ -82,17 +82,22 @@ function _single_field_solve!(
8282
b,
8383
)
8484
space = axes(x)
85+
mask = Spaces.get_mask(space)
8586
if space isa Spaces.FiniteDifferenceSpace
87+
@assert mask isa DataLayouts.NoMask
8688
_single_field_solve_col!(device, cache, x, A, b)
8789
else
8890
Fields.bycolumn(space) do colidx
89-
_single_field_solve_col!(
90-
device,
91-
cache[colidx],
92-
x[colidx],
93-
A[colidx],
94-
b[colidx],
95-
)
91+
I = Fields.universal_index(colidx)
92+
if DataLayouts.should_compute(mask, I)
93+
_single_field_solve_col!(
94+
device,
95+
cache[colidx],
96+
x[colidx],
97+
A[colidx],
98+
b[colidx],
99+
)
100+
end
96101
end
97102
end
98103
end

src/Operators/integrals.jl

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ bottom of each column and moving upward, and the result of the final iteration
124124
is passed to the `transform` function before being stored in `output`. If `init`
125125
is specified, it is used as the initial value of the iteration; otherwise, the
126126
value at the bottom of each column in `input` is used as the initial value.
127-
127+
128128
With `first_level` and `last_level` denoting the indices of the boundary levels
129129
of `input`, the reduction in each column can be summarized as follows:
130130
- If `init` is unspecified,
@@ -156,27 +156,35 @@ function column_reduce!(
156156
column_reduce_device!(device, f, transform, output, input, init, space)
157157
end
158158

159-
column_reduce_device!(
159+
function column_reduce_device!(
160160
::ClimaComms.AbstractCPUDevice,
161161
f::F,
162162
transform::T,
163163
output,
164164
input,
165165
init,
166166
space,
167-
) where {F, T} =
168-
space isa Spaces.FiniteDifferenceSpace ?
169-
single_column_reduce!(f, transform, output, input, init, space) :
170-
Fields.bycolumn(space) do colidx
171-
single_column_reduce!(
172-
f,
173-
transform,
174-
output[colidx],
175-
input[colidx],
176-
init,
177-
space[colidx],
178-
)
167+
) where {F, T}
168+
mask = Spaces.get_mask(space)
169+
if space isa Spaces.FiniteDifferenceSpace
170+
@assert mask isa DataLayouts.NoMask
171+
single_column_reduce!(f, transform, output, input, init, space)
172+
else
173+
Fields.bycolumn(space) do colidx
174+
I = Fields.universal_index(colidx)
175+
if DataLayouts.should_compute(mask, I)
176+
single_column_reduce!(
177+
f,
178+
transform,
179+
output[colidx],
180+
input[colidx],
181+
init,
182+
space[colidx],
183+
)
184+
end
185+
end
179186
end
187+
end
180188

181189
# On GPUs, input and output go through strip_space to become _input and _output.
182190
function single_column_reduce!(
@@ -214,7 +222,7 @@ from the bottom of each column and moving upward, and the result of each
214222
iteration is passed to the `transform` function before being stored in `output`.
215223
The `init` value is is optional for center-to-center, face-to-face, and
216224
face-to-center accumulation, but it is required for center-to-face accumulation.
217-
225+
218226
With `first_level` and `last_level` denoting the indices of the boundary levels
219227
of `input`, the accumulation in each column can be summarized as follows:
220228
- For center-to-center and face-to-face accumulation with `init` unspecified,
@@ -276,27 +284,35 @@ function column_accumulate!(
276284
column_accumulate_device!(device, f, transform, output, input, init, space)
277285
end
278286

279-
column_accumulate_device!(
287+
function column_accumulate_device!(
280288
::ClimaComms.AbstractCPUDevice,
281289
f::F,
282290
transform::T,
283291
output,
284292
input,
285293
init,
286294
space,
287-
) where {F, T} =
288-
space isa Spaces.FiniteDifferenceSpace ?
289-
single_column_accumulate!(f, transform, output, input, init, space) :
290-
Fields.bycolumn(space) do colidx
291-
single_column_accumulate!(
292-
f,
293-
transform,
294-
output[colidx],
295-
input[colidx],
296-
init,
297-
space[colidx],
298-
)
295+
) where {F, T}
296+
mask = Spaces.get_mask(space)
297+
if space isa Spaces.FiniteDifferenceSpace
298+
@assert mask isa DataLayouts.NoMask
299+
single_column_accumulate!(f, transform, output, input, init, space)
300+
else
301+
Fields.bycolumn(space) do colidx
302+
I = Fields.universal_index(colidx)
303+
if DataLayouts.should_compute(mask, I)
304+
single_column_accumulate!(
305+
f,
306+
transform,
307+
output[colidx],
308+
input[colidx],
309+
init,
310+
space[colidx],
311+
)
312+
end
313+
end
299314
end
315+
end
300316

301317
# On GPUs, input and output go through strip_space to become _input and _output.
302318
function single_column_accumulate!(

0 commit comments

Comments
 (0)