@@ -399,10 +399,15 @@ function _Remapper(
399
399
)
400
400
num_dims = num_hdims
401
401
else
402
+ cpu_space = if ClimaComms. device (space) isa AbstractCPUDevice
403
+ space
404
+ else
405
+ to_cpu (space)
406
+ end
402
407
vert_interpolation_weights =
403
- ArrayType (vertical_interpolation_weights (space , target_zcoords))
408
+ ArrayType (vertical_interpolation_weights (cpu_space , target_zcoords))
404
409
vert_bounding_indices =
405
- ArrayType (vertical_bounding_indices (space , target_zcoords))
410
+ ArrayType (vertical_bounding_indices (cpu_space , target_zcoords))
406
411
407
412
# We have to add one extra dimension with respect to the bitmask/local_horiz_indices
408
413
# because we are going to store the values for the columns
@@ -463,10 +468,16 @@ function _Remapper(
463
468
FT = Spaces. undertype (space)
464
469
ArrayType = ClimaComms. array_type (space)
465
470
471
+ cpu_space = if ClimaComms. device (space) isa AbstractCPUDevice
472
+ space
473
+ else
474
+ to_cpu (space)
475
+ end
476
+
466
477
vert_interpolation_weights =
467
- ArrayType (vertical_interpolation_weights (space , target_zcoords))
478
+ ArrayType (vertical_interpolation_weights (cpu_space , target_zcoords))
468
479
vert_bounding_indices =
469
- ArrayType (vertical_bounding_indices (space , target_zcoords))
480
+ ArrayType (vertical_bounding_indices (cpu_space , target_zcoords))
470
481
471
482
local_interpolated_values =
472
483
ArrayType (zeros (FT, (length (target_zcoords), buffer_length)))
0 commit comments