Skip to content

Commit a7eb55e

Browse files
committed
Move field index to last
1 parent 38ffafd commit a7eb55e

File tree

2 files changed

+130
-123
lines changed

2 files changed

+130
-123
lines changed

src/Remapping/distributed_remapping.jl

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,13 @@ struct Remapper{
181181

182182
# Scratch space where we save the process-local field value. We keep overwriting this to
183183
# avoid extra allocations. Ideally, we wouldn't need this and we would use views for
184-
# everything. This has dimensions (buffer_length, Nq) or (buffer_length, Nq, Nq)
184+
# everything. This has dimensions (Nq, ) or (Nq, Nq, )
185185
# depending if the horizontal space is 1D or 2D.
186186
_field_values::T9
187187

188188
# Storage area where the interpolated values are saved. This is meaningful only for the
189189
# root process and gets filled by a interpolate call. This has dimensions
190-
# (buffer_length, H, V), where H is the size of target_hcoords and V of target_zcoords.
190+
# (H, V, buffer_length), where H is the size of target_hcoords and V of target_zcoords.
191191
# In other words, this is the expected output array.
192192
_interpolated_values::T10
193193

@@ -308,9 +308,9 @@ function Remapper(
308308
vert_interpolation_weights = nothing
309309
vert_bounding_indices = nothing
310310
local_interpolated_values =
311-
ArrayType(zeros(FT, (buffer_length, size(local_horiz_indices)...)))
311+
ArrayType(zeros(FT, (size(local_horiz_indices)..., buffer_length)))
312312
interpolated_values = ArrayType(
313-
zeros(FT, (buffer_length, size(local_target_hcoords_bitmask)...)),
313+
zeros(FT, (size(local_target_hcoords_bitmask)..., buffer_length)),
314314
)
315315
num_dims = num_hdims
316316
else
@@ -325,19 +325,19 @@ function Remapper(
325325
zeros(
326326
FT,
327327
(
328-
buffer_length,
329328
size(local_horiz_indices)...,
330329
length(target_zcoords),
330+
buffer_length,
331331
),
332332
),
333333
)
334334
interpolated_values = ArrayType(
335335
zeros(
336336
FT,
337337
(
338-
buffer_length,
339338
size(local_target_hcoords_bitmask)...,
340339
length(target_zcoords),
340+
buffer_length,
341341
),
342342
),
343343
)
@@ -433,7 +433,7 @@ function set_interpolated_values_cpu_kernel!(
433433
I2[out_index, j] *
434434
scratch_field_values[i, j]
435435
end
436-
out[field_index, out_index, vindex] = tmp
436+
out[out_index, vindex, field_index] = tmp
437437
end
438438
end
439439
end
@@ -470,9 +470,9 @@ function set_interpolated_values_kernel!(
470470
A, B = vert_interpolation_weights[j]
471471
for k in findex:totalThreadsZ:num_fields
472472
if i num_horiz && j num_vert && k num_fields
473-
out[k, i, j] = 0
473+
out[i, j, k] = 0
474474
for t in 1:Nq, s in 1:Nq
475-
out[k, i, j] +=
475+
out[i, j, k] +=
476476
I1[i, t] *
477477
I2[i, s] *
478478
(
@@ -509,7 +509,7 @@ function set_interpolated_values_kernel!(
509509
totalThreadsY = gridDim().y * blockDim().y
510510
totalThreadsZ = gridDim().z * blockDim().z
511511

512-
_, Nq = size(I1)
512+
_, Nq = size(I)
513513

514514
for i in hindex:totalThreadsX:num_horiz
515515
h = local_horiz_indices[i]
@@ -518,11 +518,11 @@ function set_interpolated_values_kernel!(
518518
A, B = vert_interpolation_weights[j]
519519
for k in findex:totalThreadsZ:num_fields
520520
if i num_horiz && j num_vert && k num_fields
521-
out[k, i, j] = 0
521+
out[i, j, k] = 0
522522
for t in 1:Nq
523-
out[k, i, j] +=
524-
I1[i, t] *
525-
I2[i, s] *
523+
out[i, j, k] +=
524+
I[i, t] *
525+
I[i, s] *
526526
(
527527
A *
528528
field_values[k][t, nothing, nothing, v_lo, h] +
@@ -577,7 +577,7 @@ function set_interpolated_values_cpu_kernel!(
577577
for i in 1:Nq
578578
tmp += I[out_index, i] * scratch_field_values[i]
579579
end
580-
out[field_index, out_index, vindex] = tmp
580+
out[out_index, vindex, field_index] = tmp
581581
end
582582
end
583583
end
@@ -652,17 +652,17 @@ function _set_interpolated_values!(
652652
for (field_index, field) in enumerate(fields)
653653
field_values = Fields.field_values(field)
654654
for (out_index, h) in enumerate(local_horiz_indices)
655-
out[field_index, out_index] = zero(FT)
655+
out[out_index, field_index] = zero(FT)
656656
if hdims == 2
657657
for j in 1:Nq, i in 1:Nq
658-
out[field_index, out_index] +=
658+
out[out_index, field_index] +=
659659
local_horiz_interpolation_weights[1][out_index, i] *
660660
local_horiz_interpolation_weights[2][out_index, j] *
661661
field_values[i, j, nothing, nothing, h]
662662
end
663663
elseif hdims == 1
664664
for i in 1:Nq
665-
out[field_index, out_index] +=
665+
out[out_index, field_index] +=
666666
local_horiz_interpolation_weights[1][out_index, i] *
667667
field_values[i, nothing, nothing, nothing, h]
668668
end
@@ -694,9 +694,9 @@ function set_interpolated_values_kernel!(
694694
h = local_horiz_indices[i]
695695
for k in findex:totalThreadsZ:num_fields
696696
if i num_horiz && k num_fields
697-
out[k, i] = 0
697+
out[i, k] = 0
698698
for t in 1:Nq, s in 1:Nq
699-
out[k, i] +=
699+
out[i, k] +=
700700
I1[i, t] *
701701
I2[i, s] *
702702
field_values[k][t, s, nothing, nothing, h]
@@ -729,9 +729,9 @@ function set_interpolated_values_kernel!(
729729
h = local_horiz_indices[i]
730730
for k in findex:totalThreadsZ:num_fields
731731
if i num_horiz && k num_fields
732-
out[k, i] = 0
732+
out[i, k] = 0
733733
for t in 1:Nq, s in 1:Nq
734-
out[k, i] +=
734+
out[i, k] +=
735735
I[i, i] *
736736
field_values[k][t, nothing, nothing, nothing, h]
737737
end
@@ -753,16 +753,20 @@ around according to MPI-ownership and the expected output shape.
753753
`interpolated_values`. We assume that it is always the first `num_fields` that have to be moved.
754754
"""
755755
function _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int)
756-
view(
757-
remapper._interpolated_values,
758-
1:num_fields,
759-
remapper.local_target_hcoords_bitmask,
760-
:,
761-
) .= view(
762-
remapper._local_interpolated_values,
763-
1:num_fields,
764-
remapper.colons...,
765-
)
756+
if isnothing(remapper.target_zcoords)
757+
view(
758+
remapper._interpolated_values,
759+
remapper.local_target_hcoords_bitmask,
760+
1:num_fields,
761+
) .= view(remapper._local_interpolated_values, :, 1:num_fields)
762+
else
763+
view(
764+
remapper._interpolated_values,
765+
remapper.local_target_hcoords_bitmask,
766+
:,
767+
1:num_fields,
768+
) .= view(remapper._local_interpolated_values, :, :, 1:num_fields)
769+
end
766770
end
767771

768772
"""
@@ -793,7 +797,7 @@ function _collect_and_return_interpolated_values!(
793797
)
794798
output_array = ClimaComms.reduce(
795799
remapper.comms_ctx,
796-
remapper._interpolated_values[1:num_fields, remapper.colons...],
800+
remapper._interpolated_values[remapper.colons..., 1:num_fields],
797801
+,
798802
)
799803

@@ -818,7 +822,7 @@ function _collect_interpolated_values!(
818822
if only_one_field
819823
ClimaComms.reduce!(
820824
remapper.comms_ctx,
821-
remapper._interpolated_values[1, remapper.colons...],
825+
remapper._interpolated_values[remapper.colons..., begin],
822826
dest,
823827
+,
824828
)
@@ -922,7 +926,9 @@ function interpolate(remapper::Remapper, fields)
922926
# buffer_length
923927
index_ranges = batched_ranges(length(fields), remapper.buffer_length)
924928

925-
interpolated_values = mapreduce(vcat, index_ranges) do range
929+
cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1)
930+
931+
interpolated_values = mapreduce(cat_fn, index_ranges) do range
926932
num_fields = length(range)
927933

928934
# Reset interpolated_values. This is needed because we collect distributed results
@@ -938,13 +944,14 @@ function interpolate(remapper::Remapper, fields)
938944
# Finally, we have to send all the _interpolated_values to root and sum them up to
939945
# obtain the final answer. Only the root will contain something useful. This also
940946
# moves the data off the GPU
941-
return _collect_and_return_interpolated_values!(remapper, num_fields)
947+
ret = _collect_and_return_interpolated_values!(remapper, num_fields)
948+
return ret
942949
end
943950

944951
# Non-root processes
945952
isnothing(interpolated_values) && return nothing
946953

947-
return only_one_field ? interpolated_values[begin, remapper.colons...] :
954+
return only_one_field ? interpolated_values[remapper.colons..., begin] :
948955
interpolated_values
949956
end
950957

@@ -963,9 +970,9 @@ function interpolate!(
963970
if !isnothing(dest)
964971
# !isnothing(dest) means that this is the root process, in this case, the size have
965972
# to match (ignoring the buffer_length)
966-
dest_size = only_one_field ? size(dest) : size(dest)[2:end]
973+
dest_size = only_one_field ? size(dest) : size(dest)[1:(end - 1)]
967974

968-
dest_size == size(remapper._interpolated_values)[2:end] || error(
975+
dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error(
969976
"Destination array is not compatible with remapper (size mismatch)",
970977
)
971978
end

0 commit comments

Comments
 (0)