diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 05f8764430..db2ba1b31d 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -227,7 +227,7 @@ class ForwardPassStrategy: max_nodes: int = 1 head_node: bool = False redistribute_chunks: bool = False - use_cpu: bool = False + use_cpu: bool = True @log_args def __post_init__(self): diff --git a/sup3r/postprocessing/collectors/nc.py b/sup3r/postprocessing/collectors/nc.py index 1885049bca..1907b0ace4 100644 --- a/sup3r/postprocessing/collectors/nc.py +++ b/sup3r/postprocessing/collectors/nc.py @@ -10,8 +10,8 @@ from rex.utilities.loggers import init_logger from sup3r.preprocessing.cachers import Cacher +from sup3r.preprocessing.loaders import Loader from sup3r.preprocessing.names import Dimension -from sup3r.utilities.utilities import xr_open_mfdataset from .base import BaseCollector @@ -32,14 +32,17 @@ def collect( overwrite=True, res_kwargs=None, cacher_kwargs=None, + is_regular_grid=True, ): """Collect data files from a dir to one output file. - TODO: This assumes that if there is any spatial chunking it is split - by latitude. This should be generalized to allow for any spatial - chunking and any dimension. This will either require a new file - naming scheme with a spatial index for both latitude and - longitude or checking each chunk to see how they are split. + TODO: For a regular grid (lat values are constant across lon and vice + versa) collecting lat / lon chunks is supported. For curvilinear grids + only collection of chunks that are split by latitude are supported. + This should be generalized to allow for any spatial chunking and any + dimension. I think this would require a new file naming scheme with a + spatial index for both latitude and longitude or checking each chunk + to see how they are split. Filename requirements: - Should end with ".nc" @@ -68,6 +71,9 @@ def collect( Dictionary of kwargs to pass to xarray.open_mfdataset. cacher_kwargs : dict | None Dictionary of kwargs to pass to Cacher._write_single. + is_regular_grid : bool + Whether the data is on a regular grid. If True then spatial chunks + can be combined across both latitude and longitude. """ logger.info(f'Initializing collection for file_paths={file_paths}') @@ -94,11 +100,25 @@ def collect( 'combine': 'nested', 'concat_dim': Dimension.TIME, } - for s_idx in spatial_chunks: - spatial_chunks[s_idx] = xr_open_mfdataset( - spatial_chunks[s_idx], **res_kwargs + for s_idx, sfiles in spatial_chunks.items(): + schunk = Loader(sfiles, res_kwargs=res_kwargs) + spatial_chunks[s_idx] = schunk + + # Set lat / lon as 1D arrays if regular grid and get the + # xr.Dataset _ds + if is_regular_grid: + spatial_chunks = { + s_idx: schunk.set_regular_grid()._ds + for s_idx, schunk in spatial_chunks.items() + } + out = xr.combine_by_coords(spatial_chunks.values(), + combine_attrs='override') + + else: + out = xr.concat( + spatial_chunks.values(), dim=Dimension.SOUTH_NORTH ) - out = xr.concat(spatial_chunks.values(), dim=Dimension.SOUTH_NORTH) + cacher_kwargs = cacher_kwargs or {} Cacher._write_single( out_file=out_file, diff --git a/sup3r/utilities/pytest/helpers.py b/sup3r/utilities/pytest/helpers.py index cf5654eba6..3380df0d8a 100644 --- a/sup3r/utilities/pytest/helpers.py +++ b/sup3r/utilities/pytest/helpers.py @@ -8,7 +8,7 @@ import pandas as pd import xarray as xr -from sup3r.postprocessing import OutputHandlerH5 +from sup3r.postprocessing import OutputHandlerH5, OutputHandlerNC from sup3r.preprocessing.base import Container, Sup3rDataset from sup3r.preprocessing.batch_handlers import BatchHandlerCC, BatchHandlerDC from sup3r.preprocessing.names import Dimension @@ -268,13 +268,15 @@ def sample_batch(self): return BatchHandlerTester -def make_collect_chunks(td): - """Make fake h5 chunked output files for collection tests. +def make_collect_chunks(td, ext='h5'): + """Make fake chunked output files for collection tests. Parameters ---------- td : tempfile.TemporaryDirectory Test TemporaryDirectory + ext : str + File extension for output files. Either 'h5' or 'nc'. Default is 'h5'. Returns ------- @@ -320,13 +322,15 @@ def make_collect_chunks(td): s_slices_hr = np.array_split(np.arange(shape[0]), 4) s_slices_hr = [slice(s[0], s[-1] + 1) for s in s_slices_hr] - out_pattern = os.path.join(td, 'fp_out_{t}_{s}.h5') + out_pattern = os.path.join(td, 'fp_out_{t}_{s}.' + ext) out_files = [] + + Writer = OutputHandlerNC if ext == 'nc' else OutputHandlerH5 for t, slice_hr in enumerate(t_slices_hr): for s, (s1_hr, s2_hr) in enumerate(product(s_slices_hr, s_slices_hr)): out_file = out_pattern.format(t=str(t).zfill(6), s=str(s).zfill(6)) out_files.append(out_file) - OutputHandlerH5._write_output( + Writer._write_output( data[s1_hr, s2_hr, slice_hr, :], features, hr_lat_lon[s1_hr, s2_hr], diff --git a/tests/output/test_output_handling.py b/tests/output/test_output_handling.py index 636e966c02..43f47d5d0d 100644 --- a/tests/output/test_output_handling.py +++ b/tests/output/test_output_handling.py @@ -9,6 +9,7 @@ from sup3r.postprocessing import ( CollectorH5, + CollectorNC, OutputHandlerH5, OutputHandlerNC, ) @@ -128,14 +129,14 @@ def test_invert_uv_inplace(): assert np.allclose(data[..., 1], wd) -def test_general_collect(): +def test_general_h5_collect(): """Make sure general file collection gives complete meta, time_index, and data array.""" with tempfile.TemporaryDirectory() as td: fp_out = os.path.join(td, 'out_combined.h5') - out = make_collect_chunks(td) + out = make_collect_chunks(td, ext='h5') out_files, data, features, hr_lat_lon, hr_times = ( out[0], out[1], @@ -160,6 +161,31 @@ def test_general_collect(): assert np.array_equal(base_data, collect_data) +def test_general_nc_collect(): + """Make sure general file collection gives complete meta, time_index, and + data array.""" + + with tempfile.TemporaryDirectory() as td: + fp_out = os.path.join(td, 'out_combined.nc') + + out = make_collect_chunks(td, ext='nc') + out_files, base_data, features, hr_lat_lon, hr_times = ( + out[0], + out[1], + out[-3], + out[-2], + out[-1], + ) + + CollectorNC.collect(out_files, fp_out, features=features, + is_regular_grid=True) + + with Loader(fp_out) as res: + assert np.array_equal(hr_times, res.time_index.values) + assert np.allclose(hr_lat_lon, res.lat_lon) + assert np.allclose(base_data, res.values) + + def test_h5_out_and_collect(collect_check): """Test h5 file output writing and collection with dummy data"""