Skip to content

Commit 6cce85f

Browse files
Add linear index support for pointwise kernels
1 parent c2451ad commit 6cce85f

17 files changed

+767
-164
lines changed

benchmarks/scripts/indexing_and_static_ndranges.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ setting where linear indexing is allowed.
3030
nearly the same benefit as linear indexing.
3131
3232
# References:
33-
- https://githubSR.com/CliMA/ClimaCore.jl/issues/1889
34-
- https://githubSR.com/JuliaLang/julia/issues/28126
35-
- https://githubSR.com/JuliaLang/julia/issues/32051
33+
- https://github.com/CliMA/ClimaCore.jl/issues/1889
34+
- https://github.com/JuliaLang/julia/issues/28126
35+
- https://github.com/JuliaLang/julia/issues/32051
3636
3737
# Benchmark results:
3838

ext/cuda/data_layouts.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,16 @@ function Adapt.adapt_structure(
5353
end,
5454
)
5555
end
56+
57+
import Adapt
58+
import CUDA
59+
function Adapt.adapt_structure(
60+
to::CUDA.KernelAdaptor,
61+
bc::DataLayouts.NonExtrudedBroadcasted{Style},
62+
) where {Style}
63+
DataLayouts.NonExtrudedBroadcasted{Style}(
64+
adapt_f(to, bc.f),
65+
Adapt.adapt(to, bc.args),
66+
Adapt.adapt(to, bc.axes),
67+
)
68+
end

ext/cuda/data_layouts_copyto.jl

Lines changed: 28 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,9 @@
1+
import ClimaCore.DataLayouts:
2+
to_non_extruded_broadcasted, has_uniform_datalayouts
13
DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA()
24

3-
function knl_copyto!(dest, src)
4-
5-
i = CUDA.threadIdx().x
6-
j = CUDA.threadIdx().y
7-
8-
h = CUDA.blockIdx().x
9-
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z
10-
11-
if v <= size(dest, 4)
12-
I = CartesianIndex((i, j, 1, v, h))
13-
@inbounds dest[I] = src[I]
14-
end
15-
return nothing
16-
end
17-
18-
function Base.copyto!(
19-
dest::IJFH{S, Nij, Nh},
20-
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh},
21-
::ToCUDA,
22-
) where {S, Nij, Nh}
23-
if Nh > 0
24-
auto_launch!(
25-
knl_copyto!,
26-
(dest, bc),
27-
dest;
28-
threads_s = (Nij, Nij),
29-
blocks_s = (Nh, 1),
30-
)
31-
end
32-
return dest
33-
end
34-
35-
function Base.copyto!(
36-
dest::VIJFH{S, Nv, Nij, Nh},
37-
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh},
38-
::ToCUDA,
39-
) where {S, Nv, Nij, Nh}
40-
if Nv > 0 && Nh > 0
41-
Nv_per_block = min(Nv, fld(256, Nij * Nij))
42-
Nv_blocks = cld(Nv, Nv_per_block)
43-
auto_launch!(
44-
knl_copyto!,
45-
(dest, bc),
46-
dest;
47-
threads_s = (Nij, Nij, Nv_per_block),
48-
blocks_s = (Nh, Nv_blocks),
49-
)
50-
end
51-
return dest
52-
end
53-
54-
function Base.copyto!(
55-
dest::VF{S, Nv},
56-
bc::DataLayouts.BroadcastedUnionVF{S, Nv},
57-
::ToCUDA,
58-
) where {S, Nv}
59-
if Nv > 0
60-
auto_launch!(
61-
knl_copyto!,
62-
(dest, bc),
63-
dest;
64-
threads_s = (1, 1),
65-
blocks_s = (1, Nv),
66-
)
67-
end
68-
return dest
69-
end
70-
71-
function Base.copyto!(
72-
dest::DataF{S},
73-
bc::DataLayouts.BroadcastedUnionDataF{S},
74-
::ToCUDA,
75-
) where {S}
76-
auto_launch!(
77-
knl_copyto!,
78-
(dest, bc),
79-
dest;
80-
threads_s = (1, 1),
81-
blocks_s = (1, 1),
82-
)
83-
return dest
84-
end
85-
865
import ClimaCore.DataLayouts: isascalar
87-
function knl_copyto_flat!(dest::AbstractData, bc, us)
6+
function knl_copyto_cart!(dest::AbstractData, bc, us)
887
@inbounds begin
898
tidx = thread_index()
909
if tidx get_N(us)
@@ -96,24 +15,43 @@ function knl_copyto_flat!(dest::AbstractData, bc, us)
9615
return nothing
9716
end
9817

