Skip to content

Commit f8b7ad4

Browse files
Merge pull request #1918 from CliMA/ck/thermo_bench
Update flop-inclusive thermo bench script
2 parents 111dc75 + c89881f commit f8b7ad4

File tree

1 file changed

+84
-62
lines changed

1 file changed

+84
-62
lines changed

benchmarks/scripts/thermo_bench.jl

Lines changed: 84 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,29 @@ using Revise; include(joinpath("benchmarks", "scripts", "thermo_bench.jl"))
44
55
This benchmark requires Thermodynamics and ClimaParams
66
to be in your local environment to run.
7+
8+
# Benchmark results:
9+
10+
Clima A100:
11+
```
12+
[ Info: device = ClimaComms.CUDADevice()
13+
Problem size: (63, 4, 4, 1, 5400), float_type = Float32, device_bandwidth_GBs=2039
14+
┌──────────────────────────────────────────────────────────────────┬───────────────────────────────────┬─────────┬─────────────┬────────────────┬────────┐
15+
│ funcs │ time per call │ bw % │ achieved bw │ n-reads/writes │ n-reps │
16+
├──────────────────────────────────────────────────────────────────┼───────────────────────────────────┼─────────┼─────────────┼────────────────┼────────┤
17+
│ TB.thermo_func_bc!(x, thermo_params, us; nreps=100, bm) │ 586 microseconds, 517 nanoseconds │ 15.2602 │ 311.155 │ 9 │ 100 │
18+
│ TB.thermo_func_sol!(x_vec, thermo_params, us; nreps=100, bm) │ 292 microseconds, 178 nanoseconds │ 30.6332 │ 624.611 │ 9 │ 100 │
19+
│ TB.thermo_func_bc!(x, thermo_params, us; nreps=100, bm) │ 586 microseconds, 988 nanoseconds │ 15.2479 │ 310.905 │ 9 │ 100 │
20+
│ TB.thermo_func_sol!(x_vec, thermo_params, us; nreps=100, bm) │ 292 microseconds, 178 nanoseconds │ 30.6332 │ 624.611 │ 9 │ 100 │
21+
└──────────────────────────────────────────────────────────────────┴───────────────────────────────────┴─────────┴─────────────┴────────────────┴────────┘
22+
```
723
=#
824

925
#! format: off
10-
import ClimaCore
11-
import Thermodynamics as TD
12-
import CUDA
13-
using ClimaComms
14-
import ClimaCore: Spaces, Fields
15-
import ClimaCore.Domains: Geometry
26+
module ThermoBench
1627

17-
ENV["CLIMACOMMS_DEVICE"] = "CUDA";
18-
ClimaComms.@import_required_backends
19-
using BenchmarkTools
20-
@isdefined(TU) || include(
21-
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
22-
);
23-
import .TestUtilities as TU;
28+
include("benchmark_utils.jl")
2429

25-
module ThermoBench
2630
import ClimaCore
2731
import CUDA
2832
using ClimaComms
@@ -32,75 +36,65 @@ using JET
3236

3337
import ClimaCore: Spaces, Fields
3438
import ClimaCore.Domains: Geometry
35-
import Dates
36-
print_time_and_units(x) = println(time_and_units_str(x))
37-
time_and_units_str(x::Real) =
38-
trunc_time(string(compound_period(x, Dates.Second)))
39-
function compound_period(x::Real, ::Type{T}) where {T <: Dates.Period}
40-
nf = Dates.value(convert(Dates.Nanosecond, T(1)))
41-
ns = Dates.Nanosecond(ceil(x * nf))
42-
return Dates.canonicalize(Dates.CompoundPeriod(ns))
43-
end
44-
trunc_time(s::String) = count(',', s) > 1 ? join(split(s, ",")[1:2], ",") : s
4539

4640
@inline ts_gs(thermo_params, e_tot, q_tot, K, Φ, ρ) =
4741
thermo_state(thermo_params, e_tot - K - Φ, q_tot, ρ)
4842
@inline thermo_state(thermo_params, ρ, e_int, q_tot) =
4943
TD.PhaseEquil_ρeq(thermo_params,ρ,e_int,q_tot, 3, eltype(thermo_params)(0.003))
5044

