Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class ForwardPassStrategy:
max_nodes: int = 1
head_node: bool = False
redistribute_chunks: bool = False
use_cpu: bool = False
use_cpu: bool = True

@log_args
def __post_init__(self):
Expand Down
40 changes: 30 additions & 10 deletions sup3r/postprocessing/collectors/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from rex.utilities.loggers import init_logger

from sup3r.preprocessing.cachers import Cacher
from sup3r.preprocessing.loaders import Loader
from sup3r.preprocessing.names import Dimension
from sup3r.utilities.utilities import xr_open_mfdataset

from .base import BaseCollector

Expand All @@ -32,14 +32,17 @@ def collect(
overwrite=True,
res_kwargs=None,
cacher_kwargs=None,
is_regular_grid=True,
):
"""Collect data files from a dir to one output file.

TODO: This assumes that if there is any spatial chunking it is split
by latitude. This should be generalized to allow for any spatial
chunking and any dimension. This will either require a new file
naming scheme with a spatial index for both latitude and
longitude or checking each chunk to see how they are split.
TODO: For a regular grid (lat values are constant across lon and vice
versa) collecting lat / lon chunks is supported. For curvilinear grids
only collection of chunks that are split by latitude are supported.
This should be generalized to allow for any spatial chunking and any
dimension. I think this would require a new file naming scheme with a
spatial index for both latitude and longitude or checking each chunk
to see how they are split.

Filename requirements:
- Should end with ".nc"
Expand Down Expand Up @@ -68,6 +71,9 @@ def collect(
Dictionary of kwargs to pass to xarray.open_mfdataset.
cacher_kwargs : dict | None
Dictionary of kwargs to pass to Cacher._write_single.
is_regular_grid : bool
Whether the data is on a regular grid. If True then spatial chunks
can be combined across both latitude and longitude.
"""
logger.info(f'Initializing collection for file_paths={file_paths}')

Expand All @@ -94,11 +100,25 @@ def collect(
'combine': 'nested',
'concat_dim': Dimension.TIME,
}
for s_idx in spatial_chunks:
spatial_chunks[s_idx] = xr_open_mfdataset(
spatial_chunks[s_idx], **res_kwargs
for s_idx, sfiles in spatial_chunks.items():
schunk = Loader(sfiles, res_kwargs=res_kwargs)
spatial_chunks[s_idx] = schunk

# Set lat / lon as 1D arrays if regular grid and get the
# xr.Dataset _ds
if is_regular_grid:
spatial_chunks = {
s_idx: schunk.set_regular_grid()._ds
for s_idx, schunk in spatial_chunks.items()
}
out = xr.combine_by_coords(spatial_chunks.values(),
combine_attrs='override')

else:
out = xr.concat(
spatial_chunks.values(), dim=Dimension.SOUTH_NORTH
)
out = xr.concat(spatial_chunks.values(), dim=Dimension.SOUTH_NORTH)

cacher_kwargs = cacher_kwargs or {}
Cacher._write_single(
out_file=out_file,
Expand Down
14 changes: 9 additions & 5 deletions sup3r/utilities/pytest/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import xarray as xr

from sup3r.postprocessing import OutputHandlerH5
from sup3r.postprocessing import OutputHandlerH5, OutputHandlerNC
from sup3r.preprocessing.base import Container, Sup3rDataset
from sup3r.preprocessing.batch_handlers import BatchHandlerCC, BatchHandlerDC
from sup3r.preprocessing.names import Dimension
Expand Down Expand Up @@ -268,13 +268,15 @@ def sample_batch(self):
return BatchHandlerTester


def make_collect_chunks(td):
"""Make fake h5 chunked output files for collection tests.
def make_collect_chunks(td, ext='h5'):
"""Make fake chunked output files for collection tests.

Parameters
----------
td : tempfile.TemporaryDirectory
Test TemporaryDirectory
ext : str
File extension for output files. Either 'h5' or 'nc'. Default is 'h5'.

Returns
-------
Expand Down Expand Up @@ -320,13 +322,15 @@ def make_collect_chunks(td):
s_slices_hr = np.array_split(np.arange(shape[0]), 4)
s_slices_hr = [slice(s[0], s[-1] + 1) for s in s_slices_hr]

out_pattern = os.path.join(td, 'fp_out_{t}_{s}.h5')
out_pattern = os.path.join(td, 'fp_out_{t}_{s}.' + ext)
out_files = []

Writer = OutputHandlerNC if ext == 'nc' else OutputHandlerH5
for t, slice_hr in enumerate(t_slices_hr):
for s, (s1_hr, s2_hr) in enumerate(product(s_slices_hr, s_slices_hr)):
out_file = out_pattern.format(t=str(t).zfill(6), s=str(s).zfill(6))
out_files.append(out_file)
OutputHandlerH5._write_output(
Writer._write_output(
data[s1_hr, s2_hr, slice_hr, :],
features,
hr_lat_lon[s1_hr, s2_hr],
Expand Down
30 changes: 28 additions & 2 deletions tests/output/test_output_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sup3r.postprocessing import (
CollectorH5,
CollectorNC,
OutputHandlerH5,
OutputHandlerNC,
)
Expand Down Expand Up @@ -128,14 +129,14 @@ def test_invert_uv_inplace():
assert np.allclose(data[..., 1], wd)


def test_general_collect():
def test_general_h5_collect():
"""Make sure general file collection gives complete meta, time_index, and
data array."""

with tempfile.TemporaryDirectory() as td:
fp_out = os.path.join(td, 'out_combined.h5')

out = make_collect_chunks(td)
out = make_collect_chunks(td, ext='h5')
out_files, data, features, hr_lat_lon, hr_times = (
out[0],
out[1],
Expand All @@ -160,6 +161,31 @@ def test_general_collect():
assert np.array_equal(base_data, collect_data)


def test_general_nc_collect():
"""Make sure general file collection gives complete meta, time_index, and
data array."""

with tempfile.TemporaryDirectory() as td:
fp_out = os.path.join(td, 'out_combined.nc')

out = make_collect_chunks(td, ext='nc')
out_files, base_data, features, hr_lat_lon, hr_times = (
out[0],
out[1],
out[-3],
out[-2],
out[-1],
)

CollectorNC.collect(out_files, fp_out, features=features,
is_regular_grid=True)

with Loader(fp_out) as res:
assert np.array_equal(hr_times, res.time_index.values)
assert np.allclose(hr_lat_lon, res.lat_lon)
assert np.allclose(base_data, res.values)


def test_h5_out_and_collect(collect_check):
"""Test h5 file output writing and collection with dummy data"""

Expand Down