Skip to content

Commit fd730b0

Browse files
SbozzoloimreddyTeja
authored andcommitted
Add support for purely vertical spaces with NetCDF
Leveraging the new feature introduced in ClimaCore, this commit adds support to interpolating purely vertical fields and saving them as a NetCDF file.
1 parent dbe41fe commit fd730b0

File tree

5 files changed

+186
-57
lines changed

5 files changed

+186
-57
lines changed

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ This release adds support for working with `ITime`s. In particular,
2525
provided to make an `EveryCalendarDtSchedule` using `ITime`s. Lastly, there are
2626
small changes to the writers to support `ITime`s.
2727

28+
v0.2.12
29+
-------
30+
## Bug fixes
31+
32+
- `NetCDFWriter` now correctly writes purely vertical spaces.
33+
2834
v0.2.11
2935
-------
3036
## Bug fixes

src/netcdf_writer.jl

Lines changed: 104 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Keyword arguments
113113
- `start_date`: Date of the beginning of the simulation.
114114
"""
115115
function NetCDFWriter(
116-
space,
116+
space::Spaces.AbstractSpace,
117117
output_dir;
118118
num_points = (180, 90, 50),
119119
compression_level = 9,
@@ -122,6 +122,10 @@ function NetCDFWriter(
122122
z_sampling_method = LevelsMethod(),
123123
start_date = nothing,
124124
)
125+
has_horizontal_space =
126+
space isa Spaces.ExtrudedFiniteDifferenceSpace ||
127+
space isa Spaces.AbstractSpectralElementSpace
128+
125129
horizontal_space = Spaces.horizontal_space(space)
126130
is_horizontal_space = horizontal_space == space
127131

@@ -201,6 +205,63 @@ function NetCDFWriter(
201205
)
202206
end
203207

208+
function NetCDFWriter(
209+
space::Spaces.Spaces.FiniteDifferenceSpace,
210+
output_dir;
211+
num_points = (180, 90, 50),
212+
compression_level = 9,
213+
sync_schedule = ClimaComms.device(space) isa ClimaComms.CUDADevice ?
214+
EveryStepSchedule() : nothing,
215+
z_sampling_method = LevelsMethod(),
216+
start_date = nothing,
217+
)
218+
if z_sampling_method isa LevelsMethod
219+
num_vpts = Meshes.nelements(Grids.vertical_topology(space).mesh)
220+
@warn "Disabling vertical interpolation, the provided number of points is ignored (using $num_vpts)"
221+
num_points = (num_vpts,)
222+
end
223+
vpts = target_coordinates(space, num_points, z_sampling_method)
224+
target_zcoords = Geometry.ZPoint.(vpts)
225+
remapper = Remapper(space; target_zcoords)
226+
227+
comms_ctx = ClimaComms.context(space)
228+
229+
coords_z = Fields.coordinate_field(space).z
230+
maybe_move_to_cpu =
231+
ClimaComms.device(coords_z) isa ClimaComms.CUDADevice &&
232+
ClimaComms.iamroot(comms_ctx) ? Array : identity
233+
234+
interpolated_physical_z = maybe_move_to_cpu(interpolate(remapper, coords_z))
235+
236+
preallocated_arrays =
237+
ClimaComms.iamroot(comms_ctx) ?
238+
Dict{ScheduledDiagnostic, ClimaComms.array_type(space)}() :
239+
Dict{ScheduledDiagnostic, Nothing}()
240+
241+
unsynced_datasets = Set{NCDatasets.NCDataset}()
242+
243+
return NetCDFWriter{
244+
typeof(num_points),
245+
typeof(interpolated_physical_z),
246+
typeof(preallocated_arrays),
247+
typeof(sync_schedule),
248+
typeof(z_sampling_method),
249+
typeof(start_date),
250+
}(
251+
output_dir,
252+
Dict{String, Remapper}(),
253+
num_points,
254+
compression_level,
255+
interpolated_physical_z,
256+
Dict{String, NCDatasets.NCDataset}(),
257+
z_sampling_method,
258+
preallocated_arrays,
259+
sync_schedule,
260+
unsynced_datasets,
261+
start_date,
262+
)
263+
end
264+
204265
"""
205266
interpolate_field!(writer::NetCDFWriter, field, diagnostic, u, p, t)
206267
@@ -212,61 +273,62 @@ function interpolate_field!(writer::NetCDFWriter, field, diagnostic, u, p, t)
212273