18+
function knl_copyto_linear!(dest::AbstractData, bc, us)
19+
@inbounds begin
20+
tidx = thread_index()
21+
if tidx get_N(us)
22+
dest[tidx] = bc[tidx]
23+
end
24+
end
25+
return nothing
26+
end
27+
28+
function knl_copyto_linear!(dest::DataF, bc, us)
29+
@inbounds dest[] = bc[tidx]
30+
return nothing
31+
end
32+
9933
function cuda_copyto!(dest::AbstractData, bc)
10034
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
35+
(Nv > 0 && Nh > 0) || return dest
10136
us = DataLayouts.UniversalSize(dest)
102-
if Nv > 0 && Nh > 0
103-
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
37+
if has_uniform_datalayouts(bc)
38+
bc′ = to_non_extruded_broadcasted(bc)
39+
auto_launch!(knl_copyto_linear!, (dest, bc′, us), dest; auto = true)
40+
else
41+
auto_launch!(knl_copyto_cart!, (dest, bc, us), dest; auto = true)
10442
end
10543
return dest
10644
end
10745

10846
# TODO: can we use CUDA's luanch configuration for all data layouts?
10947
# Currently, it seems to have a slight performance degradation.
11048
#! format: off
111-
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
49+
Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
11250
Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc)
11351
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
11452
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
11553
Base.copyto!(dest::VIFH{S, Nv, Ni, Nh}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, Nh}, ::ToCUDA) where {S, Nv, Ni, Nh} = cuda_copyto!(dest, bc)
116-
# Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
117-
# Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
118-
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
54+
Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
55+
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
56+
Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
11957
#! format: on

ext/cuda/data_layouts_fill.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@ function knl_fill_flat!(dest::AbstractData, val, us)
22
@inbounds begin
33
tidx = thread_index()
44
if tidx get_N(us)
5-
n = size(dest)
6-
I = kernel_indexes(tidx, n)
7-
@inbounds dest[I] = val
5+
@inbounds dest[tidx] = val
86
end
97
end
108
return nothing
119
end
1210

11+
function knl_fill_flat!(dest::DataF, val, us)
12+
@inbounds dest[] = val
13+
return nothing
14+
end
15+
1316
function cuda_fill!(dest::AbstractData, val)
1417
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
1518
us = DataLayouts.UniversalSize(dest)

src/DataLayouts/DataLayouts.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,27 @@ empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())
11941194
@inline get_Nij(::IJF{S, Nij}) where {S, Nij} = Nij
11951195
@inline get_Nij(::IF{S, Nij}) where {S, Nij} = Nij
11961196

