From f54a3468b33346f7a41dd49d36b1ead2f414e819 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 12 Jun 2025 08:13:57 -0600 Subject: [PATCH 1/2] Update ForwardPassStrategy to use CPU by default; enhance NC collector for regular grids and add support for NC file collection tests --- sup3r/pipeline/strategy.py | 2 +- sup3r/postprocessing/collectors/nc.py | 40 ++++++++++++++++++++------- sup3r/utilities/pytest/helpers.py | 14 ++++++---- tests/output/test_output_handling.py | 30 ++++++++++++++++++-- 4 files changed, 68 insertions(+), 18 deletions(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 05f8764430..db2ba1b31d 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -227,7 +227,7 @@ class ForwardPassStrategy: max_nodes: int = 1 head_node: bool = False redistribute_chunks: bool = False - use_cpu: bool = False + use_cpu: bool = True @log_args def __post_init__(self): diff --git a/sup3r/postprocessing/collectors/nc.py b/sup3r/postprocessing/collectors/nc.py index 1885049bca..6356d4af37 100644 --- a/sup3r/postprocessing/collectors/nc.py +++ b/sup3r/postprocessing/collectors/nc.py @@ -10,8 +10,8 @@ from rex.utilities.loggers import init_logger from sup3r.preprocessing.cachers import Cacher +from sup3r.preprocessing.loaders import Loader from sup3r.preprocessing.names import Dimension -from sup3r.utilities.utilities import xr_open_mfdataset from .base import BaseCollector @@ -32,14 +32,17 @@ def collect( overwrite=True, res_kwargs=None, cacher_kwargs=None, + is_regular_grid=False, ): """Collect data files from a dir to one output file. - TODO: This assumes that if there is any spatial chunking it is split - by latitude. This should be generalized to allow for any spatial - chunking and any dimension. This will either require a new file - naming scheme with a spatial index for both latitude and - longitude or checking each chunk to see how they are split. + TODO: For a regular grid (lat values are constant across lon and vice + versa) collecting lat / lon chunks is supported. For curvilinear grids + only collection of chunks that are split by latitude are supported. + This should be generalized to allow for any spatial chunking and any + dimension. I think this would require a new file naming scheme with a + spatial index for both latitude and longitude or checking each chunk + to see how they are split. Filename requirements: - Should end with ".nc" @@ -68,6 +71,9 @@ def collect( Dictionary of kwargs to pass to xarray.open_mfdataset. cacher_kwargs : dict | None Dictionary of kwargs to pass to Cacher._write_single. + is_regular_grid : bool + Whether the data is on a regular grid. If True then spatial chunks + can be combined across both latitude and longitude. """ logger.info(f'Initializing collection for file_paths={file_paths}') @@ -94,11 +100,25 @@ def collect( 'combine': 'nested', 'concat_dim': Dimension.TIME, } - for s_idx in spatial_chunks: - spatial_chunks[s_idx] = xr_open_mfdataset( - spatial_chunks[s_idx], **res_kwargs + for s_idx, sfiles in spatial_chunks.items(): + schunk = Loader(sfiles, res_kwargs=res_kwargs) + spatial_chunks[s_idx] = schunk + + # Set lat / lon as 1D arrays if regular grid and get the + # xr.Dataset _ds + if is_regular_grid: + spatial_chunks = { + s_idx: schunk.set_regular_grid()._ds + for s_idx, schunk in spatial_chunks.items() + } + out = xr.combine_by_coords(spatial_chunks.values(), + combine_attrs='override') + + else: + out = xr.concat( + spatial_chunks.values(), dim=Dimension.SOUTH_NORTH ) - out = xr.concat(spatial_chunks.values(), dim=Dimension.SOUTH_NORTH) + cacher_kwargs = cacher_kwargs or {} Cacher._write_single( out_file=out_file, diff --git a/sup3r/utilities/pytest/helpers.py b/sup3r/utilities/pytest/helpers.py index cf5654eba6..3380df0d8a 100644 --- a/sup3r/utilities/pytest/helpers.py +++ b/sup3r/utilities/pytest/helpers.py @@ -8,7 +8,7 @@ import pandas as pd import xarray as xr -from sup3r.postprocessing import OutputHandlerH5 +from sup3r.postprocessing import OutputHandlerH5, OutputHandlerNC from sup3r.preprocessing.base import Container, Sup3rDataset from sup3r.preprocessing.batch_handlers import BatchHandlerCC, BatchHandlerDC from sup3r.preprocessing.names import Dimension @@ -268,13 +268,15 @@ def sample_batch(self): return BatchHandlerTester -def make_collect_chunks(td): - """Make fake h5 chunked output files for collection tests. +def make_collect_chunks(td, ext='h5'): + """Make fake chunked output files for collection tests. Parameters ---------- td : tempfile.TemporaryDirectory Test TemporaryDirectory + ext : str + File extension for output files. Either 'h5' or 'nc'. Default is 'h5'. Returns ------- @@ -320,13 +322,15 @@ def make_collect_chunks(td): s_slices_hr = np.array_split(np.arange(shape[0]), 4) s_slices_hr = [slice(s[0], s[-1] + 1) for s in s_slices_hr] - out_pattern = os.path.join(td, 'fp_out_{t}_{s}.h5') + out_pattern = os.path.join(td, 'fp_out_{t}_{s}.' + ext) out_files = [] + + Writer = OutputHandlerNC if ext == 'nc' else OutputHandlerH5 for t, slice_hr in enumerate(t_slices_hr): for s, (s1_hr, s2_hr) in enumerate(product(s_slices_hr, s_slices_hr)): out_file = out_pattern.format(t=str(t).zfill(6), s=str(s).zfill(6)) out_files.append(out_file) - OutputHandlerH5._write_output( + Writer._write_output( data[s1_hr, s2_hr, slice_hr, :], features, hr_lat_lon[s1_hr, s2_hr], diff --git a/tests/output/test_output_handling.py b/tests/output/test_output_handling.py index 636e966c02..43f47d5d0d 100644 --- a/tests/output/test_output_handling.py +++ b/tests/output/test_output_handling.py @@ -9,6 +9,7 @@ from sup3r.postprocessing import ( CollectorH5, + CollectorNC, OutputHandlerH5, OutputHandlerNC, ) @@ -128,14 +129,14 @@ def test_invert_uv_inplace(): assert np.allclose(data[..., 1], wd) -def test_general_collect(): +def test_general_h5_collect(): """Make sure general file collection gives complete meta, time_index, and data array.""" with tempfile.TemporaryDirectory() as td: fp_out = os.path.join(td, 'out_combined.h5') - out = make_collect_chunks(td) + out = make_collect_chunks(td, ext='h5') out_files, data, features, hr_lat_lon, hr_times = ( out[0], out[1], @@ -160,6 +161,31 @@ def test_general_collect(): assert np.array_equal(base_data, collect_data) +def test_general_nc_collect(): + """Make sure general file collection gives complete meta, time_index, and + data array.""" + + with tempfile.TemporaryDirectory() as td: + fp_out = os.path.join(td, 'out_combined.nc') + + out = make_collect_chunks(td, ext='nc') + out_files, base_data, features, hr_lat_lon, hr_times = ( + out[0], + out[1], + out[-3], + out[-2], + out[-1], + ) + + CollectorNC.collect(out_files, fp_out, features=features, + is_regular_grid=True) + + with Loader(fp_out) as res: + assert np.array_equal(hr_times, res.time_index.values) + assert np.allclose(hr_lat_lon, res.lat_lon) + assert np.allclose(base_data, res.values) + + def test_h5_out_and_collect(collect_check): """Test h5 file output writing and collection with dummy data""" From 815aa7baca02e12bcee241bd99e09145870be5a9 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 12 Jun 2025 08:25:55 -0600 Subject: [PATCH 2/2] is_regular_grid=True default --- sup3r/postprocessing/collectors/nc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/postprocessing/collectors/nc.py b/sup3r/postprocessing/collectors/nc.py index 6356d4af37..1907b0ace4 100644 --- a/sup3r/postprocessing/collectors/nc.py +++ b/sup3r/postprocessing/collectors/nc.py @@ -32,7 +32,7 @@ def collect( overwrite=True, res_kwargs=None, cacher_kwargs=None, - is_regular_grid=False, + is_regular_grid=True, ): """Collect data files from a dir to one output file.