Skip to content

Commit 6c5f62d

Browse files
Merge #1222
1222: Deprecate comm_context, use ClimaComms.context r=charleskawczynski a=charleskawczynski This PR deprecates - `comm_context` in favor of using `ClimaComms.context` and methods that assume the device is internally determined using `CUDA.functional()`, which is problematic if CUDA is available but we want to run on the CPU: - `HDF5Reader` - `HDF5Writer` - `Topology2D` - (fixes `create_dss_buffer` to use the topology context) A step towards #1170. Co-authored-by: Charles Kawczynski <kawczynski.charles@gmail.com>
2 parents 8b5e267 + b6fadc2 commit 6c5f62d

File tree

14 files changed

+91
-40
lines changed

14 files changed

+91
-40
lines changed

src/ClimaCore.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ module ClimaCore
33
using PkgVersion
44
const VERSION = PkgVersion.@Version
55

6-
function comm_context end
7-
86
include("interface.jl")
97
include("Utilities/Utilities.jl")
108
include("RecursiveApply/RecursiveApply.jl")

src/Fields/Fields.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module Fields
22

3-
import ..comm_context
43
import ClimaComms
54
import ..enable_threading
65
import ..slab, ..slab_args, ..column, ..column_args, ..level
@@ -38,18 +37,18 @@ Field(values::V, space::S) where {V <: AbstractData, S <: AbstractSpace} =
3837
Field(::Type{T}, space::S) where {T, S <: AbstractSpace} =
3938
Field(similar(Spaces.coordinates_data(space), T), space)
4039

41-
comm_context(field::Field) = comm_context(axes(field))
40+
ClimaComms.context(field::Field) = ClimaComms.context(axes(field))
4241

43-
comm_context(space::Spaces.ExtrudedFiniteDifferenceSpace) =
44-
comm_context(space.horizontal_space)
45-
comm_context(space::Spaces.SpectralElementSpace2D) =
46-
comm_context(space.topology)
47-
comm_context(space::S) where {S <: Spaces.AbstractSpace} =
48-
ClimaComms.SingletonCommsContext()
42+
ClimaComms.context(space::Spaces.ExtrudedFiniteDifferenceSpace) =
43+
ClimaComms.context(space.horizontal_space)
44+
ClimaComms.context(space::Spaces.SpectralElementSpace2D) =
45+
ClimaComms.context(space.topology)
46+
ClimaComms.context(space::S) where {S <: Spaces.AbstractSpace} =
47+
ClimaComms.context(space.topology)
4948

50-
comm_context(topology::Topologies.Topology2D) = topology.context
51-
comm_context(topology::T) where {T <: Topologies.AbstractTopology} =
52-
ClimaComms.SingletonCommsContext()
49+
ClimaComms.context(topology::Topologies.Topology2D) = topology.context
50+
ClimaComms.context(topology::T) where {T <: Topologies.AbstractTopology} =
51+
topology.context
5352

