Skip to content

Commit 77c94c6

Browse files
Merge pull request #1898 from CliMA/ck/use_us
Use UniversalSize struct in more kernels
2 parents 4736aea + 82151d9 commit 77c94c6

File tree

7 files changed

+128
-28
lines changed

7 files changed

+128
-28
lines changed

.buildkite/pipeline.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ steps:
9393
key: data_opt_similar
9494
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/opt_similar.jl"
9595

96+
- label: "Unit: opt_universal_size"
97+
key: opt_universal_size
98+
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/opt_universal_size.jl"
99+
96100
- label: "Unit: data_ndims"
97101
key: unit_data_ndims
98102
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_ndims.jl"

ext/ClimaCoreCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import ClimaCore.Utilities: half
1616
import ClimaCore.Utilities: cart_ind, linear_ind
1717
import ClimaCore.RecursiveApply:
1818
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax
19+
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
1920

2021
include(joinpath("cuda", "cuda_utils.jl"))
2122
include(joinpath("cuda", "data_layouts.jl"))

ext/cuda/data_layouts_copyto.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ function Base.copyto!(
8484
end
8585

8686
import ClimaCore.DataLayouts: isascalar
87-
function knl_copyto_flat!(dest::AbstractData, bc)
87+
function knl_copyto_flat!(dest::AbstractData, bc, us)
8888
@inbounds begin
89-
n = size(dest)
9089
tidx = thread_index()
91-
if valid_range(tidx, prod(n))
90+
if tidx get_N(us)
91+
n = size(dest)
9292
I = kernel_indexes(tidx, n)
9393
dest[I] = bc[I]
9494
end
@@ -98,14 +98,15 @@ end
9898

9999
function cuda_copyto!(dest::AbstractData, bc)
100100
(_, _, Nv, Nh) = DataLayouts.universal_size(dest)
101+
us = DataLayouts.UniversalSize(dest)
101102
if Nv > 0 && Nh > 0
102-
auto_launch!(knl_copyto_flat!, (dest, bc), dest; auto = true)
103+
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
103104
end
104105
return dest
105106
end
106107

107108
# TODO: can we use CUDA's luanch configuration for all data layouts?
108-
# Currently, it seems to have a slight performance degredation.
109+
# Currently, it seems to have a slight performance degradation.
109110
#! format: off
110111
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
111112
Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc)

ext/cuda/data_layouts_fill.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
function knl_fill_flat!(dest::AbstractData, val)
1+
function knl_fill_flat!(dest::AbstractData, val, us)
22
@inbounds begin
33
tidx = thread_index()
4-
n = size(dest)
5-
if valid_range(tidx, prod(n))
4+
if tidx get_N(us)
5+
n = size(dest)
66
I = kernel_indexes(tidx, n)
77
@inbounds dest[I] = val
88
end
@@ -12,8 +12,9 @@ end
1212

1313
function cuda_fill!(dest::AbstractData, val)
1414
(_, _, Nv, Nh) = DataLayouts.universal_size(dest)
15+
us = DataLayouts.UniversalSize(dest)
1516
if Nv > 0 && Nh > 0
16-
auto_launch!(knl_fill_flat!, (dest, val), dest; auto = true)
17+
auto_launch!(knl_fill_flat!, (dest, val, us), dest; auto = true)
1718
end
1819
return dest
1920
end

