Skip to content

Commit 84670f0

Browse files
Obtain device array from ClimaComms context. Add shallow-water CUDA
driver template.
1 parent 2ebf534 commit 84670f0

File tree

7 files changed

+66
-1
lines changed

7 files changed

+66
-1
lines changed

examples/Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.8.3"
44
manifest_format = "2.0"
5-
project_hash = "1a511843299a40488fdeb0ecc8eb50466fa341cb"
5+
project_hash = "6c93fbc837ba5dffbd7162f915124e75b23659a5"
66

77
[[deps.AMD]]
88
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
3+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
45
ClimaCommsMPI = "5f86816e-8b66-43b2-912e-75384f99de49"
56
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"

examples/sphere/shallow_water_cuda.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using CUDA
2+
using ClimaComms
3+
4+
import ClimaCore:
5+
Device,
6+
Domains,
7+
Fields,
8+
Geometry,
9+
Meshes,
10+
Operators,
11+
Spaces,
12+
Topologies,
13+
DataLayouts
14+
15+
function shallow_water_driver_cuda(ARGS, ::Type{FT}) where {FT}
16+
device = Device.device()
17+
context = ClimaComms.SingletonCommsContext(device)
18+
println("running serial simulation on $device device")
19+
# Test case specifications
20+
test_name = get(ARGS, 1, "steady_state") # default test case to run
21+
test_angle_name = get(ARGS, 2, "alpha0") # default test case to run
22+
α = parse(FT, replace(test_angle_name, "alpha" => ""))
23+
24+
println("Test name: $test_name, α = $(α)")
25+
26+
return nothing
27+
end
28+
29+
shallow_water_driver_cuda(ARGS, Float64)

src/ClimaCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const VERSION = PkgVersion.@Version
55

66
function comm_context end
77

8+
include("Device.jl")
89
include("interface.jl")
910
include("Utilities/Utilities.jl")
1011
include("RecursiveApply/RecursiveApply.jl")

src/Device.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module Device
2+
using ClimaComms
3+
using CUDA
4+
5+
device(; disablegpu = false) =
6+
CUDA.has_cuda_gpu() && !disablegpu ? ClimaComms.CUDA() : ClimaComms.CPU()
7+
8+
device_array_type(::ClimaComms.CPU) = Array
9+
device_array_type(::ClimaComms.CUDA) = CUDA.CuArray
10+
11+
end

test/gpu/device.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Test
2+
using CUDA
3+
using ClimaComms
4+
import ClimaCore: Device
5+
6+
@testset "check device detection on GPU" begin
7+
device = Device.device()
8+
cuda_context = ClimaComms.SingletonCommsContext(device)
9+
DA = Device.device_array_type(cuda_context.device)
10+
11+
@test device isa ClimaComms.CUDA
12+
@test cuda_context.device == ClimaComms.CUDA()
13+
@test DA == CuArray
14+
15+
override_device = Device.device(disablegpu = true)
16+
override_cuda_context = ClimaComms.SingletonCommsContext(override_device)
17+
DA = Device.device_array_type(override_cuda_context.device)
18+
19+
@test override_device isa ClimaComms.CPU
20+
@test override_cuda_context.device == ClimaComms.CPU()
21+
@test DA == Array
22+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ end
9696
if "CUDA" in ARGS
9797
@safetestset "GPU - cuda" begin @time include("gpu/cuda.jl") end
9898
@safetestset "GPU - data" begin @time include("gpu/data.jl") end
99+
@safetestset "GPU - device" begin @time include("gpu/device.jl") end
99100
end
100101

101102
#! format: on

0 commit comments

Comments
 (0)