1
+ import ClimaCore. DataLayouts: AbstractDataSingleton
1
2
# To implement a single flexible mapreduce, let's define
2
3
# a `OnesArray` that has nothing, and always returns 1:
3
4
struct OnesArray{T, N} <: AbstractArray{T, N} end
@@ -38,6 +39,8 @@ function mapreduce_cuda(
38
39
n_ops_on_load = cld (nitems, nthreads) == 1 ? 0 : 7
39
40
effective_blksize = nthreads * (n_ops_on_load + 1 )
40
41
nblocks = cld (nitems, effective_blksize)
42
+ s = DataLayouts. singleton (data)
43
+ us = DataLayouts. UniversalSize (data)
41
44
42
45
reduce_cuda = CuArray {T} (undef, nblocks, Nf)
43
46
shmemsize = nthreads
@@ -49,6 +52,8 @@ function mapreduce_cuda(
49
52
pdata,
50
53
pwt,
51
54
n_ops_on_load,
55
+ s,
56
+ us,
52
57
Val (shmemsize),
53
58
)
54
59
# reduce block data
@@ -71,19 +76,22 @@ function mapreduce_cuda_kernel!(
71
76
pdata:: AbstractArray{T, N} ,
72
77
pwt:: AbstractArray{T, N} ,
73
78
n_ops_on_load:: Int ,
79
+ s:: AbstractDataSingleton ,
80
+ us:: DataLayouts.UniversalSize ,
74
81
:: Val{shmemsize} ,
75
82
) where {T, N, shmemsize}
76
83
blksize = blockDim (). x
77
84
nblk = gridDim (). x
78
85
tidx = threadIdx (). x
79
86
bidx = blockIdx (). x
80
87
fidx = blockIdx (). y
81
- dataview = _dataview (pdata, fidx)
88
+ dataview = _dataview (pdata, s, fidx)
82
89
effective_blksize = blksize * (n_ops_on_load + 1 )
83
90
gidx = _get_gidx (tidx, bidx, effective_blksize)
84
91
reduction = CUDA. CuStaticSharedArray (T, shmemsize)
85
92
reduction[tidx] = 0
86
- (Nv, Nij, Nf, Nh) = _get_dims (dataview)
93
+ (Nij, _, _, Nv, Nh) = DataLayouts. universal_size (us)
94
+ Nf = 1 # a view into `fidx` always gives a size of Nf = 1
87
95
nitems = Nv * Nij * Nij * Nf * Nh
88
96
89
97
# load shmem
@@ -107,29 +115,13 @@ end
107
115
@inline function _get_gidx (tidx, bidx, effective_blksize)
108
116
return tidx + (bidx - 1 ) * effective_blksize
109
117
end
110
- # for VF DataLayout
111
- @inline function _get_dims (pdata:: AbstractArray{FT, 2} ) where {FT}
112
- (Nv, Nf) = size (pdata)
113
- return (Nv, 1 , Nf, 1 )
114
- end
115
- @inline _dataview (pdata:: AbstractArray{FT, 2} , fidx) where {FT} =
116
- view (pdata, :, fidx: fidx)
117
-
118
- # for IJFH DataLayout
119
- @inline function _get_dims (pdata:: AbstractArray{FT, 4} ) where {FT}
120
- (Nij, _, Nf, Nh) = size (pdata)
121
- return (1 , Nij, Nf, Nh)
122
- end
123
- @inline _dataview (pdata:: AbstractArray{FT, 4} , fidx) where {FT} =
124
- view (pdata, :, :, fidx: fidx, :)
125
118
126
- # for VIJFH DataLayout
127
- @inline function _get_dims (pdata:: AbstractArray{FT, 5} ) where {FT}
128
- (Nv, Nij, _, Nf, Nh) = size (pdata)
129
- return (Nv, Nij, Nf, Nh)
119
+ @inline function _dataview (pdata:: AbstractArray , s:: AbstractDataSingleton , fidx)
120
+ fdim = DataLayouts. field_dim (s)
121
+ Ipre = ntuple (i -> Colon (), Val (fdim - 1 ))
122
+ Ipost = ntuple (i -> Colon (), Val (ndims (pdata) - fdim))
123
+ return @inbounds view (pdata, Ipre... , fidx: fidx, Ipost... )
130
124
end
131
- @inline _dataview (pdata:: AbstractArray{FT, 5} , fidx) where {FT} =
132
- view (pdata, :, :, :, fidx: fidx, :)
133
125
134
126
@inline function _cuda_reduce! (op, reduction, tidx, reduction_size, N)
135
127
if reduction_size > N
0 commit comments