Skip to content

Commit b54dad5

Browse files
Define CuArray on FieldArrays
1 parent 2d8bb29 commit b54dad5

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

ext/cuda/data_layouts.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,17 @@ function Adapt.adapt_structure(
7777
Adapt.adapt(to, bc.axes),
7878
)
7979
end
80+
81+
import ClimaCore.DataLayouts as DL
82+
import CUDA
83+
function CUDA.CuArray(fa::DL.FieldArray{FD}) where {FD}
84+
arrays = ntuple(Val(DL.ncomponents(fa))) do f
85+
CUDA.CuArray(fa.arrays[f])
86+
end
87+
return DL.FieldArray{FD}(arrays)
88+
end
89+
90+
DL.field_array(
91+
array::CUDA.CuArray,
92+
as::ArraySize
93+
) = CUDA.CuArray(DL.field_array(Array(array), as))

ext/cuda/data_layouts_copyto.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ function cuda_copyto!(dest::AbstractData, bc)
7070
auto_launch!(
7171
knl_copyto_linear!,
7272
(dest, bc′, us),
73-
nitems;
73+
n;
7474
auto = true,
7575
)
7676
else
77-
auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true)
77+
auto_launch!(knl_copyto_flat!, (dest, bc, us), n; auto = true)
7878
end
7979
end
8080
return dest

src/DataLayouts/DataLayouts.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,14 +401,19 @@ end
401401

402402
Base.length(data::IJFH) = get_Nh(data)
403403

404-
Base.@propagate_inbounds slab(data::IJFH, h::Integer) = slab(data, 1, h)
405-
406404
@inline function slab(data::IJFH{S, Nij}, v::Integer, h::Integer) where {S, Nij}
407405
@boundscheck (v >= 1 && 1 <= h <= get_Nh(data)) ||
408406
throw(BoundsError(data, (v, h)))
409-
slab(data, h)
407+
fa = field_array(data)
408+
sub_arrays = ntuple(Val(ncomponents(fa))) do jf
409+
view(fa.arrays[jf], :, :, h)
410+
end
411+
dataview = FieldArray{field_dim(IJF)}(sub_arrays)
412+
IJF{S, Nij, typeof(dataview)}(dataview)
410413
end
411414

415+
Base.@propagate_inbounds slab(data::IJFH, h::Integer) = slab(data, 1, h)
416+
412417
@inline function column(data::IJFH{S, Nij}, i, j, h) where {S, Nij}
413418
@boundscheck (1 <= j <= Nij && 1 <= i <= Nij && 1 <= h <= get_Nh(data)) ||
414419
throw(BoundsError(data, (i, j, h)))

0 commit comments

Comments
 (0)