1197+
# Returns the size of the backing array.
1198+
@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, 1, Nv, Nh)
1199+
@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh)
1200+
@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh)
1201+
@inline array_size(::DataF{S}) where {S} = (1,)
1202+
@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1)
1203+
@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1)
1204+
@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1)
1205+
@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = (Nv, Nij, Nij, 1, Nh)
1206+
@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = (Nv, Ni, 1, Nh)
1207+
1208+
@inline farray_size(data::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, ncomponents(data), Nv, Nh)
1209+
@inline farray_size(data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, ncomponents(data), Nh)
1210+
@inline farray_size(data::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, ncomponents(data), Nh)
1211+
@inline farray_size(data::DataF{S}) where {S} = (ncomponents(data),)
1212+
@inline farray_size(data::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, ncomponents(data))
1213+
@inline farray_size(data::IF{S, Ni}) where {S, Ni} = (Ni, ncomponents(data))
1214+
@inline farray_size(data::VF{S, Nv}) where {S, Nv} = (Nv, ncomponents(data))
1215+
@inline farray_size(data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = (Nv, Nij, Nij, ncomponents(data), Nh)
1216+
@inline farray_size(data::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = (Nv, Ni, ncomponents(data), Nh)
1217+
11971218
"""
11981219
field_dim(data::AbstractData)
11991220
field_dim(::Type{<:AbstractData})
@@ -1350,9 +1371,11 @@ _device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
13501371
_device_dispatch(x::SArray) = ToCPU()
13511372
_device_dispatch(x::MArray) = ToCPU()
13521373

1374+
include("non_extruded_broadcasted.jl")
13531375
include("copyto.jl")
13541376
include("fused_copyto.jl")
13551377
include("fill.jl")
13561378
include("mapreduce.jl")
1379+
include("has_uniform_datalayouts.jl")
13571380

13581381
end # module

src/DataLayouts/broadcast.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ DataSlab2DStyle(::Type{VIJFHStyle{Nv, Nij, Nh, A}}) where {Nv, Nij, Nh, A} =
7373
#####
7474

7575
#! format: off
76+
const BroadcastedUnionData = Union{Base.Broadcast.Broadcasted{<:DataStyle}, AbstractData}
7677
const BroadcastedUnionIJFH{S, Nij, Nh, A} = Union{Base.Broadcast.Broadcasted{IJFHStyle{Nij, Nh, A}}, IJFH{S, Nij, Nh, A}}
7778
const BroadcastedUnionIFH{S, Ni, Nh, A} = Union{Base.Broadcast.Broadcasted{IFHStyle{Ni, Nh, A}}, IFH{S, Ni, Nh, A}}
7879
const BroadcastedUnionIJF{S, Nij, A} = Union{Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}, IJF{S, Nij, A}}

src/DataLayouts/copyto.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,22 @@
22
##### Dispatching and edge cases
33
#####
44

5-
Base.copyto!(
6-
dest::AbstractData,
5+
function Base.copyto!(
6+
dest::AbstractData{S},
77
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
8-
) = Base.copyto!(dest, bc, device_dispatch(dest))
8+
) where {S}
9+
dev = device_dispatch(dest)
10+
if dev isa ToCPU && has_uniform_datalayouts(bc) && !(dest isa DataF)
11+
# Specialize on linear indexing case:
12+
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
13+
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
14+
dest[I] = convert(S, bc′[I])
15+
end
16+
else
17+
Base.copyto!(dest, bc, device_dispatch(dest))
18+
end
19+
return dest
20+
end
921

1022
# Specialize on non-Broadcasted objects
1123
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}

src/DataLayouts/fill.jl

Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,13 @@
1-
function Base.fill!(data::IJFH, val, ::ToCPU)
2-
(_, _, _, _, Nh) = size(data)
3-
@inbounds for h in 1:Nh
4-
fill!(slab(data, h), val)
1+
function Base.fill!(dest::AbstractData, val, ::ToCPU)
2+
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
3+
dest[I] = val
54
end
6-
return data
5+
return dest
76
end
87

9-
function Base.fill!(data::IFH, val, ::ToCPU)
10-
(_, _, _, _, Nh) = size(data)
11-
@inbounds for h in 1:Nh
12-
fill!(slab(data, h), val)
13-
end
14-
return data
15-
end
16-
17-
function Base.fill!(data::DataF, val, ::ToCPU)
18-
@inbounds data[] = val
19-
return data
20-
end
21-
22-
function Base.fill!(data::IJF{S, Nij}, val, ::ToCPU) where {S, Nij}
23-
@inbounds for j in 1:Nij, i in 1:Nij
24-
data[CartesianIndex(i, j, 1, 1, 1)] = val
25-
end
26-
return data
27-
end
28-
29-
function Base.fill!(data::IF{S, Ni}, val, ::ToCPU) where {S, Ni}
30-
@inbounds for i in 1:Ni
31-
data[CartesianIndex(i, 1, 1, 1, 1)] = val
32-
end
33-
return data
34-
end
35-
36-
function Base.fill!(data::VF, val, ::ToCPU)
37-
Nv = nlevels(data)
38-
@inbounds for v in 1:Nv
39-
data[CartesianIndex(1, 1, 1, v, 1)] = val
40-
end
41-
return data
42-
end
43-
44-
function Base.fill!(data::VIJFH, val, ::ToCPU)
45-
(Ni, Nj, _, Nv, Nh) = size(data)
46-
@inbounds for h in 1:Nh, v in 1:Nv
47-
fill!(slab(data, v, h), val)
48-
end
49-
return data
50-
end
51-
52-
function Base.fill!(data::VIFH, val, ::ToCPU)
53-
(Ni, _, _, Nv, Nh) = size(data)
54-
@inbounds for h in 1:Nh, v in 1:Nv
55-
fill!(slab(data, v, h), val)
56-
end
57-
return data
8+
function Base.fill!(dest::DataF, val, ::ToCPU)
9+
@inbounds dest[] = val
10+
return dest
5811
end
5912

6013
Base.fill!(dest::AbstractData, val) =
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
2+
x1 = first_datalayout_in_bc(args[1], rargs...)
3+
x1 isa AbstractData && return x1
4+
return first_datalayout_in_bc(Base.tail(args), rargs...)
5+
end
6+
7+
@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
8+
first_datalayout_in_bc(args[1], rargs...)
9+
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
10+
@inline first_datalayout_in_bc(x) = nothing
11+
@inline first_datalayout_in_bc(x::AbstractData) = x
12+
13+
@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
14+
first_datalayout_in_bc(bc.args)
15+
16+
@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
17+
truesofar &&
18+
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
19+
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)
20+
21+
@inline _has_uniform_datalayouts_args(
22+
truesofar,
23+
start,
24+
args::Tuple{Any},
25+
rargs...,
26+
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
27+
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
28+
truesofar
29+
30+
@inline function _has_uniform_datalayouts(
31+
truesofar,
32+
start,
33+
bc::Base.Broadcast.Broadcasted,
34+
)
35+
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
36+
end
37+
for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH)
38+
@eval begin
39+
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
40+
end
41+
end
42+
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
43+
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar
44+
45+
"""
46+
has_uniform_datalayouts
47+
Find the first datalayout in the broadcast expression (BCE),
48+
and compares against every other datalayout in the BCE. Returns
49+
- `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH)
50+
- `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH)
51+
Note: a broadcasted object can have different _types_,
52+
e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}`
53+
but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`.
54+
"""
55+
function has_uniform_datalayouts end
56+
57+
@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
58+
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)
59+
60+
@inline has_uniform_datalayouts(bc::AbstractData) = true

0 commit comments

Comments
 (0)