Skip to content

Implement and extend data-specific cartesian index #1902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import ClimaCore.Utilities: cart_ind, linear_ind
import ClimaCore.RecursiveApply:
⊠, ⊞, ⊟, radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
import ClimaCore.DataLayouts: DataSpecificCartesianIndex, array_size
import ClimaCore.DataLayouts: has_uniform_datalayouts

include(joinpath("cuda", "cuda_utils.jl"))
include(joinpath("cuda", "data_layouts.jl"))
Expand Down
116 changes: 28 additions & 88 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,88 +1,5 @@
DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA()

function knl_copyto!(dest, src)

i = CUDA.threadIdx().x
j = CUDA.threadIdx().y

h = CUDA.blockIdx().x
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z

if v <= size(dest, 4)
I = CartesianIndex((i, j, 1, v, h))
@inbounds dest[I] = src[I]
end
return nothing
end

function Base.copyto!(
dest::IJFH{S, Nij, Nh},
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh},
::ToCUDA,
) where {S, Nij, Nh}
if Nh > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (Nij, Nij),
blocks_s = (Nh, 1),
)
end
return dest
end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij, Nh},
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh},
::ToCUDA,
) where {S, Nv, Nij, Nh}
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Nv_blocks = cld(Nv, Nv_per_block)
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (Nij, Nij, Nv_per_block),
blocks_s = (Nh, Nv_blocks),
)
end
return dest
end

function Base.copyto!(
dest::VF{S, Nv},
bc::DataLayouts.BroadcastedUnionVF{S, Nv},
::ToCUDA,
) where {S, Nv}
if Nv > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (1, 1),
blocks_s = (1, Nv),
)
end
return dest
end

function Base.copyto!(
dest::DataF{S},
bc::DataLayouts.BroadcastedUnionDataF{S},
::ToCUDA,
) where {S}
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (1, 1),
blocks_s = (1, 1),
)
return dest
end

import ClimaCore.DataLayouts: isascalar
function knl_copyto_flat!(dest::AbstractData, bc, us)
@inbounds begin
Expand All @@ -96,24 +13,47 @@ function knl_copyto_flat!(dest::AbstractData, bc, us)
return nothing
end

function knl_copyto_flat_specialized!(dest::AbstractData, bc, us)
@inbounds begin
tidx = thread_index()
if tidx ≤ get_N(us)
n = array_size(dest)
CIS = CartesianIndices(map(x -> Base.OneTo(x), n))
I = DataSpecificCartesianIndex(CIS[tidx])
dest[I] = bc[I]
end
end
return nothing
end

function cuda_copyto!(dest::AbstractData, bc)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
us = DataLayouts.UniversalSize(dest)
if has_uniform_datalayouts(bc)
auto_launch!(
knl_copyto_flat_specialized!,
(dest, bc, us),
dest;
auto = true,
)
else
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
end
end
return dest
end

# TODO: can we use CUDA's luanch configuration for all data layouts?
# Currently, it seems to have a slight performance degradation.
#! format: off
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VIFH{S, Nv, Ni, Nh}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, Nh}, ::ToCUDA) where {S, Nv, Ni, Nh} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
#! format: on
5 changes: 3 additions & 2 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ function knl_fill_flat!(dest::AbstractData, val, us)
@inbounds begin
tidx = thread_index()
if tidx ≤ get_N(us)
n = size(dest)
I = kernel_indexes(tidx, n)
n = array_size(dest)
CIS = CartesianIndices(map(x -> Base.OneTo(x), n))
I = DataSpecificCartesianIndex(CIS[tidx])
@inbounds dest[I] = val
end
end
Expand Down
16 changes: 16 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ include("struct.jl")

abstract type AbstractData{S} end

abstract type AbstractDataSpecificCartesianIndex{N} <:
Base.AbstractCartesianIndex{N} end

"""
DataSpecificCartesianIndex{N} <: AbstractDataSpecificCartesianIndex{N}

A DataLayout-specific CartesianIndex, which is used to provide support for
`getindex` for DataLayouts such that indices are not swapped. This is used
to improve memory access patterns on GPUs.
"""
struct DataSpecificCartesianIndex{N} <: AbstractDataSpecificCartesianIndex{N}
I::CartesianIndex{N}
end

@inline Base.size(data::AbstractData, i::Integer) = size(data)[i]
@inline Base.size(data::AbstractData) = universal_size(data)

Expand Down Expand Up @@ -1354,5 +1368,7 @@ include("copyto.jl")
include("fused_copyto.jl")
include("fill.jl")
include("mapreduce.jl")
include("cartesian_index.jl")
include("has_uniform_datalayouts.jl")

end # module
115 changes: 115 additions & 0 deletions src/DataLayouts/cartesian_index.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#! format: off
# ============================================================ Adapted from Base.Broadcast (julia version 1.10.4)
@inline function Base.getindex(bc::Base.Broadcast.Broadcasted, I::DataSpecificCartesianIndex)
@boundscheck checkbounds(bc, I)
@inbounds _broadcast_getindex(bc, I)
end

# This code path is only ever reached when all datalayouts in
# the broadcasted object are the same (e.g., ::VIJFH, ::VIJFH)
# They may have different type parameters, but this means that
# `permute_axes` will still produce the correct axes for all
# datalayouts.
@inline Base.checkbounds(bc::Base.Broadcast.Broadcasted, I::DataSpecificCartesianIndex) =
# Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) # from Base
Base.checkbounds_indices(Bool, permute_axes(axes(bc), first_datalayout_in_bc(bc)), (I.I,)) || Base.throw_boundserror(bc, (I,))

