Skip to content

Commit 3e9a1e7

Browse files
committed
Add support for PointSpaces with NetCDFWriter
When a point space results from computating a diagnostic, interpolation is not run. Additionally, no space coordinate is added to the output netcdf file. improve tests Fix vertical space coordinates bug Make point space diagnostics work with dict and hdf5 writers fix pointspace mpi
1 parent fd730b0 commit 3e9a1e7

File tree

6 files changed

+219
-20
lines changed

6 files changed

+219
-20
lines changed

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ v0.2.12
2929
-------
3030
## Bug fixes
3131

32-
- `NetCDFWriter` now correctly writes purely vertical spaces.
32+
- `NetCDFWriter` now correctly writes purely vertical and point spaces.
3333

3434
v0.2.11
3535
-------

src/clima_diagnostics.jl

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import Accessors
22
import SciMLBase
3+
import ClimaComms
4+
import ClimaCore: Spaces
35

46
import .Schedules: DivisorSchedule, EveryDtSchedule
5-
import .Writers: interpolate_field!, write_field!, sync, AbstractWriter
7+
import .Writers:
8+
interpolate_field!, write_field!, sync, AbstractWriter, NetCDFWriter
69

710
# We define all the known identities in reduction_identities.jl
811
include("reduction_identities.jl")
@@ -126,7 +129,24 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
126129

127130
# If it is not a reduction, call the output writer as well
128131
if !isa_time_reduction
129-
interpolate_field!(diag.output_writer, storage[i], diag, Y, p, t)
132+
# no need to interpolate for point spaces
133+
if axes(storage[i]) isa Spaces.PointSpace
134+
# netCDFWriter expects diagnostic to be in preallocated_output_arrays
135+
if diag.output_writer isa NetCDFWriter &&
136+
ClimaComms.iamroot(ClimaComms.context(storage[i]))
137+
diag.output_writer.preallocated_output_arrays[diag] =
138+
copy(parent(storage[i]))
139+
end
140+
else
141+
interpolate_field!(
142+
diag.output_writer,
143+
storage[i],
144+
diag,
145+
Y,
146+
p,
147+
t,
148+
)
149+
end
130150
write_field!(diag.output_writer, storage[i], diag, Y, p, t)
131151
else
132152
# Add to the accumulator
@@ -224,15 +244,25 @@ function orchestrate_diagnostics(
224244
diagnostic_handler.storage[diag_index],
225245
diagnostic_handler.counters[diag_index],
226246
)
227-
228-
interpolate_field!(
229-
diag.output_writer,
230-
diagnostic_handler.storage[diag_index],
231-
diag,
232-
integrator.u,
233-
integrator.p,
234-
integrator.t,
235-
)
247+
# dont interpolate for point spaces
248+
if axes(diagnostic_handler.storage[diag_index]) isa Spaces.PointSpace
249+
# netCDFWriter expects diagnostic to be in preallocated_output_arrays
250+
if diag.output_writer isa NetCDFWriter && ClimaComms.iamroot(
251+
ClimaComms.context(diagnostic_handler.storage[diag_index]),
252+
)
253+
diag.output_writer.preallocated_output_arrays[diag] =
254+
copy(parent(diagnostic_handler.storage[diag_index]))
255+
end
256+
else
257+
interpolate_field!(
258+
diag.output_writer,
259+
diagnostic_handler.storage[diag_index],
260+
diag,
261+
integrator.u,
262+
integrator.p,
263+
integrator.t,
264+
)
265+
end
236266
end
237267

238268
# Save to disk

src/hdf5_writer.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import ClimaComms
2-
import ClimaCore.InputOutput
2+
import ClimaCore: InputOutput, Spaces
33

44
import ClimaUtilities.TimeManager: ITime
55

@@ -38,6 +38,11 @@ The name of the file is determined by the `output_short_name` of the output
3838
`Field`s can be read back using the `InputOutput` module in `ClimaCore`.
3939
"""
4040
function write_field!(writer::HDF5Writer, field, diagnostic, u, p, t)
41+
axes(field) isa Spaces.PointSpace &&
42+
pkgversion(InputOutput) < v"0.14.27" &&
43+
error(
44+
"HDF5Writer only supports Fields with PointSpace for ClimaCore >= 0.14.27",
45+
)
4146
var = diagnostic.variable
4247
time = t
4348

src/netcdf_writer.jl

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@ include("netcdf_writer_coordinates.jl")
1818
A struct to remap `ClimaCore` `Fields` to rectangular grids and save the output to NetCDF
1919
files.
2020
"""
21-
struct NetCDFWriter{T, TS, DI, SYNC, ZSM <: AbstractZSamplingMethod, DATE} <:
22-
AbstractWriter
21+
struct NetCDFWriter{
22+
T,
23+
TS,
24+
DI,
25+
SYNC,
26+
ZSM <: Union{AbstractZSamplingMethod, Nothing},
27+
DATE,
28+
} <: AbstractWriter
2329
"""The base folder where to save the files."""
2430
output_dir::String
2531

