Skip to content

Commit 137ac8b

Browse files
Implement data-specific cartesian index
Extend data-specific cartesian index
1 parent 7ab0acf commit 137ac8b

File tree

9 files changed

+307
-91
lines changed

9 files changed

+307
-91
lines changed

ext/ClimaCoreCUDAExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import ClimaCore.Utilities: cart_ind, linear_ind
1717
import ClimaCore.RecursiveApply:
1818
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax
1919
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
20+
import ClimaCore.DataLayouts: DataSpecificCartesianIndex, array_size
21+
import ClimaCore.DataLayouts: has_uniform_datalayouts
2022

2123
include(joinpath("cuda", "cuda_utils.jl"))
2224
include(joinpath("cuda", "data_layouts.jl"))

ext/cuda/data_layouts_copyto.jl

Lines changed: 28 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,5 @@
11
DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA()
22

3-
function knl_copyto!(dest, src)
4-
5-
i = CUDA.threadIdx().x
6-
j = CUDA.threadIdx().y
7-
8-
h = CUDA.blockIdx().x
9-
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z
10-
11-
if v <= size(dest, 4)
12-
I = CartesianIndex((i, j, 1, v, h))
13-
@inbounds dest[I] = src[I]
14-
end
15-
return nothing
16-
end
17-
18-
function Base.copyto!(
19-
dest::IJFH{S, Nij, Nh},
20-
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh},
21-
::ToCUDA,
22-
) where {S, Nij, Nh}
23-
if Nh > 0
24-
auto_launch!(
25-
knl_copyto!,
26-
(dest, bc),
27-
dest;
28-
threads_s = (Nij, Nij),
29-
blocks_s = (Nh, 1),
30-
)
31-
end
32-
return dest
33-
end
34-
35-
function Base.copyto!(
36-
dest::VIJFH{S, Nv, Nij, Nh},
37-
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh},
38-
::ToCUDA,
39-
) where {S, Nv, Nij, Nh}
40-
if Nv > 0 && Nh > 0
41-
Nv_per_block = min(Nv, fld(256, Nij * Nij))
42-
Nv_blocks = cld(Nv, Nv_per_block)
43-
auto_launch!(
44-
knl_copyto!,
45-
(dest, bc),
46-
dest;
47-
threads_s = (Nij, Nij, Nv_per_block),
48-
blocks_s = (Nh, Nv_blocks),
49-
)
50-
end
51-
return dest
52-
end
53-
54-
function Base.copyto!(
55-
dest::VF{S, Nv},
56-
bc::DataLayouts.BroadcastedUnionVF{S, Nv},
57-
::ToCUDA,
58-
) where {S, Nv}
59-
if Nv > 0
60-
auto_launch!(
61-
knl_copyto!,
62-
(dest, bc),
63-
dest;
64-
threads_s = (1, 1),
65-
blocks_s = (1, Nv),
66-
)
67-
end
68-
return dest
69-
end
70-
71-
function Base.copyto!(
72-
dest::DataF{S},
73-
bc::DataLayouts.BroadcastedUnionDataF{S},
74-
::ToCUDA,
75-
) where {S}
76-
auto_launch!(
77-
knl_copyto!,
78-
(dest, bc),
79-
dest;
80-
threads_s = (1, 1),
81-
blocks_s = (1, 1),
82-
)
83-
return dest
84-
end
85-
863
import ClimaCore.DataLayouts: isascalar
874
function knl_copyto_flat!(dest::AbstractData, bc, us)
885
@inbounds begin
@@ -96,24 +13,46 @@ function knl_copyto_flat!(dest::AbstractData, bc, us)
9613
return nothing
9714
end
9815