213274
space = axes(field)
214275

215-
horizontal_space = Spaces.horizontal_space(space)
276+
has_horizontal_space = !(space isa Spaces.FiniteDifferenceSpace)
216277

217-
# We have to deal with to cases: when we have an horizontal slice (e.g., the
218-
# surface), and when we have a full space. We distinguish these cases by checking if
219-
# the given space has the horizontal_space attribute. If not, it is going to be a
220-
# SpectralElementSpace2D and we don't have to deal with the z coordinates.
221-
is_horizontal_space = horizontal_space == space
278+
if has_horizontal_space
279+
horizontal_space = Spaces.horizontal_space(space)
280+
281+
# We have to deal with to cases: when we have an horizontal slice (e.g., the
282+
# surface), and when we have a full space. We distinguish these cases by checking if
283+
# the given space has the horizontal_space attribute. If not, it is going to be a
284+
# SpectralElementSpace2D and we don't have to deal with the z coordinates.
285+
is_horizontal_space = horizontal_space == space
286+
end
222287

223288
# Prepare the remapper if we don't have one for the given variable. We need one remapper
224289
# per variable (not one per diagnostic since all the time reductions return the same
225290
# type of space).
226291

227-
# TODO: Expand this once we support spatial reductions
292+
# TODO: Expand this once we support spatial reductions.
293+
# TODO: More generally, this can be clean up to have less conditionals
294+
# depending on the type of space and use dispatch instead
228295
if !haskey(writer.remappers, var.short_name)
229296