@@ -262,6 +268,42 @@ function NetCDFWriter(
262268
)
263269
end
264270

271+
function NetCDFWriter(
272+
space::Spaces.Spaces.PointSpace,
273+
output_dir;
274+
compression_level = 9,
275+
sync_schedule = ClimaComms.device(space) isa ClimaComms.CUDADevice ?
276+
EveryStepSchedule() : nothing,
277+
start_date = nothing,
278+
kwargs...,
279+
)
280+
comms_ctx = ClimaComms.context(space)
281+
preallocated_arrays =
282+
ClimaComms.iamroot(comms_ctx) ?
283+
Dict{ScheduledDiagnostic, ClimaComms.array_type(space)}() :
284+
Dict{ScheduledDiagnostic, Nothing}()
285+
unsynced_datasets = Set{NCDatasets.NCDataset}()
286+
return NetCDFWriter{
287+
Nothing,
288+
Nothing,
289+
typeof(preallocated_arrays),
290+
typeof(sync_schedule),
291+
Nothing,
292+
typeof(start_date),
293+
}(
294+
output_dir,
295+
Dict{String, Remapper}(),
296+
nothing,
297+
compression_level,
298+
nothing,
299+
Dict{String, NCDatasets.NCDataset}(),
300+
nothing,
301+
preallocated_arrays,
302+
sync_schedule,
303+
unsynced_datasets,
304+
start_date,
305+
)
306+
end
265307
"""
266308
interpolate_field!(writer::NetCDFWriter, field, diagnostic, u, p, t)
267309
@@ -278,7 +320,7 @@ function interpolate_field!(writer::NetCDFWriter, field, diagnostic, u, p, t)
278320
if has_horizontal_space
279321
horizontal_space = Spaces.horizontal_space(space)
280322

281-
# We have to deal with to cases: when we have an horizontal slice (e.g., the
323+
# We have to deal with two cases: when we have an horizontal slice (e.g., the
282324
# surface), and when we have a full space. We distinguish these cases by checking if
283325
# the given space has the horizontal_space attribute. If not, it is going to be a
284326
# SpectralElementSpace2D and we don't have to deal with the z coordinates.
@@ -391,6 +433,11 @@ function write_field!(writer::NetCDFWriter, field, diagnostic, u, p, t)
391433
interpolated_field = permutedims(interpolated_field, perm)
392434
end
393435

436+
if space isa Spaces.PointSpace
437+
# If the space is a point space, we have to remove the singleton dimension
438+
interpolated_field = interpolated_field[]
439+
end
440+
394441
FT = Spaces.undertype(space)
395442

396443
output_path =
@@ -459,7 +506,8 @@ function write_field!(writer::NetCDFWriter, field, diagnostic, u, p, t)
459506
# We already have something in the file
460507
v = nc["$(var.short_name)"]
461508
temporal_size, spatial_size... = size(v)
462-
spatial_size == size(interpolated_field) ||
509+
interpolated_size = size(interpolated_field)
510+
spatial_size == interpolated_size ||
463511
error("incompatible dimensions for $(var.short_name)")
464512
else
465513
v = NCDatasets.defVar(

src/netcdf_writer_coordinates.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,19 @@ function add_space_coordinates_maybe!(
270270
return [name]
271271
end
272272

273+
# PointSpace
274+
function add_space_coordinates_maybe!(
275+
nc::NCDatasets.NCDataset,
276+
space::Spaces.PointSpace,
277+
num_points_z;
278+
z_sampling_method,
279+
names = (),
280+
interpolated_physical_z = nothing, # Not needed here, but needed for consistency of
281+
# interface and dispatch
282+
)
283+
return []
284+
end
285+
273286
add_space_coordinates_maybe!(
274287
nc::NCDatasets.NCDataset,
275288
space::Spaces.AbstractSpectralElementSpace,
@@ -338,6 +351,7 @@ function target_coordinates(
338351
return (longpts, latpts)
339352
end
340353

354+
islatlonbox(space::Spaces.PointSpace) = false
341355
islatlonbox(space::Spaces.FiniteDifferenceSpace) = false
342356
islatlonbox(space::Domains.AbstractDomain) = false
343357
function islatlonbox(space::Spaces.AbstractSpace)
@@ -533,16 +547,15 @@ function add_space_coordinates_maybe!(
533547
z_sampling_method,
534548
depending_on_dimensions,
535549
)
536-
num_points_z = num_points
537550
name, _... = names
538551

539552
# Add z_reference
540553
z_reference_dimension_dimension_exists =
541-
dimension_exists(nc, name, (num_points_z,))
554+
dimension_exists(nc, name, num_points)
542555

543556
if !z_reference_dimension_dimension_exists
544557
reference_altitudes =
545-
target_coordinates(space, num_points_z, z_sampling_method)
558+
target_coordinates(space, num_points, z_sampling_method)
546559
add_dimension!(nc, name, reference_altitudes; units = "m", axis = "Z")
547560
end
548561

test/writers.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import ProfileCanvas
88
import NCDatasets
99
import ClimaCore
1010
import ClimaCore.Fields
11+
import ClimaCore.Spaces
12+
import ClimaCore.Geometry
13+
import ClimaComms
1114

1215
import ClimaDiagnostics
1316
import ClimaDiagnostics.Writers
@@ -261,6 +264,106 @@ end
261264
end
262265
end
263266

267+
###############
268+
# Point Space #
269+
###############
270+
point_val = 3.14
271+
point_space =
272+
Spaces.PointSpace(ClimaComms.context(), Geometry.ZPoint(point_val))
273+
point_field = Fields.coordinate_field(point_space)
274+
point_writer = Writers.NetCDFWriter(point_space, output_dir)
275+
276+
point_u = (; field = point_field)
277+
278+
point_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(;
279+
variable = ClimaDiagnostics.DiagnosticVariable(;
280+
compute!,
281+
short_name = "ABC",
282+
),
283+
output_short_name = "my_short_name_point",
284+
output_long_name = "My Long Name Point",
285+
output_writer = point_writer,
286+
)
287+
point_writer.preallocated_output_arrays[point_diagnostic] = [point_val]
288+
# No interpolation needed for point space
289+
Writers.write_field!(
290+
point_writer,
291+
point_field,
292+
point_diagnostic,
293+
point_u,
294+
p,
295+
t,
296+
)
297+
# Write a second time
298+
Writers.write_field!(
299+
point_writer,
300+
point_field,
301+
point_diagnostic,
302+
point_u,
303+
p,
304+
t,
305+
)
306+
close(point_writer)
307+
308+
NCDatasets.NCDataset(joinpath(output_dir, "my_short_name_point.nc")) do nc
309+
@test nc["ABC"][:] == [point_val, point_val]
310+
end
311+
312+
###################
313+
# Horizontal Space#
314+
###################
315+
316+
horizontal_space = ClimaCore.Spaces.level(space, 1)
317+
horizontal_field = Fields.coordinate_field(horizontal_space).z
318+
horizontal_writer = Writers.NetCDFWriter(
319+
horizontal_space,
320+
output_dir;
321+
num_points = (NUM, 2NUM),
322+
)
323+
horizontal_u = (; field = horizontal_field)
324+
325+
horizontal_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(;
326+
variable = ClimaDiagnostics.DiagnosticVariable(;
327+
compute!,
328+
short_name = "ABC",
329+
),
330+
output_short_name = "my_short_name_horizontal",
331+
output_long_name = "My Long Name Point Horizontal",
332+
output_writer = horizontal_writer,
333+
)
334+
335+
Writers.interpolate_field!(
336+
horizontal_writer,
337+
horizontal_field,
338+
horizontal_diagnostic,
339+
horizontal_u,
340+
p,
341+
t,
342+
)
343+
Writers.write_field!(
344+
horizontal_writer,
345+
horizontal_field,
346+
horizontal_diagnostic,
347+
horizontal_u,
348+
p,
349+
t,
350+
)
351+
# Write a second time
352+
Writers.write_field!(
353+
horizontal_writer,
354+
horizontal_field,
355+
horizontal_diagnostic,
356+
horizontal_u,
357+
p,
358+
t,
359+
)
360+
close(horizontal_writer)
361+
NCDatasets.NCDataset(
362+
joinpath(output_dir, "my_short_name_horizontal.nc"),
363+
) do nc
364+
@test size(nc["ABC"]) == (2, NUM, 2NUM)
365+
end
366+
264367
###############
265368
# Performance #
266369
###############

0 commit comments

Comments
 (0)