Skip to content

Commit 8fb42cf

Browse files
Add restart support for masked spaces (#2212)
Remove unused variable Try MPI-safe mktempfile
1 parent 6b9a15d commit 8fb42cf

File tree

3 files changed

+258
-26
lines changed

3 files changed

+258
-26
lines changed

src/InputOutput/readers.jl

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,49 @@ function read_grid(reader, name)
414414
end
415415
end
416416

417+
"""
418+
read_data_layout(dataset, topology)
419+
420+
Read a datalayout from a `dataset`, with a given `topology`.
421+
422+
This should cooperate with datasets written by `write!` for datalayouts.
423+
"""
424+
function read_data_layout(dataset, topology)
425+
ArrayType = ClimaComms.array_type(topology)
426+
data_layout = HDF5.read_attribute(dataset, "type")
427+
has_horizontal = occursin('I', data_layout)
428+
DataLayout = _scan_data_layout(data_layout)
429+
array = HDF5.read(dataset)
430+
has_horizontal &&
431+
(h_dim = DataLayouts.h_dim(DataLayouts.singleton(DataLayout)))
432+
if topology isa Topologies.Topology2D
433+
nd = ndims(array)
434+
localidx = ntuple(d -> d == h_dim ? topology.local_elem_gidx : (:), nd)
435+
data = ArrayType(array[localidx...])
436+
else
437+
data = ArrayType(read(array))
438+
end
439+
has_horizontal && (Nij = size(data, findfirst("I", data_layout)[1]))
440+
# For when `Nh` is added back to the type space
441+
# Nhd = Nh_dim(data_layout)
442+
# Nht = Nhd == -1 ? () : (size(data, Nhd),)
443+
ElType = read_type(HDF5.read_attribute(dataset, "data_eltype"))
444+
if data_layout in ("VIJFH", "VIFH")
445+
Nv = size(data, 1)
446+
# values = DataLayout{ElType, Nv, Nij, Nht...}(data) # when Nh is in type-domain
447+
values = DataLayout{ElType, Nv, Nij}(data)
448+
elseif data_layout in ("VF",)
449+
Nv = size(data, 1)
450+
values = DataLayout{ElType, Nv}(data)
451+
elseif data_layout in ("DataF",)
452+
values = DataLayout{ElType}(data)
453+
else
454+
# values = DataLayout{ElType, Nij, Nht...}(data) # when Nh is in type-domain
455+
values = DataLayout{ElType, Nij}(data)
456+
end
457+
return values
458+
end
459+
417460
function read_grid_new(reader, name)
418461
group = reader.file["grids/$name"]
419462
type = attrs(group)["type"]
@@ -426,11 +469,21 @@ function read_grid_new(reader, name)
426469
if type == "SpectralElementGrid1D"
427470
return Grids.SpectralElementGrid1D(topology, quadrature_style)
428471
else
429-
return Grids.SpectralElementGrid2D(
472+
enable_mask = haskey(attrs(group), "grid_mask")
473+
grid = Grids.SpectralElementGrid2D(
430474
topology,
431475
quadrature_style;
432476
enable_bubble,
477+
enable_mask,
433478
)
479+
if enable_mask
480+
mask_type = keys(reader.file["grid_mask"])[1]
481+
@assert mask_type == "IJHMask"
482+
ds_is_active = reader.file["grid_mask"]["IJHMask"]["is_active"]
483+
is_active = read_data_layout(ds_is_active, topology)
484+
Grids.set_mask!(grid, is_active)
485+
end
486+
return grid
434487
end
435488
elseif type == "FiniteDifferenceGrid"
436489
topology = read_topology(reader, attrs(group)["topology"])

src/InputOutput/writers.jl

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ end
321321

322322
# Grids
323323
#
324+
defaultname(::DataLayouts.IJHMask) = "IJHMask"
324325
defaultname(::Grids.SpectralElementGrid1D) = "horizontal_grid"
325326
defaultname(::Grids.SpectralElementGrid2D) = "horizontal_grid"
326327
defaultname(::Grids.ExtrudedFiniteDifferenceGrid) =
@@ -374,6 +375,13 @@ function write_new!(
374375
)
375376
write_attribute(group, "bubble", grid.enable_bubble ? "true" : "false")
376377
write_attribute(group, "topology", write!(writer, Spaces.topology(grid)))
378+
if !(grid.mask isa DataLayouts.NoMask)
379+
write_attribute(
380+
group,
381+
"grid_mask",
382+
write!(writer, grid.mask, Spaces.topology(grid)),
383+
)
384+
end
377385
return name
378386
end
379387

@@ -450,6 +458,17 @@ function write_new!(
450458
return name
451459
end
452460

461+
function write!(
462+
writer::HDF5Writer,
463+
mask::DataLayouts.IJHMask,
464+
topology::Topologies.AbstractTopology,
465+
name::AbstractString = defaultname(mask),
466+
)
467+
group = create_group(writer.file, "grid_mask/$name")
468+
write!(writer, group, mask.is_active, "is_active", topology)
469+
return name
470+
end
471+
453472
# write fields
454473
"""
455474
write!(writer::HDF5Writer, field::Fields.Field, name::AbstractString)
@@ -511,6 +530,109 @@ function write!(
511530
)
512531
end
513532

533+
"""
534+
write!(
535+
writer::HDF5Writer,
536+
values::DataLayouts.AbstractData,
537+
name::AbstractString,
538+
topology::Topologies.AbstractTopology,
539+
)
540+
541+
Write an object of type `AbstractData` and name `name` to the HDF5 file.
542+
543+
The `values` should belong to a `Field` whose `space`'s topology is
544+
`topology(axes(field))`.
545+
"""
546+
function write!(
547+
writer::HDF5Writer,
548+
group,
549+
values::DataLayouts.AbstractData,
550+
name::AbstractString,
551+
topology::Topologies.AbstractTopology,
552+
)
553+
if topology isa Topologies.Topology2D &&
554+
!(writer.context isa ClimaComms.SingletonCommsContext)
555+
nelems = Topologies.nelems(topology)
556+
(; local_elem_gidx) = topology
557+
_write_mpi!(group, values, name; nelems, local_elem_gidx)
558+
else
559+
_write!(group, values, name)
560+
end
561+
end
562+
563+
function write_plain_array!(group, array::AbstractArray, name::AbstractString)
564+
nd = ndims(array)
565+
dims = size(array)
566+
localidx = ntuple(d -> (:), nd)
567+
dataset =
568+
create_dataset(group, name, datatype(eltype(array)), dataspace(dims))
569+
dataset[localidx...] = array
570+
return dataset
571+
end
572+
573+
"""
574+
_write_mpi!(
575+
writer::HDF5Writer,
576+
data::DataLayouts.AbstractData,
577+
name::AbstractString,
578+
nelems,
579+
local_elem_gidx
580+
)
581+
582+
This is an internal method, meant to be used for writing data layouts to the
583+
HDF5 file.
584+
585+
This method should be used for distributed datalayouts.
586+
"""
587+
function _write_mpi!(
588+
group,
589+
values::DataLayouts.AbstractData,
590+
name::AbstractString;
591+
nelems,
592+
local_elem_gidx,
593+
)
594+
h_dim = DataLayouts.h_dim(DataLayouts.singleton(values))
595+
array = parent(values)
596+
nd = ndims(array)
597+
dims = ntuple(d -> d == h_dim ? nelems : size(array, d), nd)
598+
localidx = ntuple(d -> d == h_dim ? local_elem_gidx : (:), nd)
599+
dataset = create_dataset(
600+
group,
601+
"data/$name",
602+
datatype(eltype(array)),
603+
dataspace(dims);
604+
dxpl_mpio = :collective,
605+
)
606+
dataset[localidx...] = array
607+
write_attribute(dataset, "array", array)
608+
write_attribute(dataset, "data_layout", string(nameof(typeof(values))))
609+
write_attribute(dataset, "data_eltype", string(eltype(values)))
610+
return name
611+
end
612+
613+
"""
614+
_write!(
615+
writer::HDF5Writer,
616+
data::DataLayouts.AbstractData,
617+
name::AbstractString,
618+
)
619+
620+
This is an internal method, meant to be used for writing data layouts to the
621+
HDF5 file.
622+
623+
This method should be used when this is not a distributed datalayout.
624+
"""
625+
function _write!(group, values::DataLayouts.AbstractData, name::AbstractString;)
626+
h_dim = DataLayouts.h_dim(DataLayouts.singleton(values))
627+
array = parent(values)
628+
dataset = write_plain_array!(group, array, name)
629+
write_attribute(dataset, "array", array)
630+
write_attribute(dataset, "type", string(nameof(typeof(values))))
631+
write_attribute(dataset, "data_eltype", string(eltype(values)))
632+
return name
633+
end
634+
635+
514636
"""
515637
write!(
516638
writer::HDF5Writer,

test/InputOutput/unit_spectralelement2d.jl

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
#=
2+
julia --project=.buildkite
3+
using Revise; include("test/InputOutput/unit_spectralelement2d.jl")
4+
=#
15
using Test
26
using ClimaComms
37
ClimaComms.@import_required_backends
48
using LinearAlgebra
59
import ClimaCore
10+
import ClimaCore.Utilities.Cache
611
import ClimaCore:
712
Domains,
813
Fields,
@@ -41,16 +46,7 @@ function init_state(local_geometry, p)
4146
return= ρ, u = u, ρθ = ρ * θ)
4247
end
4348

44-
@testset "restart test for 2D spectral element simulations" begin
45-
parameters = (
46-
ϵ = 0.1, # perturbation size for initial condition
47-
l = 0.5, # Gaussian width
48-
k = 0.5, # Sinusoidal wavenumber
49-
ρ₀ = 1.0, # reference density
50-
c = 2,
51-
g = 10,
52-
D₄ = 1e-4, # hyperdiffusion coefficient
53-
)
49+
function init_space(context; enable_bubble, enable_mask)
5450
domain = Domains.RectangleDomain(
5551
Domains.IntervalDomain(
5652
Geometry.XPoint(-2π),
@@ -67,31 +63,92 @@ end
6763
Nq = 4
6864
quad = Quadratures.GLL{Nq}()
6965
mesh = Meshes.RectilinearMesh(domain, n1, n2)
66+
grid_topology = Topologies.Topology2D(context, mesh)
67+
68+
return Spaces.SpectralElementSpace2D(
69+
grid_topology,
70+
quad;
71+
enable_bubble,
72+
enable_mask,
73+
)
74+
end
75+
76+
# function mktempfile(f)
77+
# mktempdir() do dir
78+
# cd(dir) do
79+
# f(tempname(dir))
80+
# end
81+
# end
82+
# end
83+
84+
function mktempfile(f, context)
85+
filename =
86+
ClimaComms.iamroot(context) ? tempname(pwd(); cleanup = true) : ""
87+
filename = ClimaComms.bcast(context, filename)
88+
f(filename)
89+
end
90+
91+
@testset "restart test for 2D spectral element simulations" begin
7092
device = ClimaComms.device()
7193
@info "Using device" device
72-
context = ClimaComms.SingletonCommsContext(device)
73-
grid_topology = Topologies.Topology2D(context, mesh)
94+
context = ClimaComms.context(device)
95+
parameters = (
96+
ϵ = 0.1, # perturbation size for initial condition
97+
l = 0.5, # Gaussian width
98+
k = 0.5, # Sinusoidal wavenumber
99+
ρ₀ = 1.0, # reference density
100+
c = 2,
101+
g = 10,
102+
D₄ = 1e-4, # hyperdiffusion coefficient
103+
)
104+
space = init_space(context; enable_bubble = true, enable_mask = false)
105+
y0 = init_state.(Fields.local_geometry_field(space), Ref(parameters))
106+
Y = Fields.FieldVector(y0 = y0)
107+
108+
# write field vector to hdf5 file
109+
mktempfile(context) do filename
110+
InputOutput.HDF5Writer(filename, context) do writer
111+
InputOutput.write!(writer, "Y" => Y) # write field vector from hdf5 file
112+
end
113+
Cache.clean_cache!()
114+
InputOutput.HDF5Reader(filename, context) do reader
115+
restart_Y = InputOutput.read_field(reader, "Y") # read fieldvector from hdf5 file
116+
@test restart_Y == Y # test if restart is exact
117+
@test axes(restart_Y) == axes(Y) # test if restart is exact for space
118+
end
119+
end
74120

75-
for enable_bubble in (true, false)
76-
space =
77-
Spaces.SpectralElementSpace2D(grid_topology, quad; enable_bubble)
121+
# Test with masks
122+
space = init_space(context; enable_bubble = true, enable_mask = true)
123+
y0 = init_state.(Fields.local_geometry_field(space), Ref(parameters))
124+
Y = Fields.FieldVector(y0 = y0)
78125

79-
y0 = init_state.(Fields.local_geometry_field(space), Ref(parameters))
80-
Y = Fields.FieldVector(y0 = y0)
126+
Spaces.set_mask!(space) do coords
127+
rand() > 0.5
128+
end
81129

82-
# write field vector to hdf5 file
83-
filename = tempname(pwd())
130+
# write field vector to hdf5 file
131+
mktempfile(context) do filename
84132
InputOutput.HDF5Writer(filename, context) do writer
85133
InputOutput.write!(writer, "Y" => Y) # write field vector from hdf5 file
86134
end
135+
87136
InputOutput.HDF5Reader(filename, context) do reader
137+
# We need to clean the cache so that the next read of space
138+
# does not use a pointer to the cached one.
139+
Cache.clean_cache!()
88140
restart_Y = InputOutput.read_field(reader, "Y") # read fieldvector from hdf5 file
89-
ClimaComms.allowscalar(device) do
90-
@test restart_Y == Y # test if restart is exact
91-
end
92-
ClimaComms.allowscalar(device) do
93-
@test axes(restart_Y) == axes(Y) # test if restart is exact for space
94-
end
141+
142+
is_active_restart =
143+
parent(Spaces.get_mask(axes(restart_Y.y0)).is_active)
144+
is_active = parent(Spaces.get_mask(axes(Y.y0)).is_active)
145+
@test is_active == is_active_restart
146+
147+
# Test that we're not doing trivial pointer comparisons
148+
@test !(axes(Y.y0) === axes(restart_Y.y0))
149+
@test restart_Y == Y
150+
@test typeof(axes(restart_Y.y0)) == typeof(axes(Y.y0))
151+
@test axes(restart_Y) == axes(Y) # test if restart is exact for space
95152
end
96153
end
97154
end

0 commit comments

Comments
 (0)