diff --git a/ext/cuda/data_layouts.jl b/ext/cuda/data_layouts.jl index b0eba5260c..7e70667ab2 100644 --- a/ext/cuda/data_layouts.jl +++ b/ext/cuda/data_layouts.jl @@ -32,6 +32,7 @@ include("data_layouts_copyto.jl") include("data_layouts_fused_copyto.jl") include("data_layouts_mapreduce.jl") include("data_layouts_threadblock.jl") +include("data_layouts_rrtmgp.jl") adapt_f(to, f::F) where {F} = Adapt.adapt(to, f) adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...) diff --git a/ext/cuda/data_layouts_rrtmgp.jl b/ext/cuda/data_layouts_rrtmgp.jl new file mode 100644 index 0000000000..1acdc6aaf6 --- /dev/null +++ b/ext/cuda/data_layouts_rrtmgp.jl @@ -0,0 +1,139 @@ +import ClimaCore.DataLayouts: data2array_rrtmgp!, array2data_rrtmgp! + +function data2array_rrtmgp!( + array::CUDA.CuArray, + data::D, + ::Val{trans}, +) where {trans, D <: Union{VF, VIFH, VIHF, VIJFH, VIJHF}} + (nl, ncol) = trans ? size(array) : reverse(size(array)) + Ni, Nj, Nk, Nv, Nh = Base.size(data) + @assert nl * ncol == Ni * Nj * Nk * Nv * Nh + @assert prod(size(parent(data))) == Ni * Nj * Nk * Nv * Nh # verify Nf == 1 + kernel = CUDA.@cuda launch = false data2array_rrtmgp_kernel!( + array, + data, + Val(trans), + ) + kernel_config = CUDA.launch_configuration(kernel.fun) + nitems = Ni * Nj * Nk * Nh + nthreads, nblocks = linear_partition(nitems, kernel_config.threads) + CUDA.@cuda threads = nthreads blocks = nblocks data2array_rrtmgp_kernel!( + array, + data, + Val(trans), + ) + return nothing +end + +function data2array_rrtmgp_kernel!( + array::AbstractArray, + data::AbstractData, + ::Val{trans}, +) where {trans} + Ni, Nj, Nk, Nv, Nh = Base.size(data) + ncol = Ni * Nj * Nk * Nh + # obtain the column number processed by each thread + gidx = threadIdx().x + (blockIdx().x - 1) * blockDim().x + if gidx ≤ ncol + h = cld(gidx, Ni * Nj * Nk) + idx = gidx - (h - 1) * Ni * Nj * Nk + k = cld(idx, Ni * Nj) + idx = idx - (k - 1) * Ni * Nj + j = cld(idx, Ni) + i = idx - (j - 1) * Ni + @inbounds begin + for v in 1:Nv + cidx = CartesianIndex(i, j, k, v, h) + trans ? (array[gidx, v] = data[cidx][1]) : + (array[v, gidx] = data[cidx][1]) + end + end + end + return nothing +end + +_get_kernel_function(::VIJFH) = array2data_rrtmgp_VIJFH_kernel! +_get_kernel_function(::VIFH) = array2data_rrtmgp_VIFH_kernel! +_get_kernel_function(::VF) = array2data_rrtmgp_VF_kernel! + +function array2data_rrtmgp!( + data::D, + array::CUDA.CuArray, + ::Val{trans}, +) where {trans, D <: Union{VF, VIFH, VIHF, VIJFH, VIJHF}} + (nl, ncol) = trans ? size(array) : reverse(size(array)) + Ni, Nj, _, Nv, Nh = Base.size(data) + @assert nl * ncol == Ni * Nj * Nv * Nh + @assert prod(size(parent(data))) == Ni * Nj * Nv * Nh # verify Nf == 1 + + kernelfun! = _get_kernel_function(data) + + kernel = + CUDA.@cuda launch = false kernelfun!(parent(data), array, Val(trans)) + kernel_config = CUDA.launch_configuration(kernel.fun) + nitems = Ni * Nj * Nh + nthreads, nblocks = linear_partition(nitems, kernel_config.threads) + CUDA.@cuda threads = nthreads blocks = nblocks kernelfun!( + parent(data), + array, + Val(trans), + ) + return nothing +end + +function array2data_rrtmgp_VIJFH_kernel!( + parentdata::AbstractArray, + array::AbstractArray, + ::Val{trans}, +) where {trans} + Nv, Ni, Nj, _, Nh = size(parentdata) + ncol = Ni * Nj * Nh + # obtain the column number processed by each thread + gidx = threadIdx().x + (blockIdx().x - 1) * blockDim().x + if gidx ≤ ncol + h = cld(gidx, Ni * Nj) + j = cld(gidx - (h - 1) * Ni * Nj, Ni) + i = gidx - (h - 1) * Ni * Nj - (j - 1) * Ni + for v in 1:Nv + @inbounds parentdata[v, i, j, 1, h] = + trans ? array[gidx, v] : array[v, gidx] + end + end + return nothing +end + +function array2data_rrtmgp_VIFH_kernel!( + parentdata::AbstractArray, + array::AbstractArray, + ::Val{trans}, +) where {trans} + Nv, Ni, _, Nh = size(parentdata) + ncol = Ni * Nh + # obtain the column number processed by each thread + gidx = threadIdx().x + (blockIdx().x - 1) * blockDim().x + if gidx ≤ ncol + h = cld(gidx, Ni) + i = gidx - (h - 1) * Ni + for v in 1:Nv + @inbounds parentdata[v, i, 1, h] = + trans ? array[gidx, v] : array[v, gidx] + end + end + return nothing +end + +function array2data_rrtmgp_VF_kernel!( + parentdata::AbstractArray, + array::AbstractArray, + ::Val{trans}, +) where {trans} + Nv, _ = size(parentdata) + # obtain the column number processed by each thread + gidx = threadIdx().x + (blockIdx().x - 1) * blockDim().x + if gidx ≤ 1 + for v in 1:Nv + @inbounds parentdata[v, 1] = trans ? array[gidx, v] : array[v, gidx] + end + end + return nothing +end diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index 64a0e008a3..ba887b0aba 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -2211,6 +2211,107 @@ array2data(array::AbstractArray{T}, data::AbstractData) where {T} = reshape(array, array_size(data)...), ) +function data2array_rrtmgp!( + array::AbstractArray, + data::D, + ::Val{trans}, +) where {trans, D <: Union{VF, VIFH, VIHF, VIJFH, VIJHF}} + (nl, ncol) = trans ? size(array) : reverse(size(array)) + Ni, Nj, Nk, Nv, Nh = Base.size(data) + @assert nl * ncol == Ni * Nj * Nk * Nv * Nh + @assert prod(size(parent(data))) == Ni * Nj * Nk * Nv * Nh # verify Nf == 1 + + @inbounds begin + for i in 1:Ni, j in 1:Nj, k in 1:Nk, h in 1:Nh + colidx = + i + (j - 1) * Ni + (k - 1) * Ni * Nj + (h - 1) * Ni * Nj * Nk + for v in 1:Nv + cidx = CartesianIndex(i, j, k, v, h) + trans ? (array[colidx, v] = data[cidx][1]) : + (array[v, colidx] = data[cidx][1]) + end + end + end + return nothing +end + +_get_array2data_rrtmgp_function(::VIJFH) = array2data_rrtmgp_VIJFH! +_get_array2data_rrtmgp_function(::VIJHF) = array2data_rrtmgp_VIJHF! +_get_array2data_rrtmgp_function(::VIHF) = array2data_rrtmgp_VIHF! +_get_array2data_rrtmgp_function(::VIFH) = array2data_rrtmgp_VIFH! +_get_array2data_rrtmgp_function(::VF) = array2data_rrtmgp_VF! + +function array2data_rrtmgp!( + data::D, + array::AbstractArray, + ::Val{trans}, +) where {trans, D <: Union{VF, VIFH, VIHF, VIJFH, VIJHF}} + (nl, ncol) = trans ? size(array) : reverse(size(array)) + Ni, Nj, _, Nv, Nh = Base.size(data) + @assert nl * ncol == Ni * Nj * Nv * Nh + parentdata = parent(data) + @assert prod(size(parentdata)) == Ni * Nj * Nv * Nh # verify Nf == 1 + array2data_func! = _get_array2data_rrtmgp_function(data) + array2data_func!(parentdata, array, Val(trans)) + return nothing +end + +function array2data_rrtmgp_VIJFH!(parentdata, array, ::Val{trans}) where {trans} + Nv, Ni, Nj, _, Nh = size(parentdata) + for h in 1:Nh, j in 1:Nj, i in 1:Ni + colidx = i + (j - 1) * Ni + (h - 1) * Ni * Nj + for v in 1:Nv + @inbounds parentdata[v, i, j, 1, h] = + trans ? array[colidx, v] : array[v, colidx] + end + end + return nothing +end + +function array2data_rrtmgp_VIFH!(parentdata, array, ::Val{trans}) where {trans} + Nv, Ni, _, Nh = size(parentdata) + for h in 1:Nh, i in 1:Ni + colidx = i + (h - 1) * Ni + for v in 1:Nv + @inbounds parentdata[v, i, 1, h] = + trans ? array[colidx, v] : array[v, colidx] + end + end + return nothing +end + +function array2data_rrtmgp_VF!(parentdata, array, ::Val{trans}) where {trans} + Nv, _ = size(parentdata) + for v in 1:Nv + @inbounds parentdata[v, 1] = trans ? array[1, v] : array[v, 1] + end + return nothing +end + +function array2data_rrtmgp_VIJHF!(parentdata, array, ::Val{trans}) where {trans} + Nv, Ni, Nj, Nh, _ = size(parentdata) + for h in 1:Nh, j in 1:Nj, i in 1:Ni + colidx = i + (j - 1) * Ni + (h - 1) * Ni * Nj + for v in 1:Nv + @inbounds parentdata[v, i, j, h, 1] = + trans ? array[colidx, v] : array[v, colidx] + end + end + return nothing +end + +function array2data_rrtmgp_VIHF!(parentdata, array, ::Val{trans}) where {trans} + Nv, Ni, Nh, _ = size(parentdata) + for h in 1:Nh, i in 1:Ni + colidx = i + (h - 1) * Ni + for v in 1:Nv + @inbounds parentdata[v, i, h, 1] = + trans ? array[colidx, v] : array[v, colidx] + end + end + return nothing +end + """ device_dispatch(array::AbstractArray) diff --git a/test/DataLayouts/data2dx.jl b/test/DataLayouts/data2dx.jl index 1b2621b4dd..4ad9b3a89c 100644 --- a/test/DataLayouts/data2dx.jl +++ b/test/DataLayouts/data2dx.jl @@ -5,7 +5,16 @@ using Revise; include(joinpath("test", "DataLayouts", "data2dx.jl")) using Test using ClimaComms using ClimaCore.DataLayouts -import ClimaCore.DataLayouts: VF, IJFH, VIJFH, slab, column, slab_index, vindex +import ClimaCore.DataLayouts: + VF, + IJFH, + VIJFH, + slab, + column, + slab_index, + vindex, + data2array_rrtmgp!, + array2data_rrtmgp! device = ClimaComms.device() ArrayType = ClimaComms.array_type(device) diff --git a/test/DataLayouts/datarrtmgp.jl b/test/DataLayouts/datarrtmgp.jl new file mode 100644 index 0000000000..54bc2ebc58 --- /dev/null +++ b/test/DataLayouts/datarrtmgp.jl @@ -0,0 +1,192 @@ +#= +julia --project=test +using Revise; include(joinpath("test", "DataLayouts", "data2dx.jl")) +=# +using Test +using ClimaComms +ClimaComms.@import_required_backends +using ClimaCore.DataLayouts +using ClimaCore.Geometry +import ClimaCore.DataLayouts: + VF, + VIFH, + VIJFH, + VIJHF, + slab, + column, + slab_index, + vindex, + data2array_rrtmgp!, + array2data_rrtmgp! + +device = ClimaComms.device() +ArrayType = ClimaComms.array_type(device) + +@testset "VIJFH data2array_rrtmgp!" begin + Nv = 10 # number of vertical levels + Nij = 4 # Nij × Nij nodal points per element + Nh = 10 # number of elements + + nl = Nv # number of levels/layers in the array + ncol = Nij * Nij * Nh # number of columns in the array + + for FT in (Float32, Float64) + data1 = VIJFH{FT, Nv, Nij}( + ArrayType( + reshape(FT(1.0):(Nv * Nij * Nij * Nh), Nv, Nij, Nij, 1, Nh), + ), + ) + + array1 = ArrayType{FT}(undef, nl, ncol) + array1t = ArrayType{FT}(undef, ncol, nl) + + data2array_rrtmgp!(array1, data1, Val(false)) + data2array_rrtmgp!(array1t, data1, Val(true)) + + @test array1 == ArrayType(transpose(array1t)) + + array2 = ArrayType(reshape(FT(1.0):(Nv * Nij * Nij * Nh), Nv, :)) + array2t = ArrayType{FT}(undef, ncol, nl) + + data2 = VIJFH{FT, Nv, Nij}(ArrayType{FT}(undef, Nv, Nij, Nij, 1, Nh)) + + array2data_rrtmgp!(data2, array2, Val(false)) + @test parent(data2) == parent(data1) + + array2t = ArrayType(transpose(Array(array2))) + + parent(data2) .= NaN + array2data_rrtmgp!(data2, array2t, Val(true)) + @test parent(data2) == parent(data1) + + data3 = VIJFH{WVector{FT}, Nv, Nij}( + ArrayType( + reshape(FT(1.0):(Nv * Nij * Nij * Nh), Nv, Nij, Nij, 1, Nh), + ), + ) + array3 = ArrayType{FT}(undef, nl, ncol) + array3t = ArrayType{FT}(undef, ncol, nl) + + data2array_rrtmgp!(array3, data3, Val(false)) + data2array_rrtmgp!(array3t, data3, Val(true)) + + @test array3 == ArrayType(transpose(array3t)) + + + array4 = ArrayType(reshape(FT(1.0):(Nv * Nij * Nij * Nh), Nv, :)) + array4t = ArrayType{FT}(undef, ncol, nl) + + data4 = VIJFH{FT, Nv, Nij}(ArrayType{FT}(undef, Nv, Nij, Nij, 1, Nh)) + + array2data_rrtmgp!(data4, array4, Val(false)) + @test parent(data4) == parent(data4) + end +end + +@testset "VIFH data2array_rrtmgp!" begin + Nv = 10 # number of vertical levels + Ni = 4 # Nij × Nij nodal points per element + Nh = 10 # number of elements + + nl = Nv # number of levels/layers in the array + ncol = Ni * Nh # number of columns in the array + + for FT in (Float32, Float64) + data1 = VIFH{FT, Nv, Ni}( + ArrayType(reshape(FT(1.0):(Nv * Ni * Nh), Nv, Ni, 1, Nh)), + ) + + array1 = ArrayType{FT}(undef, nl, ncol) + array1t = ArrayType{FT}(undef, ncol, nl) + + data2array_rrtmgp!(array1, data1, Val(false)) + data2array_rrtmgp!(array1t, data1, Val(true)) + + @test array1 == ArrayType(transpose(array1t)) + + array2 = ArrayType(reshape(FT(1.0):(Nv * Ni * Nh), Nv, :)) + array2t = ArrayType{FT}(undef, ncol, nl) + + data2 = VIFH{FT, Nv, Ni}(ArrayType{FT}(undef, Nv, Ni, 1, Nh)) + + array2data_rrtmgp!(data2, array2, Val(false)) + @test parent(data2) == parent(data1) + + array2t = ArrayType(transpose(Array(array2))) + + parent(data2) .= NaN + array2data_rrtmgp!(data2, array2t, Val(true)) + @test parent(data2) == parent(data1) + + data3 = VIFH{WVector{FT}, Nv, Ni}( + ArrayType(reshape(FT(1.0):(Nv * Ni * Nh), Nv, Ni, 1, Nh)), + ) + array3 = ArrayType{FT}(undef, nl, ncol) + array3t = ArrayType{FT}(undef, ncol, nl) + + data2array_rrtmgp!(array3, data3, Val(false)) + data2array_rrtmgp!(array3t, data3, Val(true)) + + @test array3 == ArrayType(transpose(array3t)) + + + array4 = ArrayType(reshape(FT(1.0):(Nv * Ni * Nh), Nv, :)) + array4t = ArrayType{FT}(undef, ncol, nl) + + data4 = VIFH{FT, Nv, Ni}(ArrayType{FT}(undef, Nv, Ni, 1, Nh)) + + array2data_rrtmgp!(data4, array4, Val(false)) + @test parent(data4) == parent(data4) + end +end + +@testset "VF data2array_rrtmgp!" begin + Nv = 10 # number of vertical levels + + nl = Nv # number of levels/layers in the array + ncol = 1# number of columns in the array + + for FT in (Float32, Float64) + data1 = VF{FT, Nv}(ArrayType(reshape(FT(1.0):(Nv), Nv, 1))) + + array1 = ArrayType{FT}(undef, nl, ncol) + array1t = ArrayType{FT}(undef, ncol, nl) + + data2array_rrtmgp!(array1, data1, Val(false)) + data2array_rrtmgp!(array1t, data1, Val(true)) + + @test array1 == ArrayType(transpose(array1t)) + + array2 = ArrayType(reshape(FT(1.0):(Nv), Nv, :)) + array2t = ArrayType{FT}(undef, ncol, nl) + + data2 = VF{FT, Nv}(ArrayType{FT}(undef, Nv, 1)) + + array2data_rrtmgp!(data2, array2, Val(false)) + @test parent(data2) == parent(data1) + + array2t = ArrayType(transpose(Array(array2))) + + parent(data2) .= NaN + array2data_rrtmgp!(data2, array2t, Val(true)) + @test parent(data2) == parent(data1) + + data3 = VF{WVector{FT}, Nv}(ArrayType(reshape(FT(1.0):(Nv), Nv, 1))) + array3 = ArrayType{FT}(undef, nl, ncol) + array3t = ArrayType{FT}(undef, ncol, nl) + + data2array_rrtmgp!(array3, data3, Val(false)) + data2array_rrtmgp!(array3t, data3, Val(true)) + + @test array3 == ArrayType(transpose(array3t)) + + + array4 = ArrayType(reshape(FT(1.0):(Nv), Nv, :)) + array4t = ArrayType{FT}(undef, ncol, nl) + + data4 = VF{FT, Nv}(ArrayType{FT}(undef, Nv, 1)) + + array2data_rrtmgp!(data4, array4, Val(false)) + @test parent(data4) == parent(data4) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 7d82c25493..af892fa10d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,7 @@ UnitTest("DataLayouts 1D" ,"DataLayouts/data1d.jl"), UnitTest("DataLayouts 2D" ,"DataLayouts/data2d.jl"), UnitTest("DataLayouts 1dx" ,"DataLayouts/data1dx.jl"), UnitTest("DataLayouts 2dx" ,"DataLayouts/data2dx.jl"), +UnitTest("DataLayouts RRTMGP" ,"DataLayouts/datarrtmgp.jl"), UnitTest("DataLayouts mapreduce" ,"DataLayouts/unit_mapreduce.jl"), UnitTest("Geometry" ,"Geometry/geometry.jl"), UnitTest("rmul_with_projection" ,"Geometry/rmul_with_projection.jl"),