Skip to content

Commit db622ba

Browse files
Refactor fill benchmark
1 parent d0680b8 commit db622ba

File tree

1 file changed

+36
-19
lines changed

1 file changed

+36
-19
lines changed

test/DataLayouts/benchmark_fill.jl

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,36 @@ julia --project
33
using Revise; include(joinpath("test", "DataLayouts", "benchmark_fill.jl"))
44
=#
55
using Test
6+
using ClimaCore
67
using ClimaCore.DataLayouts
78
using BenchmarkTools
89
import ClimaComms
910
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends
1011

11-
function benchmarkfill!(device, data, val, name)
12-
println("Benchmarking ClimaCore fill! for $name DataLayout")
12+
if ClimaComms.device() isa ClimaComms.CUDADevice
13+
import CUDA
14+
device_name = CUDA.name(CUDA.device()) # Move to ClimaComms
15+
else
16+
device_name = "CPU"
17+
end
18+
19+
include(joinpath(pkgdir(ClimaCore), "benchmarks/scripts/benchmark_utils.jl"))
20+
21+
function benchmarkfill!(bm, device, data, val, name)
22+
caller = string(nameof(typeof(data)))
23+
@info "Benchmarking $caller..."
1324
trial = @benchmark ClimaComms.@cuda_sync $device fill!($data, $val)
14-
show(stdout, MIME("text/plain"), trial)
15-
println()
16-
println("Benchmarking array fill! for $name DataLayout")
17-
trial =
18-
@benchmark ClimaComms.@cuda_sync $device fill!($(parent(data)), $val)
19-
show(stdout, MIME("text/plain"), trial)
20-
println()
25+
t_min = minimum(trial.times) * 1e-9 # to seconds
26+
nreps = length(trial.times)
27+
n_reads_writes = DataLayouts.ncomponents(data) * 2
28+
push_info(
29+
bm;
30+
kernel_time_s = t_min,
31+
nreps = nreps,
32+
caller,
33+
problem_size = size(data),
34+
n_reads_writes,
35+
)
2136
end
2237

2338
@testset "fill! with Nf = 1" begin
@@ -30,17 +45,19 @@ end
3045
Nij = 4
3146
Nh = 30 * 30 * 6
3247
Nk = 6
48+
bm = Benchmark(; float_type = FT, device_name)
3349
#! format: off
34-
data = DataF{S}(device_zeros(FT,Nf)); benchmarkfill!(device, data, 3, "DataF" ); @test all(parent(data) .== 3)
35-
data = IJFH{S, Nij, Nh}(device_zeros(FT,Nij,Nij,Nf,Nh)); benchmarkfill!(device, data, 3, "IJFH" ); @test all(parent(data) .== 3)
36-
data = IFH{S, Nij, Nh}(device_zeros(FT,Nij,Nf,Nh)); benchmarkfill!(device, data, 3, "IFH" ); @test all(parent(data) .== 3)
37-
data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); benchmarkfill!(device, data, 3, "IJF" ); @test all(parent(data) .== 3)
38-
data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); benchmarkfill!(device, data, 3, "IF" ); @test all(parent(data) .== 3)
39-
data = VF{S, Nv}(device_zeros(FT,Nv,Nf)); benchmarkfill!(device, data, 3, "VF" ); @test all(parent(data) .== 3)
40-
data = VIJFH{S,Nv,Nij,Nh}(device_zeros(FT,Nv,Nij,Nij,Nf,Nh));benchmarkfill!(device, data, 3, "VIJFH" ); @test all(parent(data) .== 3)
41-
data = VIFH{S, Nv, Nij, Nh}(device_zeros(FT,Nv,Nij,Nf,Nh)); benchmarkfill!(device, data, 3, "VIFH" ); @test all(parent(data) .== 3)
50+
data = DataF{S}(device_zeros(FT,Nf)); benchmarkfill!(bm, device, data, 3, "DataF" ); @test all(parent(data) .== 3)
51+
data = IJFH{S, Nij, Nh}(device_zeros(FT,Nij,Nij,Nf,Nh)); benchmarkfill!(bm, device, data, 3, "IJFH" ); @test all(parent(data) .== 3)
52+
data = IFH{S, Nij, Nh}(device_zeros(FT,Nij,Nf,Nh)); benchmarkfill!(bm, device, data, 3, "IFH" ); @test all(parent(data) .== 3)
53+
data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); benchmarkfill!(bm, device, data, 3, "IJF" ); @test all(parent(data) .== 3)
54+
data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); benchmarkfill!(bm, device, data, 3, "IF" ); @test all(parent(data) .== 3)
55+
data = VF{S, Nv}(device_zeros(FT,Nv,Nf)); benchmarkfill!(bm, device, data, 3, "VF" ); @test all(parent(data) .== 3)
56+
data = VIJFH{S,Nv,Nij,Nh}(device_zeros(FT,Nv,Nij,Nij,Nf,Nh));benchmarkfill!(bm, device, data, 3, "VIJFH" ); @test all(parent(data) .== 3)
57+
data = VIFH{S, Nv, Nij, Nh}(device_zeros(FT,Nv,Nij,Nf,Nh)); benchmarkfill!(bm, device, data, 3, "VIFH" ); @test all(parent(data) .== 3)
4258
#! format: on
4359

44-
# data = IJKFVH{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(device, data, 3); @test all(parent(data) .== 3) # TODO: test
45-
# data = IH1JH2{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(device, data, 3); @test all(parent(data) .== 3) # TODO: test
60+
# data = IJKFVH{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
61+
# data = IH1JH2{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
62+
tabulate_benchmark(bm)
4663
end

0 commit comments

Comments
 (0)