Skip to content

Commit 8dd4b91

Browse files
wip
1 parent 236262b commit 8dd4b91

33 files changed

+683
-444
lines changed

ext/ClimaCoreCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import ClimaCore.RecursiveApply:
1818
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax
1919
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
2020
import ClimaCore.DataLayouts: universal_size, UniversalSize
21+
import ClimaCore.DataLayouts: ArraySize
2122

2223
include(joinpath("cuda", "cuda_utils.jl"))
2324
include(joinpath("cuda", "data_layouts.jl"))

ext/cuda/data_layouts_copyto.jl

Lines changed: 33 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,23 @@ import ClimaCore.DataLayouts:
22
to_non_extruded_broadcasted, has_uniform_datalayouts
33
DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA()
44

5-
function knl_copyto!(dest, src)
6-
7-
i = CUDA.threadIdx().x
8-
j = CUDA.threadIdx().y
9-
10-
h = CUDA.blockIdx().x
11-
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z
12-
13-
if v <= size(dest, 4)
14-
I = CartesianIndex((i, j, 1, v, h))
15-
@inbounds dest[I] = src[I]
16-
end
17-
return nothing
18-
end
19-
20-
function knl_copyto_field_array!(dest, src, us)
21-
@inbounds begin
22-
tidx = thread_index()
23-
if tidx get_N(us)
24-
n = size(dest)
25-
I = kernel_indexes(tidx, n)
26-
dest[I] = src[I]
27-
end
28-
end
29-
return nothing
30-
end
31-
32-
function Base.copyto!(
33-
dest::IJFH{S, Nij, Nh},
34-
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh},
35-
::ToCUDA,
36-
) where {S, Nij, Nh}
37-
us = DataLayouts.UniversalSize(dest)
38-
if Nh > 0
39-
auto_launch!(
40-
knl_copyto_field_array!,
41-
(dest, bc, us),
42-
prod(DataLayouts.universal_size(us));
43-
auto = true,
44-
)
45-
end
46-
return dest
47-
end
5+
# function Base.copyto!(
6+
# dest::VIJFH{S, Nv, Nij, Nh},
7+
# bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh},
8+
# ::ToCUDA,
9+
# ) where {S, Nv, Nij, Nh}
10+
# if Nv > 0 && Nh > 0
11+
# us = DataLayouts.UniversalSize(dest)
12+
# n = prod(DataLayouts.universal_size(us))
13+
# if has_uniform_datalayouts(bc)
14+
# bc′ = to_non_extruded_broadcasted(bc)
15+
# auto_launch!(knl_copyto_linear!, (dest, bc′, us), n; auto = true)
16+
# else
17+
# auto_launch!(knl_copyto_cart!, (dest, bc, us), n; auto = true)
18+
# end
19+
# end
20+
# return dest
21+
# end
4822

4923
function knl_copyto_linear!(dest::AbstractData, bc, us)
5024
@inbounds begin
@@ -57,73 +31,30 @@ function knl_copyto_linear!(dest::AbstractData, bc, us)
5731
end
5832

5933
function knl_copyto_linear!(dest::DataF, bc, us)
34+
tidx = thread_index()
6035
@inbounds dest[] = bc[tidx]
6136
return nothing
6237
end
6338

64-
function knl_copyto_cart!(dest, src, us)
39+
function knl_copyto_flat!(dest::AbstractData, bc, us)
6540
@inbounds begin
6641
tidx = thread_index()
6742
if tidx get_N(us)
6843
n = size(dest)
6944
I = kernel_indexes(tidx, n)
70-
dest[I] = src[I]
45+
dest[I] = bc[I]
7146
end
7247
end
7348
return nothing
7449
end
7550

