Skip to content

Commit 985ef0b

Browse files
Add array2data and data2array functionality with transpose option.
1 parent cc9f859 commit 985ef0b

File tree

6 files changed

+228
-1
lines changed

6 files changed

+228
-1
lines changed

ext/cuda/data_layouts.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include("data_layouts_copyto.jl")
3232
include("data_layouts_fused_copyto.jl")
3333
include("data_layouts_mapreduce.jl")
3434
include("data_layouts_threadblock.jl")
35+
include("data_layouts_rrtmgp.jl")
3536

3637
adapt_f(to, f::F) where {F} = Adapt.adapt(to, f)
3738
adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...)

ext/cuda/data_layouts_rrtmgp.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import ClimaCore.DataLayouts: data2array_rrtmgp!, array2data_rrtmgp!
2+
3+
function data2array_rrtmgp!(
4+
array::CUDA.CuArray,
5+
data::AbstractData,
6+
::Val{trans},
7+
) where {trans}
8+
(nl, ncol) = trans ? size(array) : reverse(size(array))
9+
Ni, Nj, Nk, Nv, Nh = Base.size(data)
10+
@assert nl * ncol == Ni * Nj * Nk * Nv * Nh
11+
@assert prod(size(parent(data))) == Ni * Nj * Nk * Nv * Nh # verify Nf == 1
12+
kernel = CUDA.@cuda launch = false data2array_rrtmgp_kernel!(
13+
array,
14+
data,
15+
Val(trans),
16+
)
17+
kernel_config = CUDA.launch_configuration(kernel.fun)
18+
nitems = Ni * Nj * Nk * Nh
19+
nthreads, nblocks = linear_partition(nitems, kernel_config.threads)
20+
CUDA.@cuda threads = nthreads blocks = nblocks data2array_rrtmgp_kernel!(
21+
array,
22+
data,
23+
Val(trans),
24+
)
25+
return nothing
26+
end
27+
28+
function data2array_rrtmgp_kernel!(
29+
array::AbstractArray,
30+
data::AbstractData,
31+
::Val{trans},
32+
) where {trans}
33+
Ni, Nj, Nk, Nv, Nh = Base.size(data)
34+
ncol = Ni * Nj * Nk * Nh
35+
# obtain the column number processed by each thread
36+
gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x
37+
if gid ncol
38+
h = cld(gid, Ni * Nj * Nk)
39+
idx = gid - (h - 1) * Ni * Nj * Nk
40+
k = cld(idx, Ni * Nj)
41+
idx = idx - (k - 1) * Ni * Nj
42+
j = cld(idx, Ni)
43+
i = idx - (j - 1) * Ni
44+
@inbounds begin
45+
for v in 1:Nv
46+
colidx =
47+
i +
48+
(j - 1) * Ni +
49+
(k - 1) * Ni * Nj +
50+
(h - 1) * Ni * Nj * Nk
51+
cidx = CartesianIndex(i, j, k, v, h)
52+
trans ? (array[colidx, v] = data[cidx]) :
53+
(array[v, colidx] = data[cidx])
54+
end
55+
end
56+
end
57+
return nothing
58+
end
59+
60+
function array2data_rrtmgp!(
61+
data::AbstractData,
62+
array::CUDA.CuArray,
63+
::Val{trans},
64+
) where {trans}
65+
(nl, ncol) = trans ? size(array) : reverse(size(array))
66+
Ni, Nj, Nk, Nv, Nh = Base.size(data)
67+
@assert nl * ncol == Ni * Nj * Nk * Nv * Nh
68+
@assert prod(size(parent(data))) == Ni * Nj * Nk * Nv * Nh # verify Nf == 1
69+
kernel = CUDA.@cuda launch = false array2data_rrtmgp_kernel!(
70+
data,
71+
array,
72+
Val(trans),
73+
)
74+
kernel_config = CUDA.launch_configuration(kernel.fun)
75+
nitems = Ni * Nj * Nk * Nh
76+
nthreads, nblocks = linear_partition(nitems, kernel_config.threads)
77+
CUDA.@cuda threads = nthreads blocks = nblocks array2data_rrtmgp_kernel!(
78+
data,
79+
array,
80+
Val(trans),
81+
)
82+
return nothing
83+
end
84+
85+
function array2data_rrtmgp_kernel!(
86+
data::AbstractData,
87+
array::AbstractArray,
88+
::Val{trans},
89+
) where {trans}
90+
Ni, Nj, Nk, Nv, Nh = Base.size(data)
91+
ncol = Ni * Nj * Nk * Nh
92+
# obtain the column number processed by each thread
93+
gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x
94+
if gid ncol
95+
h = cld(gid, Ni * Nj * Nk)
96+
idx = gid - (h - 1) * Ni * Nj * Nk
97+
k = cld(idx, Ni * Nj)
98+
idx = idx - (k - 1) * Ni * Nj
99+
j = cld(idx, Ni)
100+
i = idx - (j - 1) * Ni
101+
colidx = i + (j - 1) * Ni + (k - 1) * Ni * Nj + (h - 1) * Ni * Nj * Nk
102+
@inbounds begin
103+
for v in 1:Nv
104+
data[CartesianIndex(i, j, k, v, h)] =
105+
trans ? array[colidx, v] : array[v, colidx]
106+
end
107+
end
108+
end
109+
return nothing
110+
end

