Skip to content

Commit ee2b83e

Browse files
Merge pull request #1920 from CliMA/ck/rm_non_cartesian_indexing_dl
Remove multiple integer indexing in DataLayouts
2 parents f8b7ad4 + 87d4fdb commit ee2b83e

32 files changed

+529
-516
lines changed

ext/cuda/limiters.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ClimaCore.Limiters:
55
apply_limiter!
66
import ClimaCore.Fields
77
import ClimaCore: DataLayouts, Spaces, Topologies, Fields
8+
import ClimaCore.DataLayouts: slab_index
89
using CUDA
910

1011
function config_threadblock(Nv, Nh)
@@ -63,7 +64,7 @@ function compute_element_bounds_kernel!(
6364
slab_ρ = slab(ρ, v, h)
6465
for j in 1:Nj
6566
for i in 1:Ni
66-
q = rdiv(slab_ρq[i, j], slab_ρ[i, j])
67+
q = rdiv(slab_ρq[slab_index(i, j)], slab_ρ[slab_index(i, j)])
6768
if i == 1 && j == 1
6869
q_min = q
6970
q_max = q
@@ -74,8 +75,8 @@ function compute_element_bounds_kernel!(
7475
end
7576
end
7677
slab_q_bounds = slab(q_bounds, v, h)
77-
slab_q_bounds[1] = q_min
78-
slab_q_bounds[2] = q_max
78+
slab_q_bounds[slab_index(1)] = q_min
79+
slab_q_bounds[slab_index(2)] = q_max
7980
end
8081
return nothing
8182
end
@@ -123,18 +124,18 @@ function compute_neighbor_bounds_local_kernel!(
123124
(v, h) = kernel_indexes(tidx, n).I
124125
(; q_bounds, q_bounds_nbr, ghost_buffer, rtol) = limiter
125126
slab_q_bounds = slab(q_bounds, v, h)
126-
q_min = slab_q_bounds[1]
127-
q_max = slab_q_bounds[2]
127+
q_min = slab_q_bounds[slab_index(1)]
128+
q_max = slab_q_bounds[slab_index(2)]
128129
for lne in
129130
local_neighbor_elem_offset[h]:(local_neighbor_elem_offset[h + 1] - 1)
130131
h_nbr = local_neighbor_elem[lne]
131132
slab_q_bounds = slab(q_bounds, v, h_nbr)
132-
q_min = rmin(q_min, slab_q_bounds[1])
133-
q_max = rmax(q_max, slab_q_bounds[2])
133+
q_min = rmin(q_min, slab_q_bounds[slab_index(1)])
134+
q_max = rmax(q_max, slab_q_bounds[slab_index(2)])
134135
end
135136
slab_q_bounds_nbr = slab(q_bounds_nbr, v, h)
136-
slab_q_bounds_nbr[1] = q_min
137-
slab_q_bounds_nbr[2] = q_max
137+
slab_q_bounds_nbr[slab_index(1)] = q_min
138+
slab_q_bounds_nbr[slab_index(2)] = q_max
138139
end
139140
return nothing
140141
end

ext/cuda/matrix_fields_single_field_solve.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import ClimaCore.Fields
77
import ClimaCore.Spaces
88
import ClimaCore.Topologies
99
import ClimaCore.MatrixFields
10+
import ClimaCore.DataLayouts: vindex
1011
import ClimaCore.MatrixFields: single_field_solve!
1112
import ClimaCore.MatrixFields: _single_field_solve!
1213
import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values
@@ -71,7 +72,7 @@ function _single_field_solve!(
7172
b_data = Fields.field_values(b)
7273
Nv = DataLayouts.nlevels(x_data)
7374
@inbounds for v in 1:Nv
74-
x_data[v] = inv(A.λ) b_data[v]
75+
x_data[vindex(v)] = inv(A.λ) b_data[vindex(v)]
7576
end
7677
end
7778

@@ -98,6 +99,7 @@ function band_matrix_solve_local_mem!(
9899
Nv = DataLayouts.nlevels(x)
99100
Ux, U₊₁ = cache
100101
A₋₁, A₀, A₊₁ = Aⱼs
102+
vi = vindex
101103

102104
Ux_local = MArray{Tuple{Nv}, eltype(Ux)}(undef)
103105
U₊₁_local = MArray{Tuple{Nv}, eltype(U₊₁)}(undef)
@@ -107,16 +109,16 @@ function band_matrix_solve_local_mem!(
107109
A₊₁_local = MArray{Tuple{Nv}, eltype(A₊₁)}(undef)
108110
b_local = MArray{Tuple{Nv}, eltype(b)}(undef)
109111
@inbounds for v in 1:Nv
110-
A₋₁_local[v] = A₋₁[v]
111-
A₀_local[v] = A₀[v]
112-
A₊₁_local[v] = A₊₁[v]
113-
b_local[v] = b[v]
112+
A₋₁_local[v] = A₋₁[vi(v)]
113+
A₀_local[v] = A₀[vi(v)]
114+
A₊₁_local[v] = A₊₁[vi(v)]
115+
b_local[v] = b[vi(v)]
114116
end
115117
cache_local = (Ux_local, U₊₁_local)
116-
Aⱼs_local = (A₋₁, A₀, A₊₁)
117-
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local)
118+
Aⱼs_local = (A₋₁_local, A₀_local, A₊₁_local)
119+
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local, identity)
118120
@inbounds for v in 1:Nv
119-
x[v] = x_local[v]
121+
x[vi(v)] = x_local[v]
120122
end
121123
return nothing
122124
end
@@ -128,6 +130,7 @@ function band_matrix_solve_local_mem!(
128130
Aⱼs,
129131
b,
130132
)
133+
vi = vindex
131134
Nv = DataLayouts.nlevels(x)
132135
Ux, U₊₁, U₊₂ = cache
133136
A₋₂, A₋₁, A₀, A₊₁, A₊₂ = Aⱼs
@@ -142,18 +145,18 @@ function band_matrix_solve_local_mem!(
142145
A₊₂_local = MArray{Tuple{Nv}, eltype(A₊₂)}(undef)
143146
b_local = MArray{Tuple{Nv}, eltype(b)}(undef)
144147
@inbounds for v in 1:Nv
145-
A₋₂_local[v] = A₋₂[v]
146-
A₋₁_local[v] = A₋₁[v]
147-
A₀_local[v] = A₀[v]
148-
A₊₁_local[v] = A₊₁[v]
149-
A₊₂_local[v] = A₊₂[v]
150-
b_local[v] = b[v]
148+
A₋₂_local[v] = A₋₂[vi(v)]
149+
A₋₁_local[v] = A₋₁[vi(v)]
150+
A₀_local[v] = A₀[vi(v)]
151+
A₊₁_local[v] = A₊₁[vi(v)]
152+
A₊₂_local[v] = A₊₂[vi(v)]
153+
b_local[v] = b[vi(v)]
151154
end
152155
cache_local = (Ux_local, U₊₁_local, U₊₂_local)
153-
Aⱼs_local = (A₋₂, A₋₁, A₀, A₊₁, A₊₂)
154-
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local)
156+
Aⱼs_local = (A₋₂_local, A₋₁_local, A₀_local, A₊₁_local, A₊₂_local)
157+
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local, identity)
155158
@inbounds for v in 1:Nv
156-
x[v] = x_local[v]
159+
x[vi(v)] = x_local[v]
157160
end
158161
return nothing
159162
end
@@ -168,7 +171,7 @@ function band_matrix_solve_local_mem!(
168171
Nv = DataLayouts.nlevels(x)
169172
(A₀,) = Aⱼs
170173
@inbounds for v in 1:Nv
171-
x[v] = inv(A₀[v]) b[v]
174+
x[vindex(v)] = inv(A₀[vindex(v)]) b[vindex(v)]
172175
end
173176
return nothing
174177
end

ext/cuda/remapping_distributed.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function set_interpolated_values_kernel!(
5959
totalThreadsZ = gridDim().z * blockDim().z
6060

6161
_, Nq = size(I1)
62-
62+
CI = CartesianIndex
6363
for i in hindex:totalThreadsX:num_horiz
6464
h = local_horiz_indices[i]
6565
for j in vindex:totalThreadsY:num_vert
@@ -73,8 +73,8 @@ function set_interpolated_values_kernel!(
7373
I1[i, t] *
7474
I2[i, s] *
7575
(
76-
A * field_values[k][t, s, nothing, v_lo, h] +
77-
B * field_values[k][t, s, nothing, v_hi, h]
76+
A * field_values[k][CI(t, s, 1, v_lo, h)] +
77+
B * field_values[k][CI(t, s, 1, v_hi, h)]
7878
)
7979
end
8080
end
@@ -107,7 +107,7 @@ function set_interpolated_values_kernel!(
107107
totalThreadsZ = gridDim().z * blockDim().z
108108

109109
_, Nq = size(I)
110-
110+
CI = CartesianIndex
111111
for i in hindex:totalThreadsX:num_horiz
112112
h = local_horiz_indices[i]
113113
for j in vindex:totalThreadsY:num_vert
@@ -121,10 +121,8 @@ function set_interpolated_values_kernel!(
121121
I[i, t] *
122122
I[i, s] *
123123
(
124-
A *
125-
field_values[k][t, nothing, nothing, v_lo, h] +
126-
B *
127-
field_values[k][t, nothing, nothing, v_hi, h]
124+
A * field_values[k][CI(t, 1, 1, v_lo, h)] +
125+
B * field_values[k][CI(t, 1, 1, v_hi, h)]
128126
)
129127
end
130128
end
@@ -199,7 +197,7 @@ function set_interpolated_values_kernel!(
199197
out[i, k] +=
200198
I1[i, t] *
201199
I2[i, s] *
202-
field_values[k][t, s, nothing, nothing, h]
200+
field_values[k][CartesianIndex(t, s, 1, 1, h)]
203201
end
204202
end
205203
end
@@ -232,8 +230,7 @@ function set_interpolated_values_kernel!(
232230
out[i, k] = 0
233231
for t in 1:Nq, s in 1:Nq
234232
out[i, k] +=
235-
I[i, i] *
236-
field_values[k][t, nothing, nothing, nothing, h]
233+
I[i, i] * field_values[k][CartesianIndex(t, 1, 1, 1, h)]
237234
end
238235
end
239236
end

lib/ClimaCorePlots/src/ClimaCorePlots.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import RecipesBase
44
import TriplotBase
55

66
import ClimaComms
7+
import ClimaCore.DataLayouts: slab_index
78
import ClimaCore:
89
ClimaCore,
910
DataLayouts,
@@ -308,7 +309,7 @@ function _slice_along(field, coord)
308309
hdata = ClimaCore.slab(hcoord_data, hidx)
309310
hnode_idx = 1
310311
for i in axes(hdata)[axis]
311-
pt = axis == 1 ? hdata[i, 1] : hdata[1, i]
312+
pt = axis == 1 ? hdata[slab_index(i, 1)] : hdata[slab_index(1, i)]
312313
axis_value = Geometry.component(pt, axis)
313314
coord_value = Geometry.component(coord, 1)
314315
if axis_value > coord_value
@@ -353,8 +354,9 @@ function _slice_along(field, coord)
353354
islab = ClimaCore.slab(ortho_data, v, i)
354355
# copy the nodal data
355356
for ni in 1:size(islab)[1]
356-
islab[ni] =
357-
axis == 1 ? ijslab[hnode_idx, ni] : ijslab[ni, hnode_idx]
357+
islab[slab_index(ni)] =
358+
axis == 1 ? ijslab[slab_index(hnode_idx, ni)] :
359+
ijslab[slab_index(ni, hnode_idx)]
358360
end
359361
end
360362
end

lib/ClimaCoreTempestRemap/src/netcdf.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import CommonDataModel
22
import ClimaCore: slab, column
3+
import ClimaCore.DataLayouts: slab_index
34

45
"""
56
def_time_coord(nc::NCDataset, length=Inf, eltype=Float64;
@@ -97,7 +98,7 @@ function def_space_coord(
9798
coords = Spaces.coordinates_data(space)
9899

99100
for (col, ((i, j), e)) in enumerate(nodes)
100-
coord = slab(coords, e)[i, j]
101+
coord = slab(coords, e)[slab_index(i, j)]
101102
X[col] = coord.x
102103
Y[col] = coord.y
103104
end
@@ -149,7 +150,7 @@ function def_space_coord(
149150
coords = Spaces.coordinates_data(space)
150151

151152
for (col, ((i, j), e)) in enumerate(nodes)
152-
coord = slab(coords, e)[i, j]
153+
coord = slab(coords, e)[slab_index(i, j)]
153154
lon[col] = coord.long
154155
lat[col] = coord.lat
155156
end
@@ -328,7 +329,7 @@ function Base.setindex!(
328329
end
329330
data = Fields.field_values(field)
330331
for (col, ((i, j), e)) in enumerate(nodes)
331-
var[col, extraidx...] = slab(data, e)[i, j]
332+
var[col, extraidx...] = slab(data, e)[slab_index(i, j)]
332333
end
333334
return var
334335
end

0 commit comments

Comments
 (0)