16+
function knl_copyto_flat_specialized!(dest::AbstractData, bc, us)
17+
@inbounds begin
18+
tidx = thread_index()
19+
if tidx get_N(us)
20+
n = array_size(dest)
21+
CIS = CartesianIndices(map(x -> Base.OneTo(x), n))
22+
I = DataSpecificCartesianIndex(CIS[tidx])
23+
dest[I] = bc[I]
24+
end
25+
end
26+
return nothing
27+
end
28+
9929
function cuda_copyto!(dest::AbstractData, bc)
10030
(_, _, Nv, Nh) = DataLayouts.universal_size(dest)
101-
us = DataLayouts.UniversalSize(dest)
10231
if Nv > 0 && Nh > 0
103-
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
32+
us = DataLayouts.UniversalSize(dest)
33+
if has_uniform_datalayouts(bc)
34+
auto_launch!(
35+
knl_copyto_flat_specialized!,
36+
(dest, bc, us),
37+
dest;
38+
auto = true,
39+
)
40+
else
41+
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
42+
end
10443
end
10544
return dest
10645
end
10746

10847
# TODO: can we use CUDA's luanch configuration for all data layouts?
10948
# Currently, it seems to have a slight performance degradation.
11049
#! format: off
111-
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
50+
Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
11251
Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc)
11352
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
11453
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
11554
Base.copyto!(dest::VIFH{S, Nv, Ni, Nh}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, Nh}, ::ToCUDA) where {S, Nv, Ni, Nh} = cuda_copyto!(dest, bc)
116-
# Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
117-
# Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
118-
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
55+
Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
56+
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
57+
Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
11958
#! format: on

ext/cuda/data_layouts_fill.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ function knl_fill_flat!(dest::AbstractData, val, us)
22
@inbounds begin
33
tidx = thread_index()
44
if tidx get_N(us)
5-
n = size(dest)
6-
I = kernel_indexes(tidx, n)
5+
n = array_size(dest)
6+
CIS = CartesianIndices(map(x -> Base.OneTo(x), n))
7+
I = DataSpecificCartesianIndex(CIS[tidx])
78
@inbounds dest[I] = val
89
end
910
end

src/DataLayouts/DataLayouts.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ include("struct.jl")
4949

5050
abstract type AbstractData{S} end
5151

52+
abstract type AbstractDataSpecificCartesianIndex{N} <:
53+
Base.AbstractCartesianIndex{N} end
54+
55+
"""
56+
DataSpecificCartesianIndex{N} <: AbstractDataSpecificCartesianIndex{N}
57+
58+
A DataLayout-specific CartesianIndex, which is used to provide support for
59+
`getindex` for DataLayouts such that indices are not swapped. This is used
60+
to improve memory access patterns on GPUs.
61+
"""
62+
struct DataSpecificCartesianIndex{N} <: AbstractDataSpecificCartesianIndex{N}
63+
I::CartesianIndex{N}
64+
end
65+
5266
Base.size(data::AbstractData, i::Integer) = size(data)[i]
5367