76-
function Base.copyto!(
77-
dest::VIJFH{S, Nv, Nij, Nh},
78-
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh},
79-
::ToCUDA,
80-
) where {S, Nv, Nij, Nh}
81-
if Nv > 0 && Nh > 0
82-
us = DataLayouts.UniversalSize(dest)
83-
n = prod(DataLayouts.universal_size(us))
84-
if has_uniform_datalayouts(bc)
85-
bc′ = to_non_extruded_broadcasted(bc)
86-
auto_launch!(knl_copyto_linear!, (dest, bc′, us), n; auto = true)
87-
else
88-
auto_launch!(knl_copyto_cart!, (dest, bc, us), n; auto = true)
89-
end
90-
end
91-
return dest
92-
end
93-
94-
function Base.copyto!(
95-
dest::VF{S, Nv},
96-
bc::DataLayouts.BroadcastedUnionVF{S, Nv},
97-
::ToCUDA,
98-
) where {S, Nv}
99-
if Nv > 0
100-
auto_launch!(
101-
knl_copyto!,
102-
(dest, bc);
103-
threads_s = (1, 1),
104-
blocks_s = (1, Nv),
105-
)
106-
end
107-
return dest
108-
end
109-
110-
function Base.copyto!(
111-
dest::DataF{S},
112-
bc::DataLayouts.BroadcastedUnionDataF{S},
113-
::ToCUDA,
114-
) where {S}
115-
auto_launch!(knl_copyto!, (dest, bc); threads_s = (1, 1), blocks_s = (1, 1))
116-
return dest
117-
end
118-
119-
import ClimaCore.DataLayouts: isascalar
120-
function knl_copyto_flat!(dest::AbstractData, bc, us)
51+
function knl_copyto_flat!(dest::DataF, bc, us)
12152
@inbounds begin
12253
tidx = thread_index()
12354
if tidx get_N(us)
12455
n = size(dest)
12556
I = kernel_indexes(tidx, n)
126-
dest[I] = bc[I]
57+
dest[] = bc[I]
12758
end
12859
end
12960
return nothing
@@ -132,22 +63,27 @@ end
13263
function cuda_copyto!(dest::AbstractData, bc)
13364
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
13465
us = DataLayouts.UniversalSize(dest)
66+
n = prod(DataLayouts.universal_size(us))
13567
if Nv > 0 && Nh > 0
136-
nitems = prod(DataLayouts.universal_size(dest))
137-
auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true)
68+
if has_uniform_datalayouts(bc)
69+
bc′ = to_non_extruded_broadcasted(bc)
70+
auto_launch!(knl_copyto_linear!, (dest, bc′, us), nitems; auto = true)
71+
else
72+
auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true)
73+
end
13874
end
13975
return dest
14076
end
14177

14278
# TODO: can we use CUDA's luanch configuration for all data layouts?
14379
# Currently, it seems to have a slight performance degradation.
14480
#! format: off
145-
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
81+
Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
14682
Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc)
14783
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
14884
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
14985
Base.copyto!(dest::VIFH{S, Nv, Ni, Nh}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, Nh}, ::ToCUDA) where {S, Nv, Ni, Nh} = cuda_copyto!(dest, bc)
150-
# Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
151-
# Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
152-
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
86+
Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
87+
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
88+
Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
15389
#! format: on

ext/cuda/data_layouts_fill.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
function knl_fill_flat!(dest::AbstractData, val, us)
2-
# @inbounds begin
3-
# tidx = thread_index()
4-
# if tidx ≤ get_N(us)
5-
# n = size(dest)
6-
# I = kernel_indexes(tidx, n)
7-
# @inbounds dest[I] = val
8-
# end
9-
# end
2+
@inbounds begin
3+
tidx = thread_index()
4+
if tidx get_N(us)
5+
n = size(dest)
6+
I = kernel_indexes(tidx, n)
7+
@inbounds dest[I] = val
8+
end
9+
end
10+
return nothing
11+
end
12+
13+
function knl_fill_flat!(dest::DataF, val, us)
14+
@inbounds dest[] = val
1015
return nothing
1116
end
1217

src/DataLayouts/DataLayouts.jl

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,18 @@ function Base.show(io::IO, data::AbstractData)
129129
(rows, cols) = displaysize(io)
130130
println(io, summary(data))
131131
print(io, " "^indent_width)
132+
# @show similar(parent_array_type(data))
133+
# fa = map(x -> vec(x), field_arrays(data))
132134
print(
133135
IOContext(
134136
io,
135137
:compact => true,
136138
:limit => true,
137139
:displaysize => (rows, cols - indent_width),
138140
),
139-
map(x -> vec(x), field_arrays(data)),
141+
# collect(field_array(data)),
142+
parent(data),
143+
# map(x -> vec(x), field_arrays(data)),
140144
)
141145
return io
142146
end
@@ -619,10 +623,7 @@ function IJF{S, Nij}(::Type{MArray}, ::Type{T}) where {S, Nij, T}
619623
array = FieldArray{field_dim(IJF)}(ntuple(f->MArray{Tuple{Nij, Nij}, T, 2, Nij * Nij}(undef), Nf))
620624
IJF{S, Nij}(array)
621625
end
622-
function SArray(ijf::IJF{S, Nij, FieldArray{FD, N, T}}) where {S, Nij, FD, N, T <: MArray}
623-
IJF{S, Nij}(SArray(field_array(ijf)))
624-
end
625-
function SArray(ijf::IJF{S, Nij, <:MArray}) where {S, Nij}
626+
function SArray(ijf::IJF{S, Nij, <:FieldArray}) where {S, Nij}
626627
IJF{S, Nij}(SArray(field_array(ijf)))
627628
end
628629