51-
struct UniversalSizesStatic{Nv, Nij, Nh} end
52-
get_Nv(::UniversalSizesStatic{Nv}) where {Nv} = Nv
53-
get_Nij(::UniversalSizesStatic{Nv, Nij}) where {Nv, Nij} = Nij
54-
get_Nh(::UniversalSizesStatic{Nv, Nij, Nh}) where {Nv, Nij, Nh} = Nh
55-
get_N(us::UniversalSizesStatic{Nv, Nij}) where {Nv, Nij} =
56-
prod((Nv, Nij, Nij, 1, get_Nh(us)))
57-
UniversalSizesStatic(Nv, Nij, Nh) = UniversalSizesStatic{Nv, Nij, Nh}()
5845
import Thermodynamics as TD
5946

60-
function thermo_func_bc!(x, thermo_params, us, niter = 1)
61-
e = CUDA.@elapsed begin
62-
for i in 1:niter # reduce variance / impact of launch latency
63-
(; ts, e_tot, q_tot, K, Φ, ρ) = x
64-
@. ts = ts_gs(thermo_params, e_tot, q_tot, K, Φ, ρ)
47+
function thermo_func_bc!(x, thermo_params, us; nreps = 1, bm=nothing, n_trials = 30)
48+
e = Inf
49+
for t in 1:n_trials
50+
et = CUDA.@elapsed begin
51+
for _ in 1:nreps
52+
(; ts, e_tot, q_tot, K, Φ, ρ) = x
53+
@. ts = ts_gs(thermo_params, e_tot, q_tot, K, Φ, ρ) # 5 reads, 5 writes, many flops
54+
end
6555
end
56+
e = min(e, et)
6657
end
67-
print_time_and_units(e / niter)
58+
push_info(bm; e, nreps, caller = @caller_name(@__FILE__),n_reads_writes=5+4) # TODO: verify this
6859
return nothing
6960
end
7061

71-
function thermo_func_sol!(x, thermo_params, us::UniversalSizesStatic, niter = 1)
72-
e = CUDA.@elapsed begin
73-
for i in 1:niter # reduce variance / impact of launch latency
62+
function thermo_func_sol!(x, thermo_params, us::UniversalSizesStatic; nreps = 1, bm=nothing, n_trials = 30)
63+
e = Inf
64+
for t in 1:n_trials
65+
et = CUDA.@elapsed begin
7466
(; ts, e_tot, q_tot, K, Φ, ρ) = x
7567
kernel = CUDA.@cuda always_inline = true launch = false thermo_func_sol_kernel!(ts,e_tot,q_tot,K,Φ,ρ,thermo_params,us)
7668
N = get_N(us)
7769
config = CUDA.launch_configuration(kernel.fun)
7870
threads = min(N, config.threads)
7971
blocks = cld(N, threads)
80-
kernel(ts,e_tot,q_tot,K,Φ,ρ,thermo_params,us; threads, blocks)
72+
for _ in 1:nreps
73+
kernel(ts,e_tot,q_tot,K,Φ,ρ,thermo_params,us; threads, blocks)
74+
end
8175
end
76+
e = min(e, et)
8277
end
83-
print_time_and_units(e / niter)
78+
push_info(bm; e, nreps, caller = @caller_name(@__FILE__),n_reads_writes=5+4) # TODO: verify this
8479
return nothing
8580
end
8681

8782
# Mimics how indexing works in generalized pointwise kernels
8883
function thermo_func_sol_kernel!(ts, e_tot, q_tot, K, Φ, ρ, thermo_params, us)
8984
@inbounds begin
90-
(; ts_ρ, ts_p, ts_e_int, ts_q_tot, ts_T) = ts
9185
FT = eltype(e_tot)
9286
I = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
9387
if I get_N(us)
94-
ts_i = ts_gs(thermo_params, e_tot[I], q_tot[I], K[I], Φ[I], ρ[I]) # 5 reads, potentially many flops (see thermodynamics for estimate)
95-
9688
# Data is not read into the correct fields because this is only used
9789
# to compare with the case when the number of flops goes to zero.
98-
# ts_i = TD.PhaseEquil{FT}(ρ[I], K[I], e_tot[I], q_tot[I], Φ[I]) # 5 reads, 0 flops
99-
ts_ρ[I] = ts_i.ρ
100-
ts_p[I] = ts_i.p
101-
ts_T[I] = ts_i.T
102-
ts_e_int[I] = ts_i.e_int
103-
ts_q_tot[I] = ts_i.q_tot
90+
91+
# 5 reads, 5 writes, potentially many flops (see thermodynamics for estimate)
92+
ts_i = ts_gs(thermo_params, e_tot[I], q_tot[I], K[I], Φ[I], ρ[I])
93+
ts.ρ[I] = ts_i.ρ
94+
ts.p[I] = ts_i.p
95+
ts.T[I] = ts_i.T
96+
ts.e_int[I] = ts_i.e_int
97+
ts.q_tot[I] = ts_i.q_tot
10498
end
10599
end
106100
return nothing
@@ -109,10 +103,27 @@ end
109103
end
110104

