@@ -181,13 +181,13 @@ struct Remapper{
181
181
182
182
# Scratch space where we save the process-local field value. We keep overwriting this to
183
183
# avoid extra allocations. Ideally, we wouldn't need this and we would use views for
184
- # everything. This has dimensions (buffer_length, Nq ) or (buffer_length , Nq, Nq )
184
+ # everything. This has dimensions (Nq, ) or (Nq , Nq, )
185
185
# depending if the horizontal space is 1D or 2D.
186
186
_field_values:: T9
187
187
188
188
# Storage area where the interpolated values are saved. This is meaningful only for the
189
189
# root process and gets filled by a interpolate call. This has dimensions
190
- # (buffer_length, H, V), where H is the size of target_hcoords and V of target_zcoords.
190
+ # (H, V, buffer_length ), where H is the size of target_hcoords and V of target_zcoords.
191
191
# In other words, this is the expected output array.
192
192
_interpolated_values:: T10
193
193
@@ -308,9 +308,9 @@ function Remapper(
308
308
vert_interpolation_weights = nothing
309
309
vert_bounding_indices = nothing
310
310
local_interpolated_values =
311
- ArrayType (zeros (FT, (buffer_length, size (local_horiz_indices)... )))
311
+ ArrayType (zeros (FT, (size (local_horiz_indices)... , buffer_length )))
312
312
interpolated_values = ArrayType (
313
- zeros (FT, (buffer_length, size (local_target_hcoords_bitmask)... )),
313
+ zeros (FT, (size (local_target_hcoords_bitmask)... , buffer_length )),
314
314
)
315
315
num_dims = num_hdims
316
316
else
@@ -325,19 +325,19 @@ function Remapper(
325
325
zeros (
326
326
FT,
327
327
(
328
- buffer_length,
329
328
size (local_horiz_indices)... ,
330
329
length (target_zcoords),
330
+ buffer_length,
331
331
),
332
332
),
333
333
)
334
334
interpolated_values = ArrayType (
335
335
zeros (
336
336
FT,
337
337
(
338
- buffer_length,
339
338
size (local_target_hcoords_bitmask)... ,
340
339
length (target_zcoords),
340
+ buffer_length,
341
341
),
342
342
),
343
343
)
@@ -433,7 +433,7 @@ function set_interpolated_values_cpu_kernel!(
433
433
I2[out_index, j] *
434
434
scratch_field_values[i, j]
435
435
end
436
- out[field_index, out_index, vindex] = tmp
436
+ out[out_index, vindex, field_index ] = tmp
437
437
end
438
438
end
439
439
end
@@ -470,9 +470,9 @@ function set_interpolated_values_kernel!(
470
470
A, B = vert_interpolation_weights[j]
471
471
for k in findex: totalThreadsZ: num_fields
472
472
if i ≤ num_horiz && j ≤ num_vert && k ≤ num_fields
473
- out[k, i, j] = 0
473
+ out[i, j, k ] = 0
474
474
for t in 1 : Nq, s in 1 : Nq
475
- out[k, i, j] +=
475
+ out[i, j, k ] +=
476
476
I1[i, t] *
477
477
I2[i, s] *
478
478
(
@@ -509,7 +509,7 @@ function set_interpolated_values_kernel!(
509
509
totalThreadsY = gridDim (). y * blockDim (). y
510
510
totalThreadsZ = gridDim (). z * blockDim (). z
511
511
512
- _, Nq = size (I1 )
512
+ _, Nq = size (I )
513
513
514
514
for i in hindex: totalThreadsX: num_horiz
515
515
h = local_horiz_indices[i]
@@ -518,11 +518,11 @@ function set_interpolated_values_kernel!(
518
518
A, B = vert_interpolation_weights[j]
519
519
for k in findex: totalThreadsZ: num_fields
520
520
if i ≤ num_horiz && j ≤ num_vert && k ≤ num_fields
521
- out[k, i, j] = 0
521
+ out[i, j, k ] = 0
522
522
for t in 1 : Nq
523
- out[k, i, j] +=
524
- I1 [i, t] *
525
- I2 [i, s] *
523
+ out[i, j, k ] +=
524
+ I [i, t] *
525
+ I [i, s] *
526
526
(
527
527
A *
528
528
field_values[k][t, nothing , nothing , v_lo, h] +
@@ -577,7 +577,7 @@ function set_interpolated_values_cpu_kernel!(
577
577
for i in 1 : Nq
578
578
tmp += I[out_index, i] * scratch_field_values[i]
579
579
end
580
- out[field_index, out_index, vindex] = tmp
580
+ out[out_index, vindex, field_index ] = tmp
581
581
end
582
582
end
583
583
end
@@ -652,17 +652,17 @@ function _set_interpolated_values!(
652
652
for (field_index, field) in enumerate (fields)
653
653
field_values = Fields. field_values (field)
654
654
for (out_index, h) in enumerate (local_horiz_indices)
655
- out[field_index, out_index ] = zero (FT)
655
+ out[out_index, field_index ] = zero (FT)
656
656
if hdims == 2
657
657
for j in 1 : Nq, i in 1 : Nq
658
- out[field_index, out_index ] +=
658
+ out[out_index, field_index ] +=
659
659
local_horiz_interpolation_weights[1 ][out_index, i] *
660
660
local_horiz_interpolation_weights[2 ][out_index, j] *
661
661
field_values[i, j, nothing , nothing , h]
662
662
end
663
663
elseif hdims == 1
664
664
for i in 1 : Nq
665
- out[field_index, out_index ] +=
665
+ out[out_index, field_index ] +=
666
666
local_horiz_interpolation_weights[1 ][out_index, i] *
667
667
field_values[i, nothing , nothing , nothing , h]
668
668
end
@@ -694,9 +694,9 @@ function set_interpolated_values_kernel!(
694
694
h = local_horiz_indices[i]
695
695
for k in findex: totalThreadsZ: num_fields
696
696
if i ≤ num_horiz && k ≤ num_fields
697
- out[k, i ] = 0
697
+ out[i, k ] = 0
698
698
for t in 1 : Nq, s in 1 : Nq
699
- out[k, i ] +=
699
+ out[i, k ] +=
700
700
I1[i, t] *
701
701
I2[i, s] *
702
702
field_values[k][t, s, nothing , nothing , h]
@@ -729,9 +729,9 @@ function set_interpolated_values_kernel!(
729
729
h = local_horiz_indices[i]
730
730
for k in findex: totalThreadsZ: num_fields
731
731
if i ≤ num_horiz && k ≤ num_fields
732
- out[k, i ] = 0
732
+ out[i, k ] = 0
733
733
for t in 1 : Nq, s in 1 : Nq
734
- out[k, i ] +=
734
+ out[i, k ] +=
735
735
I[i, i] *
736
736
field_values[k][t, nothing , nothing , nothing , h]
737
737
end
@@ -753,16 +753,20 @@ around according to MPI-ownership and the expected output shape.
753
753
`interpolated_values`. We assume that it is always the first `num_fields` that have to be moved.
754
754
"""
755
755
function _apply_mpi_bitmask! (remapper:: Remapper , num_fields:: Int )
756
- view (
757
- remapper. _interpolated_values,
758
- 1 : num_fields,
759
- remapper. local_target_hcoords_bitmask,
760
- :,
761
- ) .= view (
762
- remapper. _local_interpolated_values,
763
- 1 : num_fields,
764
- remapper. colons... ,
765
- )
756
+ if isnothing (remapper. target_zcoords)
757
+ view (
758
+ remapper. _interpolated_values,
759
+ remapper. local_target_hcoords_bitmask,
760
+ 1 : num_fields,
761
+ ) .= view (remapper. _local_interpolated_values, :, 1 : num_fields)
762
+ else
763
+ view (
764
+ remapper. _interpolated_values,
765
+ remapper. local_target_hcoords_bitmask,
766
+ :,
767
+ 1 : num_fields,
768
+ ) .= view (remapper. _local_interpolated_values, :, :, 1 : num_fields)
769
+ end
766
770
end
767
771
768
772
"""
@@ -793,7 +797,7 @@ function _collect_and_return_interpolated_values!(
793
797
)
794
798
output_array = ClimaComms. reduce (
795
799
remapper. comms_ctx,
796
- remapper. _interpolated_values[1 : num_fields, remapper. colons... ],
800
+ remapper. _interpolated_values[remapper. colons... , 1 : num_fields ],
797
801
+ ,
798
802
)
799
803
@@ -818,7 +822,7 @@ function _collect_interpolated_values!(
818
822
if only_one_field
819
823
ClimaComms. reduce! (
820
824
remapper. comms_ctx,
821
- remapper. _interpolated_values[1 , remapper. colons... ],
825
+ remapper. _interpolated_values[remapper. colons... , begin ],
822
826
dest,
823
827
+ ,
824
828
)
@@ -922,7 +926,9 @@ function interpolate(remapper::Remapper, fields)
922
926
# buffer_length
923
927
index_ranges = batched_ranges (length (fields), remapper. buffer_length)
924
928
925
- interpolated_values = mapreduce (vcat, index_ranges) do range
929
+ cat_fn = (l... ) -> cat (l... , dims = length (remapper. colons) + 1 )
930
+
931
+ interpolated_values = mapreduce (cat_fn, index_ranges) do range
926
932
num_fields = length (range)
927
933
928
934
# Reset interpolated_values. This is needed because we collect distributed results
@@ -938,13 +944,14 @@ function interpolate(remapper::Remapper, fields)
938
944
# Finally, we have to send all the _interpolated_values to root and sum them up to
939
945
# obtain the final answer. Only the root will contain something useful. This also
940
946
# moves the data off the GPU
941
- return _collect_and_return_interpolated_values! (remapper, num_fields)
947
+ ret = _collect_and_return_interpolated_values! (remapper, num_fields)
948
+ return ret
942
949
end
943
950
944
951
# Non-root processes
945
952
isnothing (interpolated_values) && return nothing
946
953
947
- return only_one_field ? interpolated_values[begin , remapper. colons... ] :
954
+ return only_one_field ? interpolated_values[remapper. colons... , begin ] :
948
955
interpolated_values
949
956
end
950
957
@@ -963,9 +970,9 @@ function interpolate!(
963
970
if ! isnothing (dest)
964
971
# !isnothing(dest) means that this is the root process, in this case, the size have
965
972
# to match (ignoring the buffer_length)
966
- dest_size = only_one_field ? size (dest) : size (dest)[2 : end ]
973
+ dest_size = only_one_field ? size (dest) : size (dest)[1 : ( end - 1 ) ]
967
974
968
- dest_size == size (remapper. _interpolated_values)[2 : end ] || error (
975
+ dest_size == size (remapper. _interpolated_values)[1 : ( end - 1 ) ] || error (
969
976
" Destination array is not compatible with remapper (size mismatch)" ,
970
977
)
971
978
end
0 commit comments