Skip to content

Commit 81645c9

Browse files
Tune kernels for use with FastCartesianIndices (#2296)
1 parent f2c91ee commit 81645c9

10 files changed

+176
-407
lines changed

ext/cuda/data_layouts.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ Base.similar(
2727
dims::Dims{N},
2828
) where {T, N, B} = similar(CUDA.CuArray{T, N, B}, dims)
2929

30+
unval(::Val{CI}) where {CI} = CI
31+
unval(CI) = CI
32+
33+
@inline linear_thread_idx() =
34+
threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
35+
3036
include("data_layouts_fill.jl")
3137
include("data_layouts_copyto.jl")
3238
include("data_layouts_fused_copyto.jl")

ext/cuda/data_layouts_copyto.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
DataLayouts.device_dispatch(x::CUDA.CuArray) = ToCUDA()
22

3-
function knl_copyto!(dest, src, us, mask)
4-
I = if mask isa NoMask
5-
universal_index(dest)
6-
else
7-
masked_universal_index(mask)
8-
end
9-
if is_valid_index(dest, I, us)
3+
function knl_copyto!(dest, src, us, mask, cart_inds)
4+
tidx = linear_thread_idx()
5+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
6+
I = if mask isa NoMask
7+
unval(cart_inds)[tidx]
8+
else
9+
masked_universal_index(mask, cart_inds)
10+
end
1011
@inbounds dest[I] = src[I]
1112
end
1213
return nothing
1314
end
1415

1516
function knl_copyto_linear!(dest, src, us)
16-
i = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
17+
i = linear_thread_idx()
1718
if linear_is_valid_index(i, us)
1819
@inbounds dest[i] = src[i]
1920
end
@@ -32,13 +33,18 @@ if VERSION ≥ v"1.11.0-beta"
3233
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
3334
us = DataLayouts.UniversalSize(dest)
3435
if Nv > 0 && Nh > 0
35-
args = (dest, bc, us, mask)
36+
cart_inds = if mask isa NoMask
37+
cartesian_indices(us)
38+
else
39+
cartesian_indicies_mask(us, mask)
40+
end
41+
args = (dest, bc, us, mask, cart_inds)
3642
threads = threads_via_occupancy(knl_copyto!, args)
3743
n_max_threads = min(threads, get_N(us))
3844
p = if mask isa NoMask
39-
partition(dest, n_max_threads)
45+
linear_partition(prod(size(dest)), n_max_threads)
4046
else
41-
masked_partition(us, n_max_threads, mask)
47+
masked_partition(mask, n_max_threads, us)
4248
end
4349
auto_launch!(
4450
knl_copyto!,
@@ -72,13 +78,18 @@ else
7278
blocks_s = p.blocks,
7379
)
7480
else
75-
args = (dest, bc, us, mask)
81+
cart_inds = if mask isa NoMask
82+
cartesian_indices(us)
83+
else
84+
cartesian_indicies_mask(us, mask)
85+
end
86+
args = (dest, bc, us, mask, cart_inds)
7687
threads = threads_via_occupancy(knl_copyto!, args)
7788
n_max_threads = min(threads, get_N(us))
7889
p = if mask isa NoMask
79-
partition(dest, n_max_threads)
90+
linear_partition(prod(size(dest)), n_max_threads)
8091
else
81-
masked_partition(us, n_max_threads, mask)
92+
masked_partition(mask, n_max_threads, us)
8293
end
8394
auto_launch!(
8495
knl_copyto!,

ext/cuda/data_layouts_fill.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
1-
function knl_fill!(dest, val, us, mask)
2-
I = if mask isa NoMask
3-
universal_index(dest)
4-
else
5-
masked_universal_index(mask)
6-
end
7-
if is_valid_index(dest, I, us)
1+
function knl_fill!(dest, val, us, mask, cart_inds)
2+
tidx = linear_thread_idx()
3+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
4+
I = if mask isa NoMask
5+
unval(cart_inds)[tidx]
6+
else
7+
masked_universal_index(mask, cart_inds)
8+
end
89
@inbounds dest[I] = val
910
end
1011
return nothing
1112
end
1213

1314
function knl_fill_linear!(dest, val, us)
14-
i = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
15+
i = linear_thread_idx()
1516
if linear_is_valid_index(i, us)
1617
@inbounds dest[i] = val
1718
end
1819
return nothing
1920
end
2021

2122
function Base.fill!(dest::AbstractData, bc, to::ToCUDA, mask = NoMask())
22-
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
23+
(Ni, Nj, Nv, _, Nh) = DataLayouts.universal_size(dest)
2324
us = DataLayouts.UniversalSize(dest)
2425
if Nv > 0 && Nh > 0
2526
if !(VERSION v"1.11.0-beta") &&
@@ -36,13 +37,18 @@ function Base.fill!(dest::AbstractData, bc, to::ToCUDA, mask = NoMask())
3637
blocks_s = p.blocks,
3738
)
3839
else
39-
args = (dest, bc, us, mask)
40+
cart_inds = if mask isa NoMask
41+
cartesian_indices(us)
42+
else
43+
cartesian_indicies_mask(us, mask)
44+
end
45+
args = (dest, bc, us, mask, cart_inds)
4046
threads = threads_via_occupancy(knl_fill!, args)
4147
n_max_threads = min(threads, get_N(us))
4248
p = if mask isa NoMask
43-
partition(dest, n_max_threads)
49+
linear_partition(prod(size(dest)), n_max_threads)
4450
else
45-
masked_partition(us, n_max_threads, mask)
51+
masked_partition(mask, n_max_threads, us)
4652
end
4753
auto_launch!(
4854
knl_fill!,

ext/cuda/data_layouts_fused_copyto.jl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,44 @@
11
Base.@propagate_inbounds function rcopyto_at!(
22
pair::Pair{<:AbstractData, <:Any},
3-
I,
3+
cart_inds,
4+
tidx,
45
us,
56
)
67
dest, bc = pair.first, pair.second
7-
if is_valid_index(dest, I, us)
8+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
9+
I = unval(cart_inds)[tidx]
810
dest[I] = isascalar(bc) ? bc[] : bc[I]
911
end
1012
return nothing
1113
end
12-
Base.@propagate_inbounds function rcopyto_at!(pair::Pair{<:DataF, <:Any}, I, us)
14+
Base.@propagate_inbounds function rcopyto_at!(
15+
pair::Pair{<:DataF, <:Any},
16+
cart_inds,
17+
tidx,
18+
us,
19+
)
1320
dest, bc = pair.first, pair.second
14-
if is_valid_index(dest, I, us)
21+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
22+
I = unval(cart_inds)[tidx]
1523
bcI = isascalar(bc) ? bc[] : bc[I]
1624
dest[] = bcI
1725
end
1826
return nothing
1927
end
20-
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, us)
21-
rcopyto_at!(first(pairs), I, us)
22-
rcopyto_at!(Base.tail(pairs), I, us)
28+
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, cart_inds, tidx, us)
29+
rcopyto_at!(first(pairs), cart_inds, tidx, us)
30+
rcopyto_at!(Base.tail(pairs), cart_inds, tidx, us)
2331
end
24-
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, us) =
25-
rcopyto_at!(first(pairs), I, us)
26-
@inline rcopyto_at!(pairs::Tuple{}, I, us) = nothing
32+
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, cart_inds, tidx, us) =
33+
rcopyto_at!(first(pairs), cart_inds, tidx, us)
34+
@inline rcopyto_at!(pairs::Tuple{}, cart_inds, tidx, us) = nothing
2735

28-
function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us)
36+
function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us, cart_inds)
2937
@inbounds begin
30-
I = universal_index(dest1)
31-
if is_valid_index(dest1, I, us)
38+
tidx = linear_thread_idx()
39+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
3240
(; pairs) = fmbc
33-
rcopyto_at!(pairs, I, us)
41+
rcopyto_at!(pairs, cart_inds, tidx, us)
3442
end
3543
end
3644
return nothing
@@ -138,10 +146,11 @@ function launch_fused_copyto!(fmb::FusedMultiBroadcast)
138146
blocks_s = p.blocks,
139147
)
140148
else
141-
args = (fmb, dest1, us)
149+
cart_inds = cartesian_indices(us)
150+
args = (fmb, dest1, us, cart_inds)
142151
threads = threads_via_occupancy(knl_fused_copyto!, args)
143152
n_max_threads = min(threads, get_N(us))
144-
p = partition(dest1, n_max_threads)
153+
p = linear_partition(prod(size(dest1)), n_max_threads)
145154
auto_launch!(
146155
knl_fused_copyto!,
147156
args;

0 commit comments

Comments
 (0)