111105
import ClimaParams # trigger Thermo extension
112-
import .ThermoBench
106+
import .ThermoBench as TB
107+
108+
import Thermodynamics as TD
109+
import CUDA
110+
using ClimaComms
111+
using ClimaCore
112+
import ClimaCore: Spaces, Fields
113+
import ClimaCore.Domains: Geometry
114+
115+
ENV["CLIMACOMMS_DEVICE"] = "CUDA";
116+
ClimaComms.@import_required_backends
117+
using BenchmarkTools
118+
@isdefined(TU) || include(
119+
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
120+
);
121+
import .TestUtilities as TU;
122+
113123
using Test
114124
@testset "Thermo state" begin
115125
FT = Float32
126+
bm = TB.Benchmark(;problem_size=(63,4,4,1,5400), float_type=FT)
116127
device = ClimaComms.device()
117128
context = ClimaComms.context(device)
118129
cspace = TU.CenterExtrudedFiniteDifferenceSpace(
@@ -125,18 +136,19 @@ using Test
125136
fspace = Spaces.FaceExtrudedFiniteDifferenceSpace(cspace)
126137
@info "device = $device"
127138
thermo_params = TD.Parameters.ThermodynamicsParameters(FT)
139+
# TODO: fill with non-trivial values (e.g., use Thermodynamics TestedProfiles) to verify correctness.
128140
nt_core = (; K = FT(0), Φ = FT(1), ρ = FT(0), e_tot = FT(1), q_tot = FT(0.001))
129141
nt_ts = (;
130-
ts_ρ = FT(0),
131-
ts_p = FT(0),
132-
ts_e_int = FT(0),
133-
ts_q_tot = FT(0),
134-
ts_T = FT(0),
142+
ρ = FT(0),
143+
p = FT(0),
144+
e_int = FT(0),
145+
q_tot = FT(0),
146+
T = FT(0),
135147
)
136148
x = fill((; ts = zero(TD.PhaseEquil{FT}), nt_core...), cspace)
137149
xv = fill((; ts = nt_ts, nt_core...), cspace)
138150
(_, Nij, _, Nv, Nh) = size(Fields.field_values(x.ts))
139-
us = ThermoBench.UniversalSizesStatic(Nv, Nij, Nh)
151+
us = TB.UniversalSizesStatic(Nv, Nij, Nh)
140152
function to_vec(ξ)
141153
pns = propertynames(ξ)
142154
dl_vals = map(pns) do pn
@@ -148,10 +160,20 @@ using Test
148160
end
149161
x_vec = to_vec(xv)
150162

151-
ThermoBench.thermo_func_bc!(x, thermo_params, us)
152-
ThermoBench.thermo_func_sol!(x_vec, thermo_params, us)
163+
TB.thermo_func_bc!(x, thermo_params, us; nreps=1, n_trials = 1)
164+
TB.thermo_func_sol!(x_vec, thermo_params, us; nreps=1, n_trials = 1)
165+
166+
rc = Fields.rcompare(x_vec, to_vec(x))
167+
rc || Fields.rprint_diff(x_vec, to_vec(x)) # test correctness (should print nothing)
168+
@test rc # test correctness
169+
170+
TB.thermo_func_bc!(x, thermo_params, us; nreps=100, bm)
171+
TB.thermo_func_sol!(x_vec, thermo_params, us; nreps=100, bm)
172+
173+
TB.thermo_func_bc!(x, thermo_params, us; nreps=100, bm)
174+
TB.thermo_func_sol!(x_vec, thermo_params, us; nreps=100, bm)
175+
176+
TB.tabulate_benchmark(bm)
153177

154-
ThermoBench.thermo_func_bc!(x, thermo_params, us, 100)
155-
ThermoBench.thermo_func_sol!(x_vec, thermo_params, us, 100)
156178
end
157179
#! format: on

0 commit comments

Comments
 (0)