Skip to content

Commit 778b91f

Browse files
Define CuArray on FieldArrays
1 parent 960d42d commit 778b91f

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

ext/cuda/data_layouts.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,17 @@ function Adapt.adapt_structure(
7878
Adapt.adapt(to, bc.axes),
7979
)
8080
end
81+
82+
import ClimaCore.DataLayouts as DL
83+
import CUDA
84+
function CUDA.CuArray(fa::DL.FieldArray{FD}) where {FD}
85+
arrays = ntuple(Val(DL.ncomponents(fa))) do f
86+
CUDA.CuArray(fa.arrays[f])
87+
end
88+
return DL.FieldArray{FD}(arrays)
89+
end
90+
91+
DL.field_array(
92+
array::CUDA.CuArray,
93+
as::ArraySize
94+
) = CUDA.CuArray(DL.field_array(Array(array), as))

ext/cuda/data_layouts_copyto.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ function cuda_copyto!(dest::AbstractData, bc)
6969
auto_launch!(
7070
knl_copyto_linear!,
7171
(dest, bc′, us),
72-
nitems;
72+
n;
7373
auto = true,
7474
)
7575
else
76-
auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true)
76+
auto_launch!(knl_copyto_flat!, (dest, bc, us), n; auto = true)
7777
end
7878
end
7979
return dest

src/DataLayouts/DataLayouts.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,14 +401,19 @@ end
401401

402402
Base.length(data::IJFH) = get_Nh(data)
403403

404-
Base.@propagate_inbounds slab(data::IJFH, h::Integer) = slab(data, 1, h)
405-
406404
@inline function slab(data::IJFH{S, Nij}, v::Integer, h::Integer) where {S, Nij}
407405
@boundscheck (v >= 1 && 1 <= h <= get_Nh(data)) ||
408406
throw(BoundsError(data, (v, h)))
409-
slab(data, h)
407+
fa = field_array(data)
408+
sub_arrays = ntuple(Val(ncomponents(fa))) do jf
409+
view(fa.arrays[jf], :, :, h)
410+
end
411+
dataview = FieldArray{field_dim(IJF)}(sub_arrays)
412+
IJF{S, Nij, typeof(dataview)}(dataview)
410413
end
411414

415+
Base.@propagate_inbounds slab(data::IJFH, h::Integer) = slab(data, 1, h)
416+
412417
@inline function column(data::IJFH{S, Nij}, i, j, h) where {S, Nij}
413418
@boundscheck (1 <= j <= Nij && 1 <= i <= Nij && 1 <= h <= get_Nh(data)) ||
414419
throw(BoundsError(data, (i, j, h)))

0 commit comments

Comments
 (0)