ext/cuda/operators_finite_difference.jl

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,11 @@ function Base.copyto!(
2929
(li, lw, rw, ri) = bounds = Operators.window_bounds(space, bc)
3030
Nv = ri - li + 1
3131
max_threads = 256
32+
us = DataLayouts.UniversalSize(Fields.field_values(out))
3233
nitems = Nv * Nq * Nq * Nh # # of independent items
3334
(nthreads, nblocks) = _configure_threadblock(max_threads, nitems)
34-
args = (
35-
strip_space(out, space),
36-
strip_space(bc, space),
37-
axes(out),
38-
bounds,
39-
Val(Nv),
40-
Val(Nq),
41-
Nh,
42-
)
35+
args =
36+
(strip_space(out, space), strip_space(bc, space), axes(out), bounds, us)
4337
auto_launch!(
4438
copyto_stencil_kernel!,
4539
args,
@@ -49,20 +43,16 @@ function Base.copyto!(
4943
)
5044
return out
5145
end
46+
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
5247

53-
function copyto_stencil_kernel!(
54-
out,
55-
bc,
56-
space,
57-
bds,
58-
::Val{Nv},
59-
::Val{Nq},
60-
Nh,
61-
) where {Nv, Nq}
48+
function copyto_stencil_kernel!(out, bc, space, bds, us)
6249
@inbounds begin
6350
gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x
64-
if gid Nv * Nq * Nq * Nh
51+
if gid get_N(us)
6552
(li, lw, rw, ri) = bds
53+
Nv = get_Nv(us)
54+
Nq = get_Nij(us)
55+
Nh = get_Nh(us)
6656
(v, i, j, h) = cart_ind((Nv, Nq, Nq, Nh), gid).I
6757
hidx = (i, j, h)
6858
idx = v - 1 + li

src/DataLayouts/DataLayouts.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,41 @@ corresponding to `UniversalSize`.
7676
"""
7777
@inline universal_size(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} =
7878
(Ni, Nj, Nv, Nh)
79+
80+
"""
81+
get_N(::AbstractData)
82+
get_N(::UniversalSize)
83+
84+
Statically returns `prod((Ni, Nj, Nv, Nh))`
85+
"""
86+
@inline get_N(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} =
87+
prod((Ni, Nj, Nv, Nh))
88+
89+
@inline get_N(data::AbstractData) = get_N(UniversalSize(data))
90+
91+
"""
92+
get_Nv(::UniversalSize)
93+
94+
Statically returns `Nv`.
95+
"""
96+
get_Nv(::UniversalSize{Ni, Nj, Nv}) where {Ni, Nj, Nv} = Nv
97+
98+
"""
99+
get_Nij(::UniversalSize)
100+
101+
Statically returns `Nij`.
102+
"""
103+
get_Nij(::UniversalSize{Nij}) where {Nij} = Nij
104+
105+
"""
106+
get_Nh(::UniversalSize)
107+
108+
Statically returns `Nh`.
109+
"""
110+
get_Nh(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} = Nh
111+
112+
get_Nh(data::AbstractData) = Nh
113+
79114
@inline universal_size(data::AbstractData) = universal_size(UniversalSize(data))
80115

81116
function Base.show(io::IO, data::AbstractData)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#=
2+
julia --project
3+
using Revise; include(joinpath("test", "DataLayouts", "opt_universal_size.jl"))
4+
=#
5+
using Test
6+
using ClimaCore.DataLayouts
7+
using ClimaCore: DataLayouts, Geometry
8+
import ClimaComms
9+
using StaticArrays: SMatrix
10+
ClimaComms.@import_required_backends
11+
using JET
12+
using InteractiveUtils: @code_typed
13+
14+
function test_universal_size(data)
15+
us = DataLayouts.UniversalSize(data)
16+
# Make sure results is statically returned / constant propagated
17+
ct = @code_typed DataLayouts.get_N(us)
18+
@test ct.first.code[1] isa Core.ReturnNode
19+
@test ct.first.code[end].val == DataLayouts.get_N(us)
20+
21+
ct = @code_typed DataLayouts.get_Nv(us)
22+
@test ct.first.code[1] isa Core.ReturnNode
23+
@test ct.first.code[end].val == DataLayouts.get_Nv(us)
24+
25+
ct = @code_typed DataLayouts.get_Nij(us)
26+
@test ct.first.code[1] isa Core.ReturnNode
27+
@test ct.first.code[end].val == DataLayouts.get_Nij(us)
28+
29+
ct = @code_typed DataLayouts.get_Nh(us)
30+
@test ct.first.code[1] isa Core.ReturnNode
31+
@test ct.first.code[end].val == DataLayouts.get_Nh(us)
32+
33+
ct = @code_typed size(data)
34+
@test ct.first.code[1] isa Core.ReturnNode
35+
@test ct.first.code[end].val == size(data)
36+
37+
ct = @code_typed DataLayouts.get_N(data)
38+
@test ct.first.code[1] isa Core.ReturnNode
39+
@test ct.first.code[end].val == DataLayouts.get_N(data)
40+
41+
# Demo of failed constant prop:
42+
ct = @code_typed prod(size(data))
43+
@test ct.first.code[1] isa Expr # first element is not a return node, but an expression
44+
end
45+
46+
@testset "UniversalSize" begin
47+
device = ClimaComms.device()
48+
device_zeros(args...) = ClimaComms.array_type(device)(zeros(args...))
49+
FT = Float64
50+
S = FT
51+
Nf = 1
52+
Nv = 4
53+
Nij = 3
54+
Nh = 5
55+
Nk = 6
56+
#! format: off
57+
data = DataF{S}(device_zeros(FT,Nf)); test_universal_size(data)
58+
data = IJFH{S, Nij, Nh}(device_zeros(FT,Nij,Nij,Nf,Nh)); test_universal_size(data)
59+
data = IFH{S, Nij, Nh}(device_zeros(FT,Nij,Nf,Nh)); test_universal_size(data)
60+
data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); test_universal_size(data)
61+
data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); test_universal_size(data)
62+
data = VF{S, Nv}(device_zeros(FT,Nv,Nf)); test_universal_size(data)
63+
data = VIJFH{S,Nv,Nij,Nh}(device_zeros(FT,Nv,Nij,Nij,Nf,Nh));test_universal_size(data)
64+
data = VIFH{S, Nv, Nij, Nh}(device_zeros(FT,Nv,Nij,Nf,Nh)); test_universal_size(data)
65+
#! format: on
66+
# data = DataLayouts.IJKFVH{S, Nij, Nk, Nv, Nh}(device_zeros(FT,Nij,Nij,Nk,Nf,Nv,Nh)); test_universal_size(data) # TODO: test
67+
# data = DataLayouts.IH1JH2{S, Nij}(device_zeros(FT,2*Nij,3*Nij)); test_universal_size(data) # TODO: test
68+
end

0 commit comments

Comments
 (0)