Skip to content

Commit 93194ba

Browse files
Merge pull request #1837 from CliMA/ck/gpu_kernel_tweaks
Some minor improvement to cuda kernel utils
2 parents 2b4ecc8 + 4b48df1 commit 93194ba

File tree

6 files changed

+65
-44
lines changed

6 files changed

+65
-44
lines changed

ext/cuda/cuda_utils.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,38 +97,44 @@ function auto_launch!(
9797
end
9898

9999
"""
100-
kernel_indexes(n)
100+
thread_index()
101+
102+
Return the threadindex:
103+
```
104+
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
105+
```
106+
"""
107+
@inline thread_index() =
108+
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
109+
110+
"""
111+
kernel_indexes(tidx, n)
101112
Return a tuple of indexes from the kernel,
102-
where `n` is a tuple of max lengths along each
113+
where `tidx` is the cuda thread index and
114+
`n` is a tuple of max lengths along each
103115
dimension of the accessed data.
104116
"""
105-
function kernel_indexes(n::Tuple)
106-
tidx = (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
107-
inds = if 1 tidx prod(n)
108-
CartesianIndices(map(x -> Base.OneTo(x), n))[tidx].I
109-
else
110-
ntuple(x -> -1, length(n))
111-
end
112-
return inds
113-
end
117+
Base.@propagate_inbounds kernel_indexes(tidx, n::Tuple) =
118+
CartesianIndices(map(x -> Base.OneTo(x), n))[tidx]
114119

115120
"""
116-
valid_range(inds, n)
121+
valid_range(tidx, n::Int)
122+
117123
Returns a `Bool` indicating if the thread index
118-
is in the valid range, based on `inds` (the result
119-
of `kernel_indexes`) and `n`, a tuple of max lengths
120-
along each dimension of the accessed data.
124+
(`tidx`) is in the valid range, based on `n`, a
125+
tuple of max lengths along each dimension of the
126+
127+
accessed data.
121128
```julia
122129
function kernel!(data, n)
123-
inds = kernel_indexes(n)
124-
if valid_range(inds, n)
125-
do_work!(data[inds...])
130+
@inbounds begin
131+
tidx = thread_index()
132+
if valid_range(tidx, n)
133+
I = kernel_indexes(tidx, n)
134+
do_work!(data[I])
135+
end
126136
end
127137
end
128138
```
129139
"""
130-
valid_range(inds::NTuple, n::Tuple) = all(i -> 1 inds[i] n[i], 1:length(n))
131-
function valid_range(n::Tuple)
132-
inds = kernel_indexes(n)
133-
return all(i -> 1 inds[i] n[i], 1:length(n))
134-
end
140+
@inline valid_range(tidx, n) = 1 tidx n

ext/cuda/fill.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
cartesian_index(::AbstractData, inds) = CartesianIndex(inds)
2-
31
function knl_fill_flat!(dest::AbstractData, val)
4-
n = DataLayouts.universal_size(dest)
5-
inds = kernel_indexes(n)
6-
if valid_range(inds, n)
7-
I = cartesian_index(dest, inds)
8-
@inbounds dest[I] = val
2+
@inbounds begin
3+
tidx = thread_index()
4+
n = DataLayouts.universal_size(dest)
5+
if valid_range(tidx, prod(n))
6+
I = kernel_indexes(tidx, n)
7+
@inbounds dest[I] = val
8+
end
99
end
1010
return nothing
1111
end

ext/cuda/limiters.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ function compute_element_bounds_kernel!(
5454
::Val{Nj},
5555
) where {Ni, Nj}
5656
n = (Nv, Nh)
57-
if valid_range(n)
58-
inds = kernel_indexes(n)
59-
(v, h) = inds
57+
tidx = thread_index()
58+
@inbounds if valid_range(tidx, prod(n))
59+
(v, h) = kernel_indexes(tidx, n).I
6060
(; q_bounds) = limiter
6161
local q_min, q_max
6262
slab_ρq = slab(ρq, v, h)
@@ -118,9 +118,9 @@ function compute_neighbor_bounds_local_kernel!(
118118
) where {Ni, Nj}
119119

120120
n = (Nv, Nh)
121-
if valid_range(n)
122-
inds = kernel_indexes(n)
123-
(v, h) = inds
121+
tidx = thread_index()
122+
@inbounds if valid_range(tidx, prod(n))
123+
(v, h) = kernel_indexes(tidx, n).I
124124
(; q_bounds, q_bounds_nbr, ghost_buffer, rtol) = limiter
125125
slab_q_bounds = slab(q_bounds, v, h)
126126
q_min = slab_q_bounds[1]
@@ -188,9 +188,9 @@ function apply_limiter_kernel!(
188188
(; q_bounds_nbr, rtol) = limiter
189189
converged = true
190190
n = (Nv, Nh)
191-
if valid_range(n)
192-
inds = kernel_indexes(n)
193-
(v, h) = inds
191+
tidx = thread_index()
192+
@inbounds if valid_range(tidx, prod(n))
193+
(v, h) = kernel_indexes(tidx, n).I
194194

195195
slab_ρ = slab(ρ_data, v, h)
196196
slab_ρq = slab(ρq_data, v, h)

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ function multiple_field_solve_kernel!(
8585
) where {Nnames}
8686
@inbounds begin
8787
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
88-
tidx = (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
89-
if 1 tidx prod((Ni, Nj, Nh, Nnames))
90-
(i, j, h, iname) =
91-
CartesianIndices((1:Ni, 1:Nj, 1:Nh, 1:Nnames))[tidx].I
88+
tidx = thread_index()
89+
n = (Ni, Nj, Nh, Nnames)
90+
if valid_range(tidx, prod(n))
91+
(i, j, h, iname) = kernel_indexes(tidx, n).I
9292
generated_single_field_solve!(
9393
device,
9494
caches,

src/DataLayouts/DataLayouts.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ abstract type AbstractData{S} end
4646

4747
Base.size(data::AbstractData, i::Integer) = size(data)[i]
4848

49+
"""
50+
(Ni, Nj, Nf, Nv, Nh) = universal_size(data::AbstractData)
51+
52+
Returns dimensions in a universal
53+
format for all data layouts:
54+
- `Ni` number of spectral element nodal degrees of freedom in first horizontal direction
55+
- `Nj` number of spectral element nodal degrees of freedom in second horizontal direction
56+
- `Nf` number of field components
57+
- `Nv` number of vertical degrees of freedom
58+
- `Nh` number of horizontal elements
59+
60+
Note: this is similar to `Base.size`, except
61+
that `universal_size` does not return 1
62+
for the number of field components.
63+
"""
4964
function universal_size(data::AbstractData)
5065
s = size(data)
5166
return (s[1], s[2], ncomponents(data), s[4], s[5])

test/Spaces/distributed_cuda/ddss3.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,5 +147,5 @@ partition numbers
147147
end
148148
#! format: on
149149
p = @allocated Spaces.weighted_dss!(y0, dss_buffer)
150-
iamroot && @test p 10816
150+
iamroot && @test p 11072
151151
end

0 commit comments

Comments
 (0)