Skip to content

Commit a27735e

Browse files
committed
Remove all restrictions
1 parent a7eb55e commit a27735e

File tree

3 files changed

+135
-187
lines changed

3 files changed

+135
-187
lines changed

NEWS.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ ClimaCore.jl Release Notes
44
main
55
-------
66

7-
- ![][badge-🤖precisionΔ] `Remapper`s can now process multiple `Field`s at the same time if created with some `buffer_lenght > 1`.
8-
PR ([#1669](https://github.com/CliMA/ClimaCore.jl/pull/1669)) Machine-precision differences are expected.
7+
- ![][badge-🤖precisionΔ] ![][badge-💥breaking] `Remapper`s can now process
8+
multiple `Field`s at the same time if created with some `buffer_lenght > 1`.
9+
PR ([#1669](https://github.com/CliMA/ClimaCore.jl/pull/1669))
10+
Machine-precision differences are expected. This change is breaking because
11+
remappers now return the same array type as the input field.
912

1013
v0.13.4
1114
-------

src/Remapping/distributed_remapping.jl

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -795,30 +795,21 @@ function _collect_and_return_interpolated_values!(
795795
remapper::Remapper,
796796
num_fields::Int,
797797
)
798-
output_array = ClimaComms.reduce(
798+
return ClimaComms.reduce(
799799
remapper.comms_ctx,
800800
remapper._interpolated_values[remapper.colons..., 1:num_fields],
801801
+,
802802
)
803-
804-
maybe_copy_to_cpu =
805-
ClimaComms.device(remapper.comms_ctx) isa ClimaComms.CUDADevice ?
806-
Array : identity
807-
808-
return ClimaComms.iamroot(remapper.comms_ctx) ?
809-
maybe_copy_to_cpu(output_array) : nothing
810803
end
811804

812805
function _collect_interpolated_values!(
813806
dest,
814807
remapper::Remapper,
815808
index_field_begin::Int,
816-
index_field_end::Int,
809+
index_field_end::Int;
810+
only_one_field,
817811
)
818812

819-
num_fields = 1 + index_field_end - index_field_begin
820-
only_one_field = num_fields == 1
821-
822813
if only_one_field
823814
ClimaComms.reduce!(
824815
remapper.comms_ctx,
@@ -829,18 +820,12 @@ function _collect_interpolated_values!(
829820
return nothing
830821
end
831822

832-
# CUDA.jl does not support views very well at the moment. We can only work with
833-
# num_fields = buffer_length
834-
num_fields == remapper.buffer_length ||
835-
error("Operation not currently supported")
823+
num_fields = 1 + index_field_end - index_field_begin
836824

837-
# MPI.reduce! seems to behave nicely with respect to CPU/GPU. In particular,
838-
# if the destination is on the CPU, but the source is on the GPU, the values
839-
# are automatically moved.
840825
ClimaComms.reduce!(
841826
remapper.comms_ctx,
842-
remapper._interpolated_values,
843-
dest,
827+
view(remapper._interpolated_values, remapper.colons..., 1:num_fields),
828+
view(dest, remapper.colons..., index_field_begin:index_field_end),
844829
+,
845830
)
846831

@@ -882,6 +867,9 @@ to be defined on the root process and to be `nothing` for the other processes.
882867
Note: `interpolate` allocates new arrays and has some internal type-instability,
883868
`interpolate!` is non-allocating and type-stable.
884869
870+
When using `interpolate!`, the `dest`ination has to be the same array type as the
871+
device in use (e.g., `CuArray` for CUDA runs).
872+
885873
Example
886874
========
887875
@@ -975,6 +963,14 @@ function interpolate!(
975963
dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error(
976964
"Destination array is not compatible with remapper (size mismatch)",
977965
)
966+
967+
expected_array_type =
968+
ClimaComms.array_type(ClimaComms.device(remapper.comms_ctx))
969+
970+
found_type = nameof(typeof(dest))
971+
972+
dest isa expected_array_type ||
973+
error("dest is a $found_type, expected $expected_array_type")
978974
end
979975
index_field_begin, index_field_end =
980976
1, min(length(fields), remapper.buffer_length)
@@ -999,7 +995,8 @@ function interpolate!(
999995
dest,
1000996
remapper,
1001997
index_field_begin,
1002-
index_field_end,
998+
index_field_end;
999+
only_one_field,
10031000
)
10041001

10051002
index_field_end != length(fields) || break

0 commit comments

Comments
 (0)