@@ -795,30 +795,21 @@ function _collect_and_return_interpolated_values!(
795
795
remapper:: Remapper ,
796
796
num_fields:: Int ,
797
797
)
798
- output_array = ClimaComms. reduce (
798
+ return ClimaComms. reduce (
799
799
remapper. comms_ctx,
800
800
remapper. _interpolated_values[remapper. colons... , 1 : num_fields],
801
801
+ ,
802
802
)
803
-
804
- maybe_copy_to_cpu =
805
- ClimaComms. device (remapper. comms_ctx) isa ClimaComms. CUDADevice ?
806
- Array : identity
807
-
808
- return ClimaComms. iamroot (remapper. comms_ctx) ?
809
- maybe_copy_to_cpu (output_array) : nothing
810
803
end
811
804
812
805
function _collect_interpolated_values! (
813
806
dest,
814
807
remapper:: Remapper ,
815
808
index_field_begin:: Int ,
816
- index_field_end:: Int ,
809
+ index_field_end:: Int ;
810
+ only_one_field,
817
811
)
818
812
819
- num_fields = 1 + index_field_end - index_field_begin
820
- only_one_field = num_fields == 1
821
-
822
813
if only_one_field
823
814
ClimaComms. reduce! (
824
815
remapper. comms_ctx,
@@ -829,18 +820,12 @@ function _collect_interpolated_values!(
829
820
return nothing
830
821
end
831
822
832
- # CUDA.jl does not support views very well at the moment. We can only work with
833
- # num_fields = buffer_length
834
- num_fields == remapper. buffer_length ||
835
- error (" Operation not currently supported" )
823
+ num_fields = 1 + index_field_end - index_field_begin
836
824
837
- # MPI.reduce! seems to behave nicely with respect to CPU/GPU. In particular,
838
- # if the destination is on the CPU, but the source is on the GPU, the values
839
- # are automatically moved.
840
825
ClimaComms. reduce! (
841
826
remapper. comms_ctx,
842
- remapper. _interpolated_values,
843
- dest,
827
+ view ( remapper. _interpolated_values, remapper . colons ... , 1 : num_fields) ,
828
+ view ( dest, remapper . colons ... , index_field_begin : index_field_end) ,
844
829
+ ,
845
830
)
846
831
@@ -882,6 +867,9 @@ to be defined on the root process and to be `nothing` for the other processes.
882
867
Note: `interpolate` allocates new arrays and has some internal type-instability,
883
868
`interpolate!` is non-allocating and type-stable.
884
869
870
+ When using `interpolate!`, the `dest`ination has to be the same array type as the
871
+ device in use (e.g., `CuArray` for CUDA runs).
872
+
885
873
Example
886
874
========
887
875
@@ -975,6 +963,14 @@ function interpolate!(
975
963
dest_size == size (remapper. _interpolated_values)[1 : (end - 1 )] || error (
976
964
" Destination array is not compatible with remapper (size mismatch)" ,
977
965
)
966
+
967
+ expected_array_type =
968
+ ClimaComms. array_type (ClimaComms. device (remapper. comms_ctx))
969
+
970
+ found_type = nameof (typeof (dest))
971
+
972
+ dest isa expected_array_type ||
973
+ error (" dest is a $found_type , expected $expected_array_type " )
978
974
end
979
975
index_field_begin, index_field_end =
980
976
1 , min (length (fields), remapper. buffer_length)
@@ -999,7 +995,8 @@ function interpolate!(
999
995
dest,
1000
996
remapper,
1001
997
index_field_begin,
1002
- index_field_end,
998
+ index_field_end;
999
+ only_one_field,
1003
1000
)
1004
1001
1005
1002
index_field_end != length (fields) || break
0 commit comments