From ed90e4e8467346ff8a067f2863b775f347900e31 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 24 Jul 2024 13:43:33 -0400 Subject: [PATCH] Implement data-specific cartesian index --- ext/ClimaCoreCUDAExt.jl | 1 + ext/cuda/data_layouts_fill.jl | 5 ++-- src/DataLayouts/DataLayouts.jl | 13 +++++++++ src/DataLayouts/cartesian_index.jl | 46 ++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 src/DataLayouts/cartesian_index.jl diff --git a/ext/ClimaCoreCUDAExt.jl b/ext/ClimaCoreCUDAExt.jl index 167696e93d..9621b9b0ee 100644 --- a/ext/ClimaCoreCUDAExt.jl +++ b/ext/ClimaCoreCUDAExt.jl @@ -17,6 +17,7 @@ import ClimaCore.Utilities: cart_ind, linear_ind import ClimaCore.RecursiveApply: ⊠, ⊞, ⊟, radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh +import ClimaCore.DataLayouts: DataSpecificCartesianIndex, array_size include(joinpath("cuda", "cuda_utils.jl")) include(joinpath("cuda", "data_layouts.jl")) diff --git a/ext/cuda/data_layouts_fill.jl b/ext/cuda/data_layouts_fill.jl index 9999c65a8a..dd9b12d74e 100644 --- a/ext/cuda/data_layouts_fill.jl +++ b/ext/cuda/data_layouts_fill.jl @@ -2,8 +2,9 @@ function knl_fill_flat!(dest::AbstractData, val, us) @inbounds begin tidx = thread_index() if tidx ≤ get_N(us) - n = size(dest) - I = kernel_indexes(tidx, n) + n = array_size(dest) + CIS = CartesianIndices(map(x -> Base.OneTo(x), n)) + I = DataSpecificCartesianIndex(CIS[tidx]) @inbounds dest[I] = val end end diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index c8f0cb0b16..77f1bf6a5b 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -1102,6 +1102,7 @@ function VIJFH{S, Nv, Nij, Nh}( array::AbstractArray{T, 5}, ) where {S, Nv, Nij, Nh, T} check_basetype(T, S) + @assert size(array, 1) == Nv @assert size(array, 2) == size(array, 3) == Nij @assert size(array, 4) == typesize(T, S) @assert size(array, 5) == Nh @@ -1271,6 +1272,7 @@ function VIFH{S, Nv, Ni, Nh}( array::AbstractArray{T, 4}, ) where {S, Nv, Ni, Nh, T} check_basetype(T, S) + @assert size(array, 1) == Nv @assert size(array, 2) == Ni @assert size(array, 3) == typesize(T, S) @assert size(array, 4) == Nh @@ -1568,6 +1570,16 @@ get_Nij(::IFH{S, Nij}) where {S, Nij} = Nij get_Nij(::IJF{S, Nij}) where {S, Nij} = Nij get_Nij(::IF{S, Nij}) where {S, Nij} = Nij +@inline field_dim(::IJKFVH) = 4 +@inline field_dim(::IJFH) = 3 +@inline field_dim(::IFH) = 2 +@inline field_dim(::DataF) = 1 +@inline field_dim(::IJF) = 3 +@inline field_dim(::IF) = 2 +@inline field_dim(::VF) = 2 +@inline field_dim(::VIJFH) = 4 +@inline field_dim(::VIFH) = 3 + Base.ndims(data::AbstractData) = Base.ndims(typeof(data)) Base.ndims(::Type{T}) where {T <: AbstractData} = Base.ndims(parent_array_type(T)) @@ -1641,5 +1653,6 @@ include("copyto.jl") include("fused_copyto.jl") include("fill.jl") include("mapreduce.jl") +include("cartesian_index.jl") end # module diff --git a/src/DataLayouts/cartesian_index.jl b/src/DataLayouts/cartesian_index.jl new file mode 100644 index 0000000000..8c76a5811e --- /dev/null +++ b/src/DataLayouts/cartesian_index.jl @@ -0,0 +1,46 @@ +abstract type AbstractDataSpecificCartesianIndex{N} <: + Base.AbstractCartesianIndex{N} end + +struct DataSpecificCartesianIndex{N} <: AbstractDataSpecificCartesianIndex{N} + I::CartesianIndex{N} +end + +# Generic fallback +@propagate_inbounds Base.getindex(x, I::DataSpecificCartesianIndex) = + Base.getindex(x, I.I) + +@propagate_inbounds Base.setindex!(x, val, I::DataSpecificCartesianIndex) = + Base.setindex!(x, val, I.I) + +# Datalayouts +@propagate_inbounds function Base.getindex( + data::AbstractData{S}, + I::DataSpecificCartesianIndex, +) where {S} + @inbounds get_struct(parent(data), S, Val(field_dim(data)), I.I) +end +@propagate_inbounds function Base.setindex!( + data::AbstractData{S}, + val, + I::DataSpecificCartesianIndex, +) where {S} + @inbounds set_struct!( + parent(data), + convert(S, val), + Val(field_dim(data)), + I.I, + ) +end + +@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = + (Nij, Nij, Nk, 1, Nv, Nh) +@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh) +@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh) +@inline array_size(::DataF{S}) where {S} = (1,) +@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1) +@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1) +@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1) +@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = + (Nv, Nij, Nij, 1, Nh) +@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = + (Nv, Ni, 1, Nh)