5453
Adapt.adapt_structure(to, field::Field) = Field(
5554
Adapt.adapt(to, Fields.field_values(field)),

src/Fields/mapreduce.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function Base.sum(
4343
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
4444
::ClimaComms.CPUDevice,
4545
)
46-
context = comm_context(axes(field))
46+
context = ClimaComms.context(axes(field))
4747
data_sum = DataLayouts.DataF(local_sum(field))
4848
ClimaComms.allreduce!(context, parent(data_sum), +)
4949
return data_sum[]
@@ -62,7 +62,7 @@ Approximate maximum of `v` or `f.(v)` over the domain.
6262
If `v` is a distributed field, this uses a `ClimaComms.allreduce` operation.
6363
"""
6464
function Base.maximum(fn, field::Field, ::ClimaComms.CPUDevice)
65-
context = comm_context(axes(field))
65+
context = ClimaComms.context(axes(field))
6666
data_max = DataLayouts.DataF(mapreduce(fn, max, todata(field)))
6767
ClimaComms.allreduce!(context, parent(data_max), max)
6868
return data_max[]
@@ -74,7 +74,7 @@ Base.maximum(fn, field::Field) =
7474
Base.maximum(field::Field) = Base.maximum(field, ClimaComms.device(field))
7575

7676
function Base.minimum(fn, field::Field, ::ClimaComms.CPUDevice)
77-
context = comm_context(axes(field))
77+
context = ClimaComms.context(axes(field))
7878
data_min = DataLayouts.DataF(mapreduce(fn, min, todata(field)))
7979
ClimaComms.allreduce!(context, parent(data_min), min)
8080
return data_min[]
@@ -109,7 +109,7 @@ function Statistics.mean(
109109
::ClimaComms.CPUDevice,
110110
)
111111
space = axes(field)
112-
context = comm_context(space)
112+
context = ClimaComms.context(space)
113113
data_combined =
114114
DataLayouts.DataF((local_sum(field), Spaces.local_area(space)))
115115
ClimaComms.allreduce!(context, parent(data_combined), +)

src/InputOutput/readers.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,14 @@ struct HDF5Reader{C <: ClimaComms.AbstractCommsContext}
9292
space_cache::Dict{Any, Any}
9393
end
9494

95+
@deprecate HDF5Reader(filename::AbstractString) HDF5Reader(
96+
filename,
97+
ClimaComms.SingletonCommsContext(),
98+
)
99+
95100
function HDF5Reader(
96101
filename::AbstractString,
97-
context::ClimaComms.AbstractCommsContext = ClimaComms.SingletonCommsContext(),
102+
context::ClimaComms.AbstractCommsContext,
98103
)
99104
if context isa ClimaComms.SingletonCommsContext
100105
file = h5open(filename, "r")

src/InputOutput/writers.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,14 @@ struct HDF5Writer{C <: ClimaComms.AbstractCommsContext} <: AbstractWriter
3737
cache::Dict{String, String}
3838
end
3939

40+
@deprecate HDF5Writer(filename::AbstractString) HDF5Writer(
41+
filename,
42+
ClimaComms.SingletonCommsContext(),
43+
)
44+
4045
function HDF5Writer(
4146
filename::AbstractString,
42-
context::ClimaComms.AbstractCommsContext = ClimaComms.SingletonCommsContext(),
47+
context::ClimaComms.AbstractCommsContext,
4348
)
4449
if context isa ClimaComms.SingletonCommsContext
4550
file = h5open(filename, "w")

src/Spaces/Spaces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ weights ``W_i`` multiplied by the Jacobian determinants ``J_i``:
6767
If `space` is distributed, this uses a `ClimaComms.allreduce` operation.
6868
"""
6969
area(space::Spaces.AbstractSpace) =
70-
ClimaComms.allreduce(comm_context(space), local_area(space), +)
70+
ClimaComms.allreduce(ClimaComms.context(space), local_area(space), +)
7171

7272
ClimaComms.array_type(space::AbstractSpace) =
7373
ClimaComms.array_type(ClimaComms.device(space))

src/Spaces/dss.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ function create_dss_buffer(
5858
local_geometry = nothing,
5959
local_weights = nothing,
6060
) where {S, Nij}
61-
context =
62-
topology isa Topologies.Topology2D ? topology.context :
63-
ClimaComms.SingletonCommsContext()
61+
context = topology.context
6462
DA = ClimaComms.array_type(topology)
6563
convert_to_array = DA isa Array ? false : true
6664
(_, _, _, Nv, nelems) = Base.size(data)

src/Spaces/finitedifference.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,14 @@ Base.@propagate_inbounds function level(
212212
v::PlusHalf,
213213
)
214214
@inbounds local_geometry = level(local_geometry_data(space), v.i + 1)
215-
PointSpace(local_geometry)
215+
context = ClimaComms.context(space)
216+
PointSpace(context, local_geometry)
216217
end
217218
Base.@propagate_inbounds function level(
218219
space::CenterFiniteDifferenceSpace,
219220
v::Int,
220221
)
221222
local_geometry = level(local_geometry_data(space), v)
222-
PointSpace(local_geometry)
223+
context = ClimaComms.context(space)
224+
PointSpace(context, local_geometry)
223225
end

src/Spaces/pointspace.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,35 @@ local_geometry_data(space::AbstractPointSpace) = space.local_geometry
77
88
A zero-dimensional space.
99
"""
10-
struct PointSpace{LG} <: AbstractPointSpace
10+
struct PointSpace{C <: ClimaComms.AbstractCommsContext, LG} <:
11+
AbstractPointSpace
12+
context::C
1113
local_geometry::LG
1214
end
1315

14-
ClimaComms.device(space::PointSpace) = ClimaComms.CPUDevice()
16+
ClimaComms.device(space::PointSpace) = ClimaComms.device(space.context)
17+
ClimaComms.context(space::PointSpace) = space.context
1518

16-
function PointSpace(local_geometry::LG) where {LG <: Geometry.LocalGeometry}
19+
@deprecate PointSpace(x::Geometry.LocalGeometry) PointSpace(
20+
ClimaComms.SingletonCommsContext(ClimaComms.CPUDevice()),
21+
x,
22+
) false
23+
24+
function PointSpace(
25+
context::ClimaComms.AbstractCommsContext,
26+
local_geometry::LG,
27+
) where {LG <: Geometry.LocalGeometry}
1728
FT = Geometry.undertype(LG)
29+
# TODO: inherit array type
1830
local_geometry_data = DataLayouts.DataF{LG}(Array{FT})
1931
local_geometry_data[] = local_geometry
20-
return PointSpace(local_geometry_data)
32+
return PointSpace(context, local_geometry_data)
2133
end
2234

23-
function PointSpace(coord::Geometry.Abstract1DPoint{FT}) where {FT}
35+
function PointSpace(
36+
context::ClimaComms.AbstractCommsContext,
37+
coord::Geometry.Abstract1DPoint{FT},
38+
) where {FT}
2439
CoordType = typeof(coord)
2540
AIdx = Geometry.coordinate_axis(CoordType)
2641
local_geometry = Geometry.LocalGeometry(
@@ -32,5 +47,5 @@ function PointSpace(coord::Geometry.Abstract1DPoint{FT}) where {FT}
3247
FT(1.0),
3348
),
3449
)
35-
return PointSpace(local_geometry)
50+
return PointSpace(context, local_geometry)
3651
end

src/Spaces/spectralelement.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,14 +644,16 @@ Base.@propagate_inbounds slab(space::AbstractSpectralElementSpace, h) =
644644

645645
Base.@propagate_inbounds function column(space::SpectralElementSpace1D, i, h)
646646
local_geometry = column(local_geometry_data(space), i, h)
647-
PointSpace(local_geometry)
647+
context = ClimaComms.context(space)
648+
PointSpace(context, local_geometry)
648649
end
649650
Base.@propagate_inbounds column(space::SpectralElementSpace1D, i, j, h) =
650651
column(space, i, h)
651652

652653
Base.@propagate_inbounds function column(space::SpectralElementSpace2D, i, j, h)
653654
local_geometry = column(local_geometry_data(space), i, j, h)
654-
PointSpace(local_geometry)
655+
context = ClimaComms.context(space)
656+
PointSpace(context, local_geometry)
655657
end
656658

657659
# XXX: this cannot take `space` as it must be constructed beforehand so

0 commit comments

Comments
 (0)