src/DataLayouts/DataLayouts.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2211,6 +2211,52 @@ array2data(array::AbstractArray{T}, data::AbstractData) where {T} =
22112211
reshape(array, array_size(data)...),
22122212
)
22132213

2214+
function data2array_rrtmgp!(
2215+
array::AbstractArray,
2216+
data::Union{VF{S}, VIFH{S}, VIHF{S}, VIJFH{S}, VIJHF{S}},
2217+
::Val{trans},
2218+
) where {trans, S}
2219+
(nl, ncol) = trans ? size(array) : reverse(size(array))
2220+
Ni, Nj, Nk, Nv, Nh = Base.size(data)
2221+
@assert nl * ncol == Ni * Nj * Nk * Nv * Nh
2222+
@assert prod(size(parent(data))) == Ni * Nj * Nk * Nv * Nh # verify Nf == 1
2223+
2224+
@inbounds begin
2225+
for i in 1:Ni, j in 1:Nj, k in 1:Nk, h in 1:Nh
2226+
colidx =
2227+
i + (j - 1) * Ni + (k - 1) * Ni * Nj + (h - 1) * Ni * Nj * Nk
2228+
for v in 1:Nv
2229+
cidx = CartesianIndex(i, j, k, v, h)
2230+
trans ? (array[colidx, v] = data[cidx]) :
2231+
(array[v, colidx] = data[cidx])
2232+
end
2233+
end
2234+
end
2235+
return nothing
2236+
end
2237+
2238+
function array2data_rrtmgp!(
2239+
data::Union{VF{S}, VIFH{S}, VIHF{S}, VIJFH{S}, VIJHF{S}},
2240+
array::AbstractArray,
2241+
::Val{trans},
2242+
) where {trans, S}
2243+
(nl, ncol) = trans ? size(array) : reverse(size(array))
2244+
Ni, Nj, Nk, Nv, Nh = Base.size(data)
2245+
@assert nl * ncol == Ni * Nj * Nk * Nv * Nh
2246+
@assert prod(size(parent(data))) == Ni * Nj * Nk * Nv * Nh # verify Nf == 1
2247+
@inbounds begin
2248+
for i in 1:Ni, j in 1:Nj, k in 1:Nk, h in 1:Nh
2249+
colidx =
2250+
i + (j - 1) * Ni + (k - 1) * Ni * Nj + (h - 1) * Ni * Nj * Nk
2251+
for v in 1:Nv
2252+
data[CartesianIndex(i, j, k, v, h)] =
2253+
trans ? array[colidx, v] : array[v, colidx]
2254+
end
2255+
end
2256+
end
2257+
return nothing
2258+
end
2259+
22142260
"""
22152261
device_dispatch(array::AbstractArray)
22162262

test/DataLayouts/data2dx.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@ using Revise; include(joinpath("test", "DataLayouts", "data2dx.jl"))
55
using Test
66
using ClimaComms
77
using ClimaCore.DataLayouts
8-
import ClimaCore.DataLayouts: VF, IJFH, VIJFH, slab, column, slab_index, vindex
8+
import ClimaCore.DataLayouts:
9+
VF,
10+
IJFH,
11+
VIJFH,
12+
slab,
13+
column,
14+
slab_index,
15+
vindex,
16+
data2array_rrtmgp!,
17+
array2data_rrtmgp!
918

1019
device = ClimaComms.device()
1120
ArrayType = ClimaComms.array_type(device)

test/DataLayouts/datarrtmgp.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#=
2+
julia --project=test
3+
using Revise; include(joinpath("test", "DataLayouts", "data2dx.jl"))
4+
=#
5+
using Test
6+
using ClimaComms
7+
ClimaComms.@import_required_backends
8+
using ClimaCore.DataLayouts
9+
import ClimaCore.DataLayouts:
10+
VF,
11+
IJFH,
12+
VIJFH,
13+
slab,
14+
column,
15+
slab_index,
16+
vindex,
17+
data2array_rrtmgp!,
18+
array2data_rrtmgp!
19+
20+
device = ClimaComms.device()
21+
ArrayType = ClimaComms.array_type(device)
22+
23+
@testset "VIJFH data2array_rrtmgp!" begin
24+
Nv = 10 # number of vertical levels
25+
Nij = 4 # Nij × Nij nodal points per element
26+
Nh = 10 # number of elements
27+
28+
nl = Nv # number of levels/layers in the array
29+
ncol = Nij * Nij * Nh # number of columns in the array
30+
31+
for FT in (Float32, Float64)
32+
data1 = VIJFH{FT, Nv, Nij}(
33+
ArrayType(
34+
reshape(FT(1.0):(Nv * Nij * Nij * Nh), Nv, Nij, Nij, 1, Nh),
35+
),
36+
)
37+
38+
array1 = ArrayType{FT}(undef, nl, ncol)
39+
array1t = ArrayType{FT}(undef, ncol, nl)
40+
41+
data2array_rrtmgp!(array1, data1, Val(false))
42+
data2array_rrtmgp!(array1t, data1, Val(true))
43+
44+
@test array1 == ArrayType(transpose(array1t))
45+
46+
array2 = ArrayType(reshape(FT(1.0):(Nv * Nij * Nij * Nh), Nv, :))
47+
array2t = ArrayType{FT}(undef, ncol, nl)
48+
49+
data2 = VIJFH{FT, Nv, Nij}(ArrayType{FT}(undef, Nv, Nij, Nij, 1, Nh))
50+
51+
array2data_rrtmgp!(data2, array2, Val(false))
52+
@test parent(data2) == parent(data1)
53+
54+
array2t = ArrayType(transpose(Array(array2)))
55+
56+
parent(data2) .= NaN
57+
array2data_rrtmgp!(data2, array2t, Val(true))
58+
@test parent(data2) == parent(data1)
59+
end
60+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ UnitTest("DataLayouts 1D" ,"DataLayouts/data1d.jl"),
2222
UnitTest("DataLayouts 2D" ,"DataLayouts/data2d.jl"),
2323
UnitTest("DataLayouts 1dx" ,"DataLayouts/data1dx.jl"),
2424
UnitTest("DataLayouts 2dx" ,"DataLayouts/data2dx.jl"),
25+
UnitTest("DataLayouts RRTMGP" ,"DataLayouts/datarrtmgp.jl"),
2526
UnitTest("DataLayouts mapreduce" ,"DataLayouts/unit_mapreduce.jl"),
2627
UnitTest("Geometry" ,"Geometry/geometry.jl"),
2728
UnitTest("rmul_with_projection" ,"Geometry/rmul_with_projection.jl"),

0 commit comments

Comments
 (0)