From 70b639df33c59ff1101d9709a3b04cd61e57778b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 21 Aug 2025 14:51:05 -0600 Subject: [PATCH 01/11] caught mistakes with exo 3d agg methods. refactored these and added more tests --- sup3r/models/with_obs.py | 16 +- sup3r/postprocessing/writers/base.py | 4 + sup3r/preprocessing/rasterizers/exo.py | 105 +++++++++--- tests/rasterizers/test_exo.py | 213 ++++++++++++++++++++----- 4 files changed, 269 insertions(+), 69 deletions(-) diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index 1921774851..91619b0eda 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -82,17 +82,21 @@ def __init__( self.loss_obs_fun = self.get_loss_fun(loss_obs) self.loss_obs_weight = loss_obs_weight - @tf.function + # @tf.function def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask): """Get loss for observation locations and for non observation locations.""" - hr_true = hi_res_true[..., : len(self.hr_out_features)] - loss_obs, _ = self.loss_obs_fun( - hi_res_gen[~obs_mask], hr_true[~obs_mask] + hr_true = tf.stack( + [hi_res_true[..., idx] for idx in self.obs_training_inds], + axis=-1, ) - loss_non_obs, _ = self.loss_obs_fun( - hi_res_gen[obs_mask], hr_true[obs_mask] + hr_gen = tf.stack( + [hi_res_gen[..., idx] for idx in self.obs_training_inds], + axis=-1, ) + mask = obs_mask[..., : hr_gen.shape[-1]] + loss_obs, _ = self.loss_obs_fun(hr_gen[~mask], hr_true[~mask]) + loss_non_obs, _ = self.loss_obs_fun(hr_gen[mask], hr_true[mask]) return loss_obs, loss_non_obs @property diff --git a/sup3r/postprocessing/writers/base.py b/sup3r/postprocessing/writers/base.py index 48aae1798e..d1f6384ea8 100644 --- a/sup3r/postprocessing/writers/base.py +++ b/sup3r/postprocessing/writers/base.py @@ -454,6 +454,10 @@ def get_lat_lon(cls, low_res_lat_lon, shape): (spatial_1, spatial_2, 2) Last dimension has ordering (lat, lon) """ + assert ( + low_res_lat_lon.shape[0] > 1 and low_res_lat_lon.shape[1] > 1 + ), 'low res lat/lon must have at least 2 rows and 2 columns' + logger.debug('Getting high resolution lat / lon grid') # ensure lons are between -180 and 180 diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index aeb824dc7e..5cb8f440c7 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -101,6 +101,12 @@ class will output a topography raster corresponding to the file_paths fill_nans : bool Whether to fill nans in the output data. This should probably be True for all cases except for sparse observation data. + agg_method : str + Method to use for aggregating source data to the target pixels. This + can be 'mean', 'idw' (inverse distance weighted average), or 'nn' + (nearest neighbor). The default is 'mean'. This is only used for 3D + data (time dependent) and will be ignored for 2D data (time + independent). scale_factor : float Scale factor to apply to the raw data from the source_files. This is useful for scaling observation data which might systematically under or @@ -125,6 +131,7 @@ class will output a topography raster corresponding to the file_paths chunks: Optional[Union[str, dict]] = 'auto' distance_upper_bound: Optional[int] = None fill_nans: bool = True + agg_method: str = 'mean' scale_factor: float = 1.0 max_workers: int = 1 verbose: bool = False @@ -141,6 +148,8 @@ def __post_init__(self): self._hr_lat_lon = None self._source_lat_lon = None self._hr_time_index = None + self._nn = None + self._dists = None self.input_handler_kwargs = self.input_handler_kwargs or {} self.source_handler_kwargs = self.source_handler_kwargs or {} InputHandler = get_input_handler_class(self.input_handler_name) @@ -278,6 +287,11 @@ def get_distance_upper_bound(self): """Maximum distance (float) to map high-resolution data from source_files to the low-resolution file_paths input.""" if self.distance_upper_bound is None: + assert self.hr_lat_lon.shape[0] > 1, ( + 'hr_lat_lon must have at least 2 lat points to calculate ' + 'distance upper bound. Either expand the grid or provide a ' + 'distance_upper_bound explicitly.' + ) diff = da.diff(self.hr_lat_lon, axis=0) diff = da.abs(da.median(diff, axis=0)).max() self.distance_upper_bound = np.asarray(diff) @@ -296,15 +310,29 @@ def tree(self): self._tree = KDTree(self.hr_lat_lon.reshape((-1, 2))) return self._tree + def query_tree(self, lat_lon): + """Query the KDTree for the nearest neighbor indices and distances + for the given lat_lon points.""" + return self.tree.query( + lat_lon, + distance_upper_bound=self.get_distance_upper_bound(), + ) + @property def nn(self): """Get the nearest neighbor indices. This uses a single neighbor by default""" - _, nn = self.tree.query( - self.source_lat_lon, - distance_upper_bound=self.get_distance_upper_bound(), - ) - return nn + if self._nn is None: + self._dists, self._nn = self.query_tree(self.source_lat_lon) + return self._nn + + @property + def dists(self): + """Get the nearest neighbor indices. This uses a single neighbor by + default""" + if self._dists is None: + self._dists, self._nn = self.query_tree(self.source_lat_lon) + return self._dists @property def data(self): @@ -371,6 +399,25 @@ def get_data(self): data_vars = {self.feature: data_vars} return Sup3rX(xr.Dataset(coords=self.coords, data_vars=data_vars)) + def _idw_fill(self, x): + """Compute weighted average for a group of data.""" + valid_mask = ~np.isnan(x[self.feature].values) + if valid_mask.sum() == 0: + return np.nan + weights = 1 / np.maximum(x['distance'], 1e-6) + return np.average( + x[self.feature][valid_mask], + weights=weights[valid_mask] / weights[valid_mask].sum(), + ) + + def _mean_fill(self, x): + """Compute standard average for a group of data.""" + return x[self.feature].mean() + + def _nn_fill(self, x): + """Select value with min distance.""" + return x[self.feature].iloc[np.argmin(x['distance'])] + def _get_data_2d(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance). This is used for time @@ -413,27 +460,39 @@ def _get_data_3d(self): out : np.ndarray 3D array of source data with shape (lats, lons, time). """ - assert ( - len(self.source_data.shape) == 2 and self.source_data.shape[1] > 1 - ) + if self.agg_method == 'idw': + func = self._idw_fill + elif self.agg_method == 'nn': + func = self._nn_fill + elif self.agg_method == 'mean': + func = self._mean_fill + else: + raise ValueError( + f'Unknown aggregation method: {self.agg_method}. ' + 'Must be one of "idw", "nn", or "mean".' + ) + logger.info('Using {} aggregation method'.format(self.agg_method)) target_tmask = self.hr_time_index.isin(self.source_handler.time_index) source_tmask = self.source_handler.time_index.isin(self.hr_time_index) data = self.source_data[:, source_tmask] rows = pd.MultiIndex.from_product( [self.nn, range(data.shape[-1])], names=['gid_target', 'time'] ) - df = pd.DataFrame({self.feature: data.flatten()}, index=rows) + dists = np.repeat(self.dists[:, None], data.shape[-1], axis=1) n_target = np.prod(self.hr_shape[:-1]) + df = pd.DataFrame( + {self.feature: data.flatten(), 'distance': dists.flatten()}, + index=rows, + ) + df = df[df.index.get_level_values(0) != n_target].sort_values( + 'gid_target' + ) + df = df.groupby(['gid_target', 'time']).apply(func) out = np.full( (n_target, len(self.hr_time_index)), np.nan, dtype=np.float32 ) - gids = df.index.get_level_values(0) - df = df[gids != n_target] - df = df.sort_values('gid_target') - df = df.groupby(['gid_target', 'time']).mean() - inds = gids.unique().values[gids.unique() != n_target][:, None] - vals = df[self.feature].values.reshape((-1, data.shape[-1])) - out[inds, target_tmask] = vals + inds = np.array(df.index.get_level_values(0).unique())[:, None] + out[inds, target_tmask] = df.values.reshape((-1, data.shape[-1])) out = out.reshape((*self.hr_shape[:-1], -1)) return out @@ -487,12 +546,6 @@ def _get_data_3d(self): """ hr_data = super()._get_data_3d() gid_mask = self.nn != np.prod(self.hr_shape[:-1]) - logger.info( - 'Found {} of {} observations within {:4f} degrees of high-' - 'resolution grid points.'.format( - gid_mask.sum(), len(gid_mask), self.distance_upper_bound - ) - ) cover_frac = (~np.isnan(hr_data)).sum() / hr_data.size if cover_frac == 0: msg = ( @@ -502,10 +555,14 @@ def _get_data_3d(self): warn(msg) logger.warning(msg) else: - msg = 'Observations cover {:.4e} of the high-res domain.'.format( - compute_if_dask(cover_frac) + msg = ( + f'Found {gid_mask.sum()} of {len(gid_mask)} observations ' + f'within {self.distance_upper_bound:4f} degrees of ' + 'high-resolution grid points. Observations cover ' + f'{compute_if_dask(cover_frac):.4e} of the high-res domain.' ) logger.info(msg) + self.fill_nans = False # override parent attribute to not fill nans return hr_data diff --git a/tests/rasterizers/test_exo.py b/tests/rasterizers/test_exo.py index cce39847a2..cdd56eb96c 100644 --- a/tests/rasterizers/test_exo.py +++ b/tests/rasterizers/test_exo.py @@ -7,8 +7,8 @@ import numpy as np import pandas as pd import pytest -import xarray as xr from rex import Resource +from scipy.spatial import KDTree from sup3r.postprocessing import RexOutputs from sup3r.preprocessing import ( @@ -18,7 +18,7 @@ ExoDataHandler, ExoRasterizer, ) -from sup3r.utilities.pytest.helpers import make_fake_dset, make_fake_nc_file +from sup3r.utilities.pytest.helpers import make_fake_dset from sup3r.utilities.utilities import RANDOM_GENERATOR, xr_open_mfdataset TARGET = (13.67, 125.0) @@ -329,59 +329,186 @@ def test_topo_extraction_nc(): assert np.allclose(te.source_data.flatten(), hr_elev.flatten()) -def test_obs_agg(): - """Test the aggregation of 3D data""" +@pytest.mark.parametrize( + ['s_enhance', 'with_nans'], + [ + (1, False), + (2, False), + (4, False), + (1, True), + (2, True), + (4, True), + ], +) +def test_obs_nn(s_enhance, with_nans): + """Test the aggregation of 3D data with nearest neighbor""" with TemporaryDirectory() as td: - make_fake_nc_file( - f'{td}/hr.nc', - (20, 20, 6), + dset = make_fake_dset( + (8, 8, 1), ['u_10m'], - lat_range=(0, 5), - lon_range=(0, 5), + lat_range=(20, 24), + lon_range=(-120, -116), ) - res = xr.open_mfdataset(f'{td}/hr.nc') - res = res.sx.coarsen( - {'south_north': 4, 'west_east': 4} - ).mean() + + if with_nans: + mask = RANDOM_GENERATOR.choice( + [True, False], dset['u_10m'].shape, p=[0.6, 0.4] + ) + u_10m = dset['u_10m'].values + u_10m[mask] = np.nan + dset['u_10m'] = (dset['u_10m'].dims, u_10m) + + res = dset.sx.coarsen({'south_north': 4, 'west_east': 4}).mean() res.to_netcdf(f'{td}/lr.nc') + lats = dset.latitude.values + RANDOM_GENERATOR.uniform( + 0, 1, dset.latitude.shape + ) + dset['latitude'] = (('south_north', 'west_east'), lats) + lons = dset.longitude.values + RANDOM_GENERATOR.uniform( + 0, 1, dset.longitude.shape + ) + dset['longitude'] = (('south_north', 'west_east'), lons) + dset.to_netcdf(f'{td}/hr.nc') + te = ExoRasterizer( file_paths=f'{td}/lr.nc', source_files=f'{td}/hr.nc', feature='u_10m_obs', - s_enhance=2, + s_enhance=s_enhance, t_enhance=1, + agg_method='nn', ) - agg_obs = np.asarray(te._get_data_3d()) - true_obs = te.source_handler['u_10m'].coarsen( - {'south_north': 2, 'west_east': 2} - ).mean().values - assert np.allclose(agg_obs, true_obs, equal_nan=True, rtol=0.01) - - -@pytest.mark.parametrize('s_enhance', [1, 2, 5, 10]) -def test_obs_agg_with_nans(s_enhance): - """Test the aggregation of 3D data with some NaNs""" + tsteps = len(te.source_handler.time_index) + agg_obs = np.asarray(te._get_data_3d()).reshape((-1, tsteps)) + true_obs = np.full(agg_obs.shape, np.nan, dtype=np.float32) + vals = te.source_handler['u_10m'].values.reshape((-1, tsteps)) + tree = KDTree(te.hr_lat_lon.reshape(-1, 2)) + dists, nn = tree.query( + te.source_lat_lon.reshape(-1, 2), + k=1, + distance_upper_bound=te.distance_upper_bound, + ) + for i in range(true_obs.shape[0]): + d_i = dists[nn == i] + if d_i.size == 0: + continue + idx = np.argmin(np.abs(d_i.min() - dists)) + true_obs[i, :] = vals[idx, :] + + mask = np.isnan(agg_obs) | np.isnan(true_obs) + assert np.allclose(agg_obs[~mask], true_obs[~mask], rtol=0.01) + + +@pytest.mark.parametrize( + ['s_enhance', 'with_nans'], + [ + (1, False), + (2, False), + (4, False), + (1, True), + (2, True), + (4, True), + ], +) +def test_obs_idw(s_enhance, with_nans): + """Test the aggregation of 3D data with inverse distance weighting""" with TemporaryDirectory() as td: dset = make_fake_dset( - (40, 40, 6), + (8, 8, 1), ['u_10m'], - lat_range=(0, 5), - lon_range=(0, 5), + lat_range=(20, 24), + lon_range=(-120, -116), ) - mask = RANDOM_GENERATOR.choice( - [True, False], dset['u_10m'].shape, p=[0.6, 0.4] + if with_nans: + mask = RANDOM_GENERATOR.choice( + [True, False], dset['u_10m'].shape, p=[0.6, 0.4] + ) + u_10m = dset['u_10m'].values + u_10m[mask] = np.nan + dset['u_10m'] = (dset['u_10m'].dims, u_10m) + + res = dset.sx.coarsen({'south_north': 4, 'west_east': 4}).mean() + res.to_netcdf(f'{td}/lr.nc') + + lats = dset.latitude.values + RANDOM_GENERATOR.uniform( + 0, 1, dset.latitude.shape + ) + dset['latitude'] = (('south_north', 'west_east'), lats) + lons = dset.longitude.values + RANDOM_GENERATOR.uniform( + 0, 1, dset.longitude.shape ) - u_10m = dset['u_10m'].values - u_10m[mask] = np.nan - dset['u_10m'] = (dset['u_10m'].dims, u_10m) + dset['longitude'] = (('south_north', 'west_east'), lons) dset.to_netcdf(f'{td}/hr.nc') - res = dset.sx.coarsen( - {'south_north': 10, 'west_east': 10} - ).mean() + te = ExoRasterizer( + file_paths=f'{td}/lr.nc', + source_files=f'{td}/hr.nc', + feature='u_10m_obs', + s_enhance=s_enhance, + t_enhance=1, + agg_method='idw', + ) + + tsteps = len(te.source_handler.time_index) + agg_obs = np.asarray(te._get_data_3d()).reshape((-1, tsteps)) + true_obs = np.full(agg_obs.shape, np.nan, dtype=np.float32) + vals = te.source_handler['u_10m'].values.reshape((-1, tsteps)) + tree = KDTree(te.hr_lat_lon.reshape(-1, 2)) + dists, nn = tree.query( + te.source_lat_lon.reshape(-1, 2), + k=1, + distance_upper_bound=te.distance_upper_bound, + ) + for i in range(true_obs.shape[0]): + d_i = dists[nn == i] + if d_i.size == 0: + continue + v_i = [] + for d in d_i: + idx = np.argmin(np.abs(d - dists)) + v_i.append(vals[idx, :]) + w_i = 1 / np.maximum(d_i, 1e-6) + w_i /= np.sum(w_i) + true_obs[i, :] = np.average(v_i, axis=0, weights=w_i) + + mask = np.isnan(agg_obs) | np.isnan(true_obs) + assert np.allclose(agg_obs[~mask], true_obs[~mask], rtol=0.01) + + +@pytest.mark.parametrize( + ['s_enhance', 'with_nans'], + [ + (1, False), + (2, False), + (4, False), + (1, True), + (2, True), + (4, True), + ], +) +def test_obs_agg(s_enhance, with_nans): + """Test the aggregation of 3D data with local mean""" + with TemporaryDirectory() as td: + dset = make_fake_dset( + (8, 8, 1), + ['u_10m'], + lat_range=(20, 24), + lon_range=(-120, -116), + ) + + if with_nans: + mask = RANDOM_GENERATOR.choice( + [True, False], dset['u_10m'].shape, p=[0.6, 0.4] + ) + u_10m = dset['u_10m'].values + u_10m[mask] = np.nan + dset['u_10m'] = (dset['u_10m'].dims, u_10m) + + res = dset.sx.coarsen({'south_north': 4, 'west_east': 4}).mean() res.to_netcdf(f'{td}/lr.nc') + dset.to_netcdf(f'{td}/hr.nc') te = ExoRasterizer( file_paths=f'{td}/lr.nc', @@ -389,9 +516,17 @@ def test_obs_agg_with_nans(s_enhance): feature='u_10m_obs', s_enhance=s_enhance, t_enhance=1, + agg_method='mean', ) agg_obs = np.asarray(te._get_data_3d()) - true_obs = te.source_handler['u_10m'].coarsen( - {'south_north': 10 // s_enhance, 'west_east': 10 // s_enhance} - ).mean().values - assert np.allclose(agg_obs, true_obs, equal_nan=True) + true_obs = ( + te.source_handler['u_10m'] + .coarsen({ + 'south_north': 4 // s_enhance, + 'west_east': 4 // s_enhance, + }) + .mean() + .values + ) + mask = np.isnan(agg_obs) | np.isnan(true_obs) + assert np.allclose(agg_obs[~mask], true_obs[~mask], rtol=0.01) From bad3b126db4c57090e38618d421dff5eef436382 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 27 Aug 2025 16:20:20 -0600 Subject: [PATCH 02/11] fix: optimize loss calculation and update test index selection methods --- sup3r/models/with_obs.py | 16 ++++++---------- tests/rasterizers/test_exo.py | 4 ++-- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index 91619b0eda..1921774851 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -82,21 +82,17 @@ def __init__( self.loss_obs_fun = self.get_loss_fun(loss_obs) self.loss_obs_weight = loss_obs_weight - # @tf.function + @tf.function def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask): """Get loss for observation locations and for non observation locations.""" - hr_true = tf.stack( - [hi_res_true[..., idx] for idx in self.obs_training_inds], - axis=-1, + hr_true = hi_res_true[..., : len(self.hr_out_features)] + loss_obs, _ = self.loss_obs_fun( + hi_res_gen[~obs_mask], hr_true[~obs_mask] ) - hr_gen = tf.stack( - [hi_res_gen[..., idx] for idx in self.obs_training_inds], - axis=-1, + loss_non_obs, _ = self.loss_obs_fun( + hi_res_gen[obs_mask], hr_true[obs_mask] ) - mask = obs_mask[..., : hr_gen.shape[-1]] - loss_obs, _ = self.loss_obs_fun(hr_gen[~mask], hr_true[~mask]) - loss_non_obs, _ = self.loss_obs_fun(hr_gen[mask], hr_true[mask]) return loss_obs, loss_non_obs @property diff --git a/tests/rasterizers/test_exo.py b/tests/rasterizers/test_exo.py index cdd56eb96c..67e5db9db6 100644 --- a/tests/rasterizers/test_exo.py +++ b/tests/rasterizers/test_exo.py @@ -393,7 +393,7 @@ def test_obs_nn(s_enhance, with_nans): d_i = dists[nn == i] if d_i.size == 0: continue - idx = np.argmin(np.abs(d_i.min() - dists)) + idx = np.abs(d_i.min() - dists).idxmin() true_obs[i, :] = vals[idx, :] mask = np.isnan(agg_obs) | np.isnan(true_obs) @@ -467,7 +467,7 @@ def test_obs_idw(s_enhance, with_nans): continue v_i = [] for d in d_i: - idx = np.argmin(np.abs(d - dists)) + idx = np.abs(d - dists).idxmin() v_i.append(vals[idx, :]) w_i = 1 / np.maximum(d_i, 1e-6) w_i /= np.sum(w_i) From 57205371467e610dce05bead752f4b0fd30c87d6 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 4 Sep 2025 14:01:33 -0600 Subject: [PATCH 03/11] refactor: move batch timing logging to the correct position in Sup3rGan --- sup3r/models/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 9ea5b03b88..55e58d56e6 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1172,8 +1172,6 @@ def _train_epoch( weight_gen_advers, multi_gpu, ) - elapsed = time.time() - start - logger.info('Finished batch in {:.4f} seconds'.format(elapsed)) loss_means = self._post_batch( ib, @@ -1181,6 +1179,8 @@ def _train_epoch( len(batch_handler), loss_means, ) + elapsed = time.time() - start + logger.info('Finished batch in {:.4f} seconds'.format(elapsed)) self.total_batches += len(batch_handler) loss_details = self._train_record.mean().to_dict() From 999aa4704a14794d651e0b3c8fe54f02ef8a77bc Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 4 Sep 2025 18:17:43 -0600 Subject: [PATCH 04/11] refactor: adjust timing logging for batch processing in Sup3rGan --- sup3r/models/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 55e58d56e6..cc759fa6dd 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1154,13 +1154,14 @@ def _train_epoch( tf.summary.trace_on(graph=True, profiler=True) for ib, batch in enumerate(batch_handler): + start = time.time() + b_loss_details = {} loss_disc = loss_means['train_loss_disc'] disc_too_good = loss_disc <= disc_th_low disc_too_bad = (loss_disc > disc_th_high) and train_disc gen_too_good = disc_too_bad - start = time.time() b_loss_details = self._train_batch( batch, train_gen, From 1e4a0a9abe21315be19873a00157b6f8fc161683 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 5 Sep 2025 07:04:52 -0600 Subject: [PATCH 05/11] refactor: integrate Timer for batch processing in Sup3rGan and adjust logging --- sup3r/models/abstract.py | 7 ++++--- sup3r/models/base.py | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index ecad25b618..d3e93a8431 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -22,7 +22,7 @@ from sup3r.preprocessing.data_handlers import ExoData from sup3r.preprocessing.utilities import numpy_if_tensor from sup3r.utilities import VERSION_RECORD -from sup3r.utilities.utilities import camel_to_underscore, safe_cast +from sup3r.utilities.utilities import Timer, camel_to_underscore, safe_cast from .utilities import SUP3R_LAYERS, SUP3R_OBS_LAYERS, TensorboardMixIn @@ -40,11 +40,12 @@ def __init__(self): super().__init__() self.gpu_list = tf.config.list_physical_devices('GPU') self.default_device = '/cpu:0' if len(self.gpu_list) == 0 else '/gpu:0' - self._version_record = VERSION_RECORD + self.timer = Timer() self.name = None - self._meta = None self.loss_name = None self.loss_fun = None + self._version_record = VERSION_RECORD + self._meta = None self._history = None self._optimizer = None self._gen = None diff --git a/sup3r/models/base.py b/sup3r/models/base.py index cc759fa6dd..6701d974aa 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1162,7 +1162,7 @@ def _train_epoch( disc_too_bad = (loss_disc > disc_th_high) and train_disc gen_too_good = disc_too_bad - b_loss_details = self._train_batch( + b_loss_details = self.timer(self._train_batch, log=True)( batch, train_gen, only_gen, @@ -1174,14 +1174,16 @@ def _train_epoch( multi_gpu, ) - loss_means = self._post_batch( - ib, - b_loss_details, - len(batch_handler), - loss_means, + """ + loss_means = self.timer(self._post_batch, log=True)( + ib, b_loss_details, len(batch_handler), loss_means + ) + """ + + logger.info( + f'Finished step {ib + 1} / {len(batch_handler)} in ' + f'{time.time() - start:.4f} seconds' ) - elapsed = time.time() - start - logger.info('Finished batch in {:.4f} seconds'.format(elapsed)) self.total_batches += len(batch_handler) loss_details = self._train_record.mean().to_dict() From f274e0964245bd3370983c430cc43f20fa6e1344 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 5 Sep 2025 08:42:24 -0600 Subject: [PATCH 06/11] refactor: update tensorboard logging conditions and clean up batch processing code in Sup3rGan --- sup3r/models/base.py | 11 +++++------ sup3r/models/utilities.py | 13 ++++++------- sup3r/preprocessing/batch_queues/abstract.py | 1 - 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 6701d974aa..6ce24edc8e 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -638,7 +638,7 @@ def train( adaptive_update_bounds=(0.9, 0.99), adaptive_update_fraction=0.0, multi_gpu=False, - tensorboard_log=True, + tensorboard_log=False, tensorboard_profile=False, ): """Train the GAN model on real low res data and real high res data @@ -1062,8 +1062,9 @@ def _post_batch(self, ib, b_loss_details, n_batches, previous_means): prefix='train_', ) - self.dict_to_tensorboard(b_loss_details) - self.dict_to_tensorboard(self.timer.log) + if self._tb_writer is not None: + self.dict_to_tensorboard(b_loss_details) + self.dict_to_tensorboard(self.timer.log) trained_gen = bool(self._train_record['gen_train_frac'].values[-1]) trained_disc = bool(self._train_record['disc_train_frac'].values[-1]) @@ -1174,14 +1175,12 @@ def _train_epoch( multi_gpu, ) - """ loss_means = self.timer(self._post_batch, log=True)( ib, b_loss_details, len(batch_handler), loss_means ) - """ logger.info( - f'Finished step {ib + 1} / {len(batch_handler)} in ' + f'Finished batch step {ib + 1} / {len(batch_handler)} in ' f'{time.time() - start:.4f} seconds' ) diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index 96cb86b9ac..07d169c397 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -109,13 +109,12 @@ def dict_to_tensorboard(self, entry): entry: dict Dictionary of values to write to tensorboard log file """ - if self._tb_writer is not None: - with self._tb_writer.as_default(): - for name, value in entry.items(): - if isinstance(value, str): - tf.summary.text(name, value, self.total_batches) - else: - tf.summary.scalar(name, value, self.total_batches) + with self._tb_writer.as_default(): + for name, value in entry.items(): + if isinstance(value, str): + tf.summary.text(name, value, self.total_batches) + else: + tf.summary.scalar(name, value, self.total_batches) def profile_to_tensorboard(self, name): """Write profile data to tensorboard log file. diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 3ced2d7eac..3b213c7049 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -275,7 +275,6 @@ def enqueue_batches(self) -> None: log_time = time.time() while self.running: needed = max(self.queue_cap - self.queue_len, 0) - needed = min(self.max_workers, needed) if needed > 0: batches = self.sample_batches(n_batches=needed) if needed > 1 and self.max_workers > 1: From d0e0dce503cc562b30d75389f4ee570eb9cf25e4 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 5 Sep 2025 10:13:28 -0600 Subject: [PATCH 07/11] refactor: streamline batch retrieval and processing logic in AbstractBatchQueue --- sup3r/preprocessing/batch_queues/abstract.py | 27 ++++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 3b213c7049..1a0dbd54fb 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -240,8 +240,15 @@ def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0: - return self.sample_batch() - return self.queue.dequeue() + samples = self.sample_batch() + else: + samples = self.queue.dequeue() + if self.sample_shape[2] == 1: + if isinstance(samples, (list, tuple)): + samples = tuple(s[..., 0, :] for s in samples) + else: + samples = samples[..., 0, :] + return self.post_proc(samples) @property def running(self): @@ -299,22 +306,8 @@ def __next__(self) -> DsetTuple: Batch object with batch.low_res and batch.high_res attributes """ if self._batch_count < self.n_batches: - self.timer.start() - samples = self.get_batch() - if self.sample_shape[2] == 1: - if isinstance(samples, (list, tuple)): - samples = tuple(s[..., 0, :] for s in samples) - else: - samples = samples[..., 0, :] - batch = self.post_proc(samples) - self.timer.stop() + batch = self.timer(self.get_batch, log=True)() self._batch_count += 1 - if self.verbose: - logger.debug( - 'Batch step %s finished in %s.', - self._batch_count, - self.timer.elapsed_str, - ) else: raise StopIteration return batch From 38b92cb8f7e3209a91700b1e308d3a41b0250031 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 10 Sep 2025 05:52:18 -0600 Subject: [PATCH 08/11] refactor: enhance logging verbosity for batch retrieval in AbstractBatchQueue --- sup3r/preprocessing/batch_queues/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 1a0dbd54fb..92ffe27fa1 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -306,7 +306,7 @@ def __next__(self) -> DsetTuple: Batch object with batch.low_res and batch.high_res attributes """ if self._batch_count < self.n_batches: - batch = self.timer(self.get_batch, log=True)() + batch = self.timer(self.get_batch, log=self.verbose)() self._batch_count += 1 else: raise StopIteration From 88738f05e4f9a067997c2b45b36c37e58c7d7506 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 11 Sep 2025 16:26:07 -0600 Subject: [PATCH 09/11] refactor: replace idxmin with argmin for improved performance in test_obs_nn and test_obs_idw --- tests/rasterizers/test_exo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/rasterizers/test_exo.py b/tests/rasterizers/test_exo.py index 67e5db9db6..cdd56eb96c 100644 --- a/tests/rasterizers/test_exo.py +++ b/tests/rasterizers/test_exo.py @@ -393,7 +393,7 @@ def test_obs_nn(s_enhance, with_nans): d_i = dists[nn == i] if d_i.size == 0: continue - idx = np.abs(d_i.min() - dists).idxmin() + idx = np.argmin(np.abs(d_i.min() - dists)) true_obs[i, :] = vals[idx, :] mask = np.isnan(agg_obs) | np.isnan(true_obs) @@ -467,7 +467,7 @@ def test_obs_idw(s_enhance, with_nans): continue v_i = [] for d in d_i: - idx = np.abs(d - dists).idxmin() + idx = np.argmin(np.abs(d - dists)) v_i.append(vals[idx, :]) w_i = 1 / np.maximum(d_i, 1e-6) w_i /= np.sum(w_i) From 4899cd86dd38ea11ffbee3fe0e49a94734c5e5cb Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 13 Oct 2025 10:02:07 -0600 Subject: [PATCH 10/11] removed idw and nn agg methods for exo data. haven't seen much difference and these dont need to be maintained. --- sup3r/preprocessing/rasterizers/exo.py | 42 +------ tests/rasterizers/test_exo.py | 149 ------------------------- 2 files changed, 4 insertions(+), 187 deletions(-) diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 5cb8f440c7..dd4b8329e8 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -101,12 +101,6 @@ class will output a topography raster corresponding to the file_paths fill_nans : bool Whether to fill nans in the output data. This should probably be True for all cases except for sparse observation data. - agg_method : str - Method to use for aggregating source data to the target pixels. This - can be 'mean', 'idw' (inverse distance weighted average), or 'nn' - (nearest neighbor). The default is 'mean'. This is only used for 3D - data (time dependent) and will be ignored for 2D data (time - independent). scale_factor : float Scale factor to apply to the raw data from the source_files. This is useful for scaling observation data which might systematically under or @@ -131,7 +125,6 @@ class will output a topography raster corresponding to the file_paths chunks: Optional[Union[str, dict]] = 'auto' distance_upper_bound: Optional[int] = None fill_nans: bool = True - agg_method: str = 'mean' scale_factor: float = 1.0 max_workers: int = 1 verbose: bool = False @@ -160,9 +153,9 @@ def __post_init__(self): @property def source_handler(self): """Get the Loader object that handles the exogenous data file.""" - assert ( - self.source_files is not None - ), 'source_files must be provided to BaseExoRasterizer' + assert self.source_files is not None, ( + 'source_files must be provided to BaseExoRasterizer' + ) if self._source_handler is None: self._source_handler = Loader( self.source_files, @@ -399,25 +392,10 @@ def get_data(self): data_vars = {self.feature: data_vars} return Sup3rX(xr.Dataset(coords=self.coords, data_vars=data_vars)) - def _idw_fill(self, x): - """Compute weighted average for a group of data.""" - valid_mask = ~np.isnan(x[self.feature].values) - if valid_mask.sum() == 0: - return np.nan - weights = 1 / np.maximum(x['distance'], 1e-6) - return np.average( - x[self.feature][valid_mask], - weights=weights[valid_mask] / weights[valid_mask].sum(), - ) - def _mean_fill(self, x): """Compute standard average for a group of data.""" return x[self.feature].mean() - def _nn_fill(self, x): - """Select value with min distance.""" - return x[self.feature].iloc[np.argmin(x['distance'])] - def _get_data_2d(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance). This is used for time @@ -460,18 +438,6 @@ def _get_data_3d(self): out : np.ndarray 3D array of source data with shape (lats, lons, time). """ - if self.agg_method == 'idw': - func = self._idw_fill - elif self.agg_method == 'nn': - func = self._nn_fill - elif self.agg_method == 'mean': - func = self._mean_fill - else: - raise ValueError( - f'Unknown aggregation method: {self.agg_method}. ' - 'Must be one of "idw", "nn", or "mean".' - ) - logger.info('Using {} aggregation method'.format(self.agg_method)) target_tmask = self.hr_time_index.isin(self.source_handler.time_index) source_tmask = self.source_handler.time_index.isin(self.hr_time_index) data = self.source_data[:, source_tmask] @@ -487,7 +453,7 @@ def _get_data_3d(self): df = df[df.index.get_level_values(0) != n_target].sort_values( 'gid_target' ) - df = df.groupby(['gid_target', 'time']).apply(func) + df = df.groupby(['gid_target', 'time']).apply(self._mean_fill) out = np.full( (n_target, len(self.hr_time_index)), np.nan, dtype=np.float32 ) diff --git a/tests/rasterizers/test_exo.py b/tests/rasterizers/test_exo.py index cdd56eb96c..bba9642634 100644 --- a/tests/rasterizers/test_exo.py +++ b/tests/rasterizers/test_exo.py @@ -8,7 +8,6 @@ import pandas as pd import pytest from rex import Resource -from scipy.spatial import KDTree from sup3r.postprocessing import RexOutputs from sup3r.preprocessing import ( @@ -329,154 +328,6 @@ def test_topo_extraction_nc(): assert np.allclose(te.source_data.flatten(), hr_elev.flatten()) -@pytest.mark.parametrize( - ['s_enhance', 'with_nans'], - [ - (1, False), - (2, False), - (4, False), - (1, True), - (2, True), - (4, True), - ], -) -def test_obs_nn(s_enhance, with_nans): - """Test the aggregation of 3D data with nearest neighbor""" - with TemporaryDirectory() as td: - dset = make_fake_dset( - (8, 8, 1), - ['u_10m'], - lat_range=(20, 24), - lon_range=(-120, -116), - ) - - if with_nans: - mask = RANDOM_GENERATOR.choice( - [True, False], dset['u_10m'].shape, p=[0.6, 0.4] - ) - u_10m = dset['u_10m'].values - u_10m[mask] = np.nan - dset['u_10m'] = (dset['u_10m'].dims, u_10m) - - res = dset.sx.coarsen({'south_north': 4, 'west_east': 4}).mean() - res.to_netcdf(f'{td}/lr.nc') - - lats = dset.latitude.values + RANDOM_GENERATOR.uniform( - 0, 1, dset.latitude.shape - ) - dset['latitude'] = (('south_north', 'west_east'), lats) - lons = dset.longitude.values + RANDOM_GENERATOR.uniform( - 0, 1, dset.longitude.shape - ) - dset['longitude'] = (('south_north', 'west_east'), lons) - dset.to_netcdf(f'{td}/hr.nc') - - te = ExoRasterizer( - file_paths=f'{td}/lr.nc', - source_files=f'{td}/hr.nc', - feature='u_10m_obs', - s_enhance=s_enhance, - t_enhance=1, - agg_method='nn', - ) - tsteps = len(te.source_handler.time_index) - agg_obs = np.asarray(te._get_data_3d()).reshape((-1, tsteps)) - true_obs = np.full(agg_obs.shape, np.nan, dtype=np.float32) - vals = te.source_handler['u_10m'].values.reshape((-1, tsteps)) - tree = KDTree(te.hr_lat_lon.reshape(-1, 2)) - dists, nn = tree.query( - te.source_lat_lon.reshape(-1, 2), - k=1, - distance_upper_bound=te.distance_upper_bound, - ) - for i in range(true_obs.shape[0]): - d_i = dists[nn == i] - if d_i.size == 0: - continue - idx = np.argmin(np.abs(d_i.min() - dists)) - true_obs[i, :] = vals[idx, :] - - mask = np.isnan(agg_obs) | np.isnan(true_obs) - assert np.allclose(agg_obs[~mask], true_obs[~mask], rtol=0.01) - - -@pytest.mark.parametrize( - ['s_enhance', 'with_nans'], - [ - (1, False), - (2, False), - (4, False), - (1, True), - (2, True), - (4, True), - ], -) -def test_obs_idw(s_enhance, with_nans): - """Test the aggregation of 3D data with inverse distance weighting""" - with TemporaryDirectory() as td: - dset = make_fake_dset( - (8, 8, 1), - ['u_10m'], - lat_range=(20, 24), - lon_range=(-120, -116), - ) - - if with_nans: - mask = RANDOM_GENERATOR.choice( - [True, False], dset['u_10m'].shape, p=[0.6, 0.4] - ) - u_10m = dset['u_10m'].values - u_10m[mask] = np.nan - dset['u_10m'] = (dset['u_10m'].dims, u_10m) - - res = dset.sx.coarsen({'south_north': 4, 'west_east': 4}).mean() - res.to_netcdf(f'{td}/lr.nc') - - lats = dset.latitude.values + RANDOM_GENERATOR.uniform( - 0, 1, dset.latitude.shape - ) - dset['latitude'] = (('south_north', 'west_east'), lats) - lons = dset.longitude.values + RANDOM_GENERATOR.uniform( - 0, 1, dset.longitude.shape - ) - dset['longitude'] = (('south_north', 'west_east'), lons) - dset.to_netcdf(f'{td}/hr.nc') - - te = ExoRasterizer( - file_paths=f'{td}/lr.nc', - source_files=f'{td}/hr.nc', - feature='u_10m_obs', - s_enhance=s_enhance, - t_enhance=1, - agg_method='idw', - ) - - tsteps = len(te.source_handler.time_index) - agg_obs = np.asarray(te._get_data_3d()).reshape((-1, tsteps)) - true_obs = np.full(agg_obs.shape, np.nan, dtype=np.float32) - vals = te.source_handler['u_10m'].values.reshape((-1, tsteps)) - tree = KDTree(te.hr_lat_lon.reshape(-1, 2)) - dists, nn = tree.query( - te.source_lat_lon.reshape(-1, 2), - k=1, - distance_upper_bound=te.distance_upper_bound, - ) - for i in range(true_obs.shape[0]): - d_i = dists[nn == i] - if d_i.size == 0: - continue - v_i = [] - for d in d_i: - idx = np.argmin(np.abs(d - dists)) - v_i.append(vals[idx, :]) - w_i = 1 / np.maximum(d_i, 1e-6) - w_i /= np.sum(w_i) - true_obs[i, :] = np.average(v_i, axis=0, weights=w_i) - - mask = np.isnan(agg_obs) | np.isnan(true_obs) - assert np.allclose(agg_obs[~mask], true_obs[~mask], rtol=0.01) - - @pytest.mark.parametrize( ['s_enhance', 'with_nans'], [ From 08c22969aa41885588fe6e3f30e7ff3ec3f70d8c Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 13 Oct 2025 11:09:33 -0600 Subject: [PATCH 11/11] removed agg_method arg in tests --- tests/rasterizers/test_exo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/rasterizers/test_exo.py b/tests/rasterizers/test_exo.py index bba9642634..48e3a8ee58 100644 --- a/tests/rasterizers/test_exo.py +++ b/tests/rasterizers/test_exo.py @@ -366,8 +366,7 @@ def test_obs_agg(s_enhance, with_nans): source_files=f'{td}/hr.nc', feature='u_10m_obs', s_enhance=s_enhance, - t_enhance=1, - agg_method='mean', + t_enhance=1 ) agg_obs = np.asarray(te._get_data_3d()) true_obs = (