230297
# hpts, vpts are ranges of numbers
231-
# hcoords, zcoords are ranges of Geometry.Points
232-
233-
zcoords = []
234-
235-
if is_horizontal_space
236-
hpts = target_coordinates(space, writer.num_points)
237-
vpts = []
298+
# target_hcoords, target_zcoords are ranges of Geometry.Points
299+
300+
target_zcoords = nothing
301+
target_hcoords = nothing
302+
303+
if has_horizontal_space
304+
if is_horizontal_space
305+
hpts = target_coordinates(space, writer.num_points)
306+
vpts = []
307+
else
308+
hpts, vpts = target_coordinates(
309+
space,
310+
writer.num_points,
311+
writer.z_sampling_method,
312+
)
313+
end
314+
315+
target_hcoords = hcoords_from_horizontal_space(
316+
horizontal_space,
317+
Meshes.domain(Spaces.topology(horizontal_space)),
318+
hpts,
319+
)
238320
else
239-
hpts, vpts = target_coordinates(
321+
vpts = target_coordinates(
240322
space,
241323
writer.num_points,
242324
writer.z_sampling_method,
243325
)
244326
end
245327

246-
hcoords = hcoords_from_horizontal_space(
247-
horizontal_space,
248-
Meshes.domain(Spaces.topology(horizontal_space)),
249-
hpts,
250-
)
251-
252-
# When we disable vertical_interpolation, we override the vertical points with
253-
# the reference values for the vertical space.
254-
if writer.z_sampling_method isa LevelsMethod && !is_horizontal_space
255-
# We need Array(parent()) because we want an array of values, not a DataLayout
256-
# of Points
257-
vpts = Array(
258-
parent(
259-
space.grid.vertical_grid.center_local_geometry.coordinates,
260-
),
261-
)[
262-
:,
263-
1,
264-
]
265-
end
328+
target_zcoords = Geometry.ZPoint.(vpts)
266329

267-
zcoords = [Geometry.ZPoint(p) for p in vpts]
268-
269-
writer.remappers[var.short_name] = Remapper(space, hcoords, zcoords)
330+
writer.remappers[var.short_name] =
331+
Remapper(space, target_hcoords, target_zcoords)
270332
end
271333

272334
remapper = writer.remappers[var.short_name]
@@ -321,9 +383,7 @@ function write_field!(writer::NetCDFWriter, field, diagnostic, u, p, t)
321383
interpolated_field =
322384
maybe_move_to_cpu(writer.preallocated_output_arrays[diagnostic])
323385

324-
if islatlonbox(
325-
Meshes.domain(Spaces.topology(Spaces.horizontal_space(space))),
326-
)
386+
if islatlonbox(space)
327387
# ClimaCore works with LatLong points, but we want to have longitude
328388
# first in the output, so we have to flip things
329389
perm = collect(1:length(size(interpolated_field)))
@@ -441,14 +501,8 @@ function write_field!(writer::NetCDFWriter, field, diagnostic, u, p, t)
441501
[date_type(nc["date"][time_index - 1]); curr_date]
442502
end
443503

444-
# TODO: It would be nice to find a cleaner way to do this
445-
if length(dim_names) == 3
446-
v[time_index, :, :, :] = interpolated_field
447-
elseif length(dim_names) == 2
448-
v[time_index, :, :] = interpolated_field
449-
elseif length(dim_names) == 1
450-
v[time_index, :] = interpolated_field
451-
end
504+
colons = ntuple(_ -> Colon(), length(dim_names))
505+
v[time_index, colons...] = interpolated_field
452506

453507
# Add file to list of files that might need manual sync
454508
push!(writer.unsynced_datasets, nc)

src/netcdf_writer_coordinates.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ function target_coordinates(
226226
# We assume H to be 7000, which is a good scale height for the Earth atmosphere
227227
H_EARTH = 7000
228228

229-
num_points_z = num_points[]
229+
num_points_z = last(num_points)
230230
FT = Spaces.undertype(space)
231231
topology = Spaces.topology(space)
232232
vert_domain = topology.mesh.domain
@@ -257,9 +257,12 @@ function add_space_coordinates_maybe!(
257257
num_points_z;
258258
z_sampling_method,
259259
names = ("z",),
260+
interpolated_physical_z = nothing, # Not needed here, but needed for consistency of
261+
# interface and dispatch
260262
)
261263
name, _... = names
262-
z_dimension_exists = dimension_exists(nc, name, (num_points_z,))
264+
z_dimension_exists = dimension_exists(nc, name, num_points_z)
265+
263266
if !z_dimension_exists
264267
zpts = target_coordinates(space, num_points_z, z_sampling_method)
265268
add_dimension!(nc, name, zpts, units = "m", axis = "Z")
@@ -335,14 +338,21 @@ function target_coordinates(
335338
return (longpts, latpts)
336339
end
337340

338-
islatlonbox(domain) = false
341+
islatlonbox(space::Spaces.FiniteDifferenceSpace) = false
342+
islatlonbox(space::Domains.AbstractDomain) = false
343+
function islatlonbox(space::Spaces.AbstractSpace)
344+
return islatlonbox(
345+
Meshes.domain(Spaces.topology(Spaces.horizontal_space(space))),
346+
)
347+
end
339348

340349
# Box
341350
function islatlonbox(domain::Domains.RectangleDomain)
342351
return domain.interval1.coord_max isa Geometry.LatPoint &&
343352
domain.interval2.coord_max isa Geometry.LongPoint
344353
end
345354

355+
346356
function add_space_coordinates_maybe!(
347357
nc::NCDatasets.NCDataset,
348358
space::Spaces.SpectralElementSpace2D,
@@ -478,14 +488,14 @@ function add_space_coordinates_maybe!(
478488
vdims_names = add_space_coordinates_maybe!(
479489
nc,
480490
vertical_space,
481-
num_points_vertic;
491+
(num_points_vertic,);
482492
z_sampling_method,
483493
)
484494
else
485495
vdims_names = add_space_coordinates_maybe!(
486496
nc,
487497
vertical_space,
488-
num_points_vertic,
498+
(num_points_vertic,),
489499
interpolated_physical_z;
490500
z_sampling_method,
491501
names = ("z_reference",),

test/TestTools.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@ function ColumnCenterFiniteDifferenceSpace(
1414
context = ClimaComms.SingletonCommsContext();
1515
FT = Float64,
1616
)
17+
return _column(
18+
zelem,
19+
ClimaCore.Spaces.CenterFiniteDifferenceSpace,
20+
context,
21+
FT,
22+
)
23+
end
24+
25+
function ColumnFaceFiniteDifferenceSpace(
26+
zelem = 10,
27+
context = ClimaComms.SingletonCommsContext();
28+
FT = Float64,
29+
)
30+
return _column(
31+
zelem,
32+
ClimaCore.Spaces.FaceFiniteDifferenceSpace,
33+
context,
34+
FT,
35+
)
36+
end
37+
38+
function _column(zelem, constructor, context, FT)
1739
zlim = (FT(0.0), FT(1.0))
1840
domain = ClimaCore.Domains.IntervalDomain(
1941
ClimaCore.Geometry.ZPoint(zlim[1]),
@@ -22,9 +44,10 @@ function ColumnCenterFiniteDifferenceSpace(
2244
)
2345
mesh = ClimaCore.Meshes.IntervalMesh(domain, nelems = zelem)
2446
topology = ClimaCore.Topologies.IntervalTopology(context, mesh)
25-
return ClimaCore.Spaces.CenterFiniteDifferenceSpace(topology)
47+
return constructor(topology)
2648
end
2749

50+
2851
function SphericalShellSpace(;
2952
radius = 6371.0,
3053
height = 10.0,

test/writers.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Profile
66
using BenchmarkTools
77
import ProfileCanvas
88
import NCDatasets
9+
import ClimaCore
910
import ClimaCore.Fields
1011

1112
import ClimaDiagnostics
@@ -17,7 +18,7 @@ include("TestTools.jl")
1718

1819
# The temporary directory where we write the file cannot be in /tmp, it has
1920
# to be on disk
20-
output_dir = mktempdir(".")
21+
output_dir = mktempdir(pwd())
2122

2223
@testset "DictWriter" begin
2324
writer = Writers.DictWriter()
@@ -225,6 +226,41 @@ end
225226
t,
226227
)
227228

229+
# Check columns
230+
if pkgversion(ClimaCore) >= v"0.14.23"
231+
# Center space
232+
for (i, colspace) in enumerate((
233+
ColumnCenterFiniteDifferenceSpace(),
234+
ColumnFaceFiniteDifferenceSpace(),
235+
))
236+
colfield = Fields.coordinate_field(colspace).z
237+
238+
colwriter =
239+
Writers.NetCDFWriter(colspace, output_dir; num_points = (NUM,))
240+
coldiagnostic = ClimaDiagnostics.ScheduledDiagnostic(;
241+
variable = ClimaDiagnostics.DiagnosticVariable(;
242+
compute!,
243+
short_name = "ABC",
244+
),
245+
output_short_name = "my_short_name_c$(i)",
246+
output_long_name = "My Long Name",
247+
output_writer = colwriter,
248+
)
249+
colu = (; colfield)
250+
Writers.interpolate_field!(
251+
colwriter,
252+
colfield,
253+
coldiagnostic,
254+
colu,
255+
p,
256+
t,
257+
)
258+
Writers.write_field!(colwriter, colfield, coldiagnostic, colu, p, t)
259+
# Write a second time, to check consistency
260+
Writers.write_field!(colwriter, colfield, coldiagnostic, colu, p, t)
261+
end
262+
end
263+
228264
###############
229265
# Performance #
230266
###############

0 commit comments

Comments
 (0)