Skip to content

Commit 3d0c9ca

Browse files
Avoid gpu-incompatible remapper logic
1 parent 9d13496 commit 3d0c9ca

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/Remapping/Remapping.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import ..DataLayouts,
1616
..Hypsography
1717
import ClimaCore.Utilities: half
1818
import ClimaCore.Spaces: cuda_synchronize
19+
import ..to_cpu
1920

2021
using ..RecursiveApply
2122

src/Remapping/distributed_remapping.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,15 @@ function _Remapper(
399399
)
400400
num_dims = num_hdims
401401
else
402+
cpu_space = if ClimaComms.device(space) isa ClimaComms.AbstractCPUDevice
403+
space
404+
else
405+
to_cpu(space)
406+
end
402407
vert_interpolation_weights =
403-
ArrayType(vertical_interpolation_weights(space, target_zcoords))
408+
ArrayType(vertical_interpolation_weights(cpu_space, target_zcoords))
404409
vert_bounding_indices =
405-
ArrayType(vertical_bounding_indices(space, target_zcoords))
410+
ArrayType(vertical_bounding_indices(cpu_space, target_zcoords))
406411

407412
# We have to add one extra dimension with respect to the bitmask/local_horiz_indices
408413
# because we are going to store the values for the columns
@@ -463,10 +468,16 @@ function _Remapper(
463468
FT = Spaces.undertype(space)
464469
ArrayType = ClimaComms.array_type(space)
465470

471+
cpu_space = if ClimaComms.device(space) isa ClimaComms.AbstractCPUDevice
472+
space
473+
else
474+
to_cpu(space)
475+
end
476+
466477
vert_interpolation_weights =
467-
ArrayType(vertical_interpolation_weights(space, target_zcoords))
478+
ArrayType(vertical_interpolation_weights(cpu_space, target_zcoords))
468479
vert_bounding_indices =
469-
ArrayType(vertical_bounding_indices(space, target_zcoords))
480+
ArrayType(vertical_bounding_indices(cpu_space, target_zcoords))
470481

471482
local_interpolated_values =
472483
ArrayType(zeros(FT, (length(target_zcoords), buffer_length)))

0 commit comments

Comments
 (0)