5468
"""
@@ -1063,6 +1077,7 @@ function VIJFH{S, Nv, Nij, Nh}(
10631077
array::AbstractArray{T, 5},
10641078
) where {S, Nv, Nij, Nh, T}
10651079
check_basetype(T, S)
1080+
@assert size(array, 1) == Nv
10661081
@assert size(array, 2) == size(array, 3) == Nij
10671082
@assert size(array, 4) == typesize(T, S)
10681083
@assert size(array, 5) == Nh
@@ -1223,6 +1238,7 @@ function VIFH{S, Nv, Ni, Nh}(
12231238
array::AbstractArray{T, 4},
12241239
) where {S, Nv, Ni, Nh, T}
12251240
check_basetype(T, S)
1241+
@assert size(array, 1) == Nv
12261242
@assert size(array, 2) == Ni
12271243
@assert size(array, 3) == typesize(T, S)
12281244
@assert size(array, 4) == Nh
@@ -1604,6 +1620,8 @@ include("copyto.jl")
16041620
include("fused_copyto.jl")
16051621
include("fill.jl")
16061622
include("mapreduce.jl")
1623+
include("cartesian_index.jl")
1624+
include("has_uniform_datalayouts.jl")
16071625

16081626
slab_index(i, j) = CartesianIndex(i, j, 1, 1, 1)
16091627
slab_index(i) = CartesianIndex(i, 1, 1, 1, 1)

src/DataLayouts/cartesian_index.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#! format: off
2+
# ============================================================ Adapted from Base.Broadcast (julia version 1.10.4)
3+
@inline function Base.getindex(bc::Base.Broadcast.Broadcasted, I::DataSpecificCartesianIndex)
4+
@boundscheck checkbounds(bc, I)
5+
@inbounds _broadcast_getindex(bc, I)
6+
end
7+
8+
# This code path is only ever reached when all datalayouts in
9+
# the broadcasted object are the same (e.g., ::VIJFH, ::VIJFH)
10+
# They may have different type parameters, but this means that
11+
# `permute_axes` will still produce the correct axes for all
12+
# datalayouts.
13+
@inline Base.checkbounds(bc::Base.Broadcast.Broadcasted, I::DataSpecificCartesianIndex) =
14+
# Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) # from Base
15+
Base.checkbounds_indices(Bool, permute_axes(axes(bc), first_datalayout_in_bc(bc)), (I.I,)) || Base.throw_boundserror(bc, (I,))
16+
17+
Base.@propagate_inbounds _broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices
18+
Base.@propagate_inbounds _broadcast_getindex(::Ref{Type{T}}, I) where {T} = T
19+
# Tuples are statically known to be singleton or vector-like
20+
Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I) = A[1]
21+
Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I) = A[I[1]]
22+
# Everything else falls back to dynamically dropping broadcasted indices based upon its axes
23+
# Base.@propagate_inbounds _broadcast_getindex(A, I) = A[Base.Broadcast.newindex(A, I)]
24+
Base.@propagate_inbounds _broadcast_getindex(A, I) = A[I]
25+
26+
# For Broadcasted
27+
Base.@propagate_inbounds function _broadcast_getindex(bc::Base.Broadcast.Broadcasted{<:Any,<:Any,<:Any,<:Any}, I)
28+
args = _getindex(bc.args, I)
29+
return _broadcast_getindex_evalf(bc.f, args...)
30+
end
31+
# Hack around losing Type{T} information in the final args tuple. Julia actually
32+
# knows (in `code_typed`) the _value_ of these types, statically displaying them,
33+
# but inference is currently skipping inferring the type of the types as they are
34+
# transiently placed in a tuple as the argument list is lispily constructed. These
35+
# additional methods recover type stability when a `Type` appears in one of the
36+
# first two arguments of a function.
37+
Base.@propagate_inbounds function _broadcast_getindex(bc::Base.Broadcast.Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Ref{Type{T}},Vararg{Any}}}, I) where {T}
38+
args = _getindex(Base.tail(bc.args), I)
39+
return _broadcast_getindex_evalf(bc.f, T, args...)
40+
end
41+
Base.@propagate_inbounds function _broadcast_getindex(bc::Base.Broadcast.Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Any,Ref{Type{T}},Vararg{Any}}}, I) where {T}
42+
arg1 = _broadcast_getindex(bc.args[1], I)
43+
args = _getindex(Base.tail(Base.tail(bc.args)), I)
44+
return _broadcast_getindex_evalf(bc.f, arg1, T, args...)
45+
end
46+
Base.@propagate_inbounds function _broadcast_getindex(bc::Base.Broadcast.Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Ref{Type{T}},Ref{Type{S}},Vararg{Any}}}, I) where {T,S}
47+
args = _getindex(Base.tail(Base.tail(bc.args)), I)
48+
return _broadcast_getindex_evalf(bc.f, T, S, args...)
49+
end
50+
51+
# Utilities for _broadcast_getindex
52+
Base.@propagate_inbounds _getindex(args::Tuple, I) = (_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...)
53+
Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) = (_broadcast_getindex(args[1], I),)
54+
Base.@propagate_inbounds _getindex(args::Tuple{}, I) = ()
55+
56+
@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any,N}) where {Tf,N} = f(args...) # not propagate_inbounds
57+
# ============================================================
58+
59+
#! format: on
60+
# Datalayouts
61+
@propagate_inbounds function Base.getindex(
62+
data::AbstractData{S},
63+
I::DataSpecificCartesianIndex,
64+
) where {S}
65+
@inbounds get_struct(parent(data), S, Val(field_dim(data)), I.I)
66+
end
67+
@propagate_inbounds function Base.setindex!(
68+
data::AbstractData{S},
69+
val,
70+
I::DataSpecificCartesianIndex,
71+
) where {S}
72+
@inbounds set_struct!(
73+
parent(data),
74+
convert(S, val),
75+
Val(field_dim(data)),
76+
I.I,
77+
)
78+
end
79+
80+
# Returns the size of the backing array.
81+
@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} =
82+
(Nij, Nij, Nk, 1, Nv, Nh)
83+
@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh)
84+
@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh)
85+
@inline array_size(::DataF{S}) where {S} = (1,)
86+
@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1)
87+
@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1)
88+
@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1)
89+
@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} =
90+
(Nv, Nij, Nij, 1, Nh)
91+
@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} =
92+
(Nv, Ni, 1, Nh)
93+
94+
#####
95+
##### Helpers to support `Base.checkbounds`
96+
#####
97+
98+
# Converts axes(::AbstractData) to a Data-specific axes
99+
@inline permute_axes(A, data::AbstractData) =
100+
map(x -> A[x], perm_to_array(data))
101+
102+
# axes for IJF and IF exclude the field dimension
103+
@inline permute_axes(A, ::IJF) = (A[1], A[2], Base.OneTo(1))
104+
@inline permute_axes(A, ::IF) = (A[1], Base.OneTo(1))
105+
106+
# Permute dimensions of size(data) (the universal size) to
107+
# output size of array for example, this should satisfy:
108+
# @test size(parent(data)) == map(size(data)[i], perm_to_array(data))
109+
@inline perm_to_array(::IJKFVH) = (1, 2, 3, 4, 5)
110+
@inline perm_to_array(::IJFH) = (1, 2, 3, 5)
111+
@inline perm_to_array(::IFH) = (1, 3, 5)
112+
@inline perm_to_array(::DataF) = (3,)
113+
@inline perm_to_array(::VF) = (4, 3)
114+
@inline perm_to_array(::VIJFH) = (4, 1, 2, 3, 5)
115+
@inline perm_to_array(::VIFH) = (4, 1, 3, 5)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
2+
x1 = first_datalayout_in_bc(args[1], rargs...)
3+
x1 isa AbstractData && return x1
4+
return first_datalayout_in_bc(Base.tail(args), rargs...)
5+
end
6+
7+
@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
8+
first_datalayout_in_bc(args[1], rargs...)
9+
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
10+
@inline first_datalayout_in_bc(x) = nothing
11+
@inline first_datalayout_in_bc(x::AbstractData) = x
12+
13+
@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
14+
first_datalayout_in_bc(bc.args)
15+
16+
@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
17+
truesofar &&
18+
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
19+
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)
20+
21+
@inline _has_uniform_datalayouts_args(
22+
truesofar,
23+
start,
24+
args::Tuple{Any},
25+
rargs...,
26+
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
27+
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
28+
truesofar
29+
30+
@inline function _has_uniform_datalayouts(
31+
truesofar,
32+
start,
33+
bc::Base.Broadcast.Broadcasted,
34+
)
35+
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
36+
end
37+
for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH)
38+
@eval begin
39+
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
40+
end
41+
end
42+
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
43+
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar
44+
45+
"""
46+
has_uniform_datalayouts
47+
48+
Find the first datalayout in the broadcast expression (BCE),
49+
and compares against every other datalayout in the BCE. Returns
50+
- `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH)
51+
- `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH)
52+
53+
Note: a broadcasted object can have different _types_,
54+
e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}`
55+
but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`.
56+
"""
57+
function has_uniform_datalayouts end
58+
59+
@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
60+
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)
61+
62+
@inline has_uniform_datalayouts(bc::AbstractData) = true

0 commit comments

Comments
 (0)