Base.@propagate_inbounds _broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices
Base.@propagate_inbounds _broadcast_getindex(::Ref{Type{T}}, I) where {T} = T
# Tuples are statically known to be singleton or vector-like
Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I) = A[1]
Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I) = A[I[1]]
# Everything else falls back to dynamically dropping broadcasted indices based upon its axes
# Base.@propagate_inbounds _broadcast_getindex(A, I) = A[Base.Broadcast.newindex(A, I)]
Base.@propagate_inbounds _broadcast_getindex(A, I) = A[I]

# For Broadcasted
Base.@propagate_inbounds function _broadcast_getindex(bc::Base.Broadcast.Broadcasted{<:Any,<:Any,<:Any,<:Any}, I)
args = _getindex(bc.args, I)
return _broadcast_getindex_evalf(bc.f, args...)
end
# Hack around losing Type{T} information in the final args tuple. Julia actually
# knows (in `code_typed`) the _value_ of these types, statically displaying them,
# but inference is currently skipping inferring the type of the types as they are
# transiently placed in a tuple as the argument list is lispily constructed. These
# additional methods recover type stability when a `Type` appears in one of the
# first two arguments of a function.
Base.@propagate_inbounds function _broadcast_getindex(bc::Base.Broadcast.Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Ref{Type{T}},Vararg{Any}}}, I) where {T}
args = _getindex(Base.tail(bc.args), I)
return _broadcast_getindex_evalf(bc.f, T, args...)
end
Base.@propagate_inbounds function _broadcast_getindex(bc::Base.Broadcast.Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Any,Ref{Type{T}},Vararg{Any}}}, I) where {T}
arg1 = _broadcast_getindex(bc.args[1], I)
args = _getindex(Base.tail(Base.tail(bc.args)), I)
return _broadcast_getindex_evalf(bc.f, arg1, T, args...)
end
Base.@propagate_inbounds function _broadcast_getindex(bc::Base.Broadcast.Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Ref{Type{T}},Ref{Type{S}},Vararg{Any}}}, I) where {T,S}
args = _getindex(Base.tail(Base.tail(bc.args)), I)
return _broadcast_getindex_evalf(bc.f, T, S, args...)
end

# Utilities for _broadcast_getindex
Base.@propagate_inbounds _getindex(args::Tuple, I) = (_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...)
Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) = (_broadcast_getindex(args[1], I),)
Base.@propagate_inbounds _getindex(args::Tuple{}, I) = ()

@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any,N}) where {Tf,N} = f(args...) # not propagate_inbounds
# ============================================================

#! format: on
# Datalayouts
@propagate_inbounds function Base.getindex(
data::AbstractData{S},
I::DataSpecificCartesianIndex,
) where {S}
@inbounds get_struct(parent(data), S, Val(field_dim(data)), I.I)
end
@propagate_inbounds function Base.setindex!(
data::AbstractData{S},
val,
I::DataSpecificCartesianIndex,
) where {S}
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(field_dim(data)),
I.I,
)
end

# Returns the size of the backing array.
@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} =
(Nij, Nij, Nk, 1, Nv, Nh)
@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh)
@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh)
@inline array_size(::DataF{S}) where {S} = (1,)
@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1)
@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1)
@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1)
@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} =
(Nv, Nij, Nij, 1, Nh)
@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} =
(Nv, Ni, 1, Nh)

#####
##### Helpers to support `Base.checkbounds`
#####

# Converts axes(::AbstractData) to a Data-specific axes
@inline permute_axes(A, data::AbstractData) =
map(x -> A[x], perm_to_array(data))

# axes for IJF and IF exclude the field dimension
@inline permute_axes(A, ::IJF) = (A[1], A[2], Base.OneTo(1))
@inline permute_axes(A, ::IF) = (A[1], Base.OneTo(1))

# Permute dimensions of size(data) (the universal size) to
# output size of array for example, this should satisfy:
# @test size(parent(data)) == map(size(data)[i], perm_to_array(data))
@inline perm_to_array(::IJKFVH) = (1, 2, 3, 4, 5)
@inline perm_to_array(::IJFH) = (1, 2, 3, 5)
@inline perm_to_array(::IFH) = (1, 3, 5)
@inline perm_to_array(::DataF) = (3,)
@inline perm_to_array(::VF) = (4, 3)
@inline perm_to_array(::VIJFH) = (4, 1, 2, 3, 5)
@inline perm_to_array(::VIFH) = (4, 1, 3, 5)
62 changes: 62 additions & 0 deletions src/DataLayouts/has_uniform_datalayouts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
x1 = first_datalayout_in_bc(args[1], rargs...)
x1 isa AbstractData && return x1
return first_datalayout_in_bc(Base.tail(args), rargs...)
end

@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
first_datalayout_in_bc(args[1], rargs...)
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_datalayout_in_bc(x) = nothing
@inline first_datalayout_in_bc(x::AbstractData) = x

@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
first_datalayout_in_bc(bc.args)

@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
truesofar &&
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)

@inline _has_uniform_datalayouts_args(
truesofar,
start,
args::Tuple{Any},
rargs...,
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
truesofar

@inline function _has_uniform_datalayouts(
truesofar,
start,
bc::Base.Broadcast.Broadcasted,
)
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
end
for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH)
@eval begin
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
end
end
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar

"""
has_uniform_datalayouts

Find the first datalayout in the broadcast expression (BCE),
and compares against every other datalayout in the BCE. Returns
- `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH)
- `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH)

Note: a broadcasted object can have different _types_,
e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}`
but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`.
"""
function has_uniform_datalayouts end

@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)

@inline has_uniform_datalayouts(bc::AbstractData) = true
Loading
Loading