@@ -681,15 +682,15 @@ end
681682
function IF{S, Ni}(::Type{MArray}, ::Type{T}) where {S, Ni, T}
682683
Nf = typesize(T, S)
683684
# array = MArray{Tuple{Ni, Nf}, T, 2, Ni * Nf}(undef)
684-
array = FieldArray{field_dim(IF)}(ntuple(f->MArray{Tuple{Ni}, T, 1, Ni}(undef), Nf))
685-
IF{S, Ni}(array)
685+
fa = FieldArray{field_dim(IF)}(ntuple(f->MArray{Tuple{Ni}, T, 1, Ni}(undef), Nf))
686+
IF{S, Ni}(fa)
686687
end
687-
function SArray(data::IF{S, Ni, <:FieldArray{<:Any, <:Any, T}}) where {S, Ni, T <: MArray}
688-
IF{S, Ni}(SArray(field_array(data)))
689-
end
690-
function SArray(data::IF{S, Ni, <:MArray}) where {S, Ni}
688+
function SArray(data::IF{S, Ni, <:FieldArray}) where {S, Ni}
691689
IF{S, Ni}(SArray(field_array(data)))
692690
end
691+
# function SArray(data::IF{S, Ni, <:MArray}) where {S, Ni}
692+
# IF{S, Ni}(SArray(field_array(data)))
693+
# end
693694

694695
@inline function column(data::IF{S, Ni}, i) where {S, Ni}
695696
@boundscheck (1 <= i <= Ni) || throw(BoundsError(data, (i,)))
@@ -816,14 +817,16 @@ Base.length(data::VIJFH) = get_Nv(data) * get_Nh(data)
816817
@boundscheck (1 <= v <= Nv && 1 <= h <= Nh) ||
817818
throw(BoundsError(data, (v, h)))
818819
Nf = ncomponents(data)
819-
dataview = @inbounds view(
820-
array,
821-
v,
822-
Base.Slice(Base.OneTo(Nij)),
823-
Base.Slice(Base.OneTo(Nij)),
824-
Base.Slice(Base.OneTo(Nf)),
825-
h,
826-
)
820+
sub_arrays = @inbounds ntuple(Nf) do f
821+
view(
822+
array.arrays[f],
823+
v,
824+
Base.Slice(Base.OneTo(Nij)),
825+
Base.Slice(Base.OneTo(Nij)),
826+
h,
827+
)
828+
end
829+
dataview = FieldArray{field_dim(IJF)}(sub_arrays)
827830
IJF{S, Nij}(dataview)
828831
end
829832

@@ -1113,11 +1116,15 @@ type parameters.
11131116
@inline field_dim(::Type{<:VIJFH}) = 4
11141117
@inline field_dim(::Type{<:VIFH}) = 3
11151118

1116-
@inline to_data_specific_field_array(::IJFH, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[2], I.I[5])
1117-
@inline to_data_specific_field_array(::IFH, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[5])
1118-
@inline to_data_specific_field_array(::VIJFH, I::CartesianIndex{5}) = CartesianIndex(I.I[4], I.I[1], I.I[2], I.I[5])
1119-
@inline to_data_specific_field_array(::VIFH, I::CartesianIndex{5}) = CartesianIndex(I.I[4], I.I[1], I.I[5])
1120-
@inline to_data_specific_field_array(::DataSlab1D, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[1], I.I[5])
1119+
@inline to_data_specific_field_array(data::AbstractData, I::CartesianIndex) =
1120+
CartesianIndex(_to_data_specific_field_array(data, I.I))
1121+
@inline _to_data_specific_field_array(::VF, I::Tuple) = (I[4],)
1122+
@inline _to_data_specific_field_array(::IF, I::Tuple) = (I[1],)
1123+
@inline _to_data_specific_field_array(::IJF, I::Tuple) = (I[1], I[2])
1124+
@inline _to_data_specific_field_array(::IJFH, I::Tuple) = (I[1], I[2], I[5])
1125+
@inline _to_data_specific_field_array(::IFH, I::Tuple) = (I[1], I[5])
1126+
@inline _to_data_specific_field_array(::VIJFH, I::Tuple) = (I[4], I[1], I[2], I[5])
1127+
@inline _to_data_specific_field_array(::VIFH, I::Tuple) = (I[4], I[1], I[5])
11211128

11221129
@inline to_data_specific(data::AbstractData, I::CartesianIndex) =
11231130
CartesianIndex(_to_data_specific(data, I.I))
@@ -1349,7 +1356,7 @@ field_array(data::AbstractData{S}) where {S} = parent(data)
13491356
parent(data),
13501357
eltype(data),
13511358
Val(field_dim(data)),
1352-
to_data_specific(data, I),
1359+
to_data_specific_field_array(data, I),
13531360
)
13541361
end
13551362

@@ -1363,7 +1370,7 @@ end
13631370
parent(data),
13641371
convert(eltype(data), val),
13651372
Val(field_dim(data)),
1366-
to_data_specific(data, I),
1373+
to_data_specific_field_array(data, I),
13671374
)
13681375
end
13691376

0 commit comments

Comments
 (0)