Skip to content

Commit 96bb8f4

Browse files
Merge pull request #2053 from CliMA/ck/refactor_mapreduce
Reduce use of internals in cuda mapreduce
2 parents 67dd505 + 0f9ad8b commit 96bb8f4

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

ext/cuda/data_layouts_mapreduce.jl

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ClimaCore.DataLayouts: AbstractDataSingleton
12
# To implement a single flexible mapreduce, let's define
23
# a `OnesArray` that has nothing, and always returns 1:
34
struct OnesArray{T, N} <: AbstractArray{T, N} end
@@ -38,6 +39,8 @@ function mapreduce_cuda(
3839
n_ops_on_load = cld(nitems, nthreads) == 1 ? 0 : 7
3940
effective_blksize = nthreads * (n_ops_on_load + 1)
4041
nblocks = cld(nitems, effective_blksize)
42+
s = DataLayouts.singleton(data)
43+
us = DataLayouts.UniversalSize(data)
4144

4245
reduce_cuda = CuArray{T}(undef, nblocks, Nf)
4346
shmemsize = nthreads
@@ -49,6 +52,8 @@ function mapreduce_cuda(
4952
pdata,
5053
pwt,
5154
n_ops_on_load,
55+
s,
56+
us,
5257
Val(shmemsize),
5358
)
5459
# reduce block data
@@ -71,19 +76,22 @@ function mapreduce_cuda_kernel!(
7176
pdata::AbstractArray{T, N},
7277
pwt::AbstractArray{T, N},
7378
n_ops_on_load::Int,
79+
s::AbstractDataSingleton,
80+
us::DataLayouts.UniversalSize,
7481
::Val{shmemsize},
7582
) where {T, N, shmemsize}
7683
blksize = blockDim().x
7784
nblk = gridDim().x
7885
tidx = threadIdx().x
7986
bidx = blockIdx().x
8087
fidx = blockIdx().y
81-
dataview = _dataview(pdata, fidx)
88+
dataview = _dataview(pdata, s, fidx)
8289
effective_blksize = blksize * (n_ops_on_load + 1)
8390
gidx = _get_gidx(tidx, bidx, effective_blksize)
8491
reduction = CUDA.CuStaticSharedArray(T, shmemsize)
8592
reduction[tidx] = 0
86-
(Nv, Nij, Nf, Nh) = _get_dims(dataview)
93+
(Nij, _, _, Nv, Nh) = DataLayouts.universal_size(us)
94+
Nf = 1 # a view into `fidx` always gives a size of Nf = 1
8795
nitems = Nv * Nij * Nij * Nf * Nh
8896

8997
# load shmem
@@ -107,29 +115,13 @@ end
107115
@inline function _get_gidx(tidx, bidx, effective_blksize)
108116
return tidx + (bidx - 1) * effective_blksize
109117
end
110-
# for VF DataLayout
111-
@inline function _get_dims(pdata::AbstractArray{FT, 2}) where {FT}
112-
(Nv, Nf) = size(pdata)
113-
return (Nv, 1, Nf, 1)
114-
end
115-
@inline _dataview(pdata::AbstractArray{FT, 2}, fidx) where {FT} =
116-
view(pdata, :, fidx:fidx)
117-
118-
# for IJFH DataLayout
119-
@inline function _get_dims(pdata::AbstractArray{FT, 4}) where {FT}
120-
(Nij, _, Nf, Nh) = size(pdata)
121-
return (1, Nij, Nf, Nh)
122-
end
123-
@inline _dataview(pdata::AbstractArray{FT, 4}, fidx) where {FT} =
124-
view(pdata, :, :, fidx:fidx, :)
125118

126-
# for VIJFH DataLayout
127-
@inline function _get_dims(pdata::AbstractArray{FT, 5}) where {FT}
128-
(Nv, Nij, _, Nf, Nh) = size(pdata)
129-
return (Nv, Nij, Nf, Nh)
119+
@inline function _dataview(pdata::AbstractArray, s::AbstractDataSingleton, fidx)
120+
fdim = DataLayouts.field_dim(s)
121+
Ipre = ntuple(i -> Colon(), Val(fdim - 1))
122+
Ipost = ntuple(i -> Colon(), Val(ndims(pdata) - fdim))
123+
return @inbounds view(pdata, Ipre..., fidx:fidx, Ipost...)
130124
end
131-
@inline _dataview(pdata::AbstractArray{FT, 5}, fidx) where {FT} =
132-
view(pdata, :, :, :, fidx:fidx, :)
133125

134126
@inline function _cuda_reduce!(op, reduction, tidx, reduction_size, N)
135127
if reduction_size > N

0 commit comments

Comments
 (0)