diff --git a/sup3r/preprocessing/data_handlers/factory.py b/sup3r/preprocessing/data_handlers/factory.py index 1115ec792a..1eec3c46ce 100644 --- a/sup3r/preprocessing/data_handlers/factory.py +++ b/sup3r/preprocessing/data_handlers/factory.py @@ -47,6 +47,7 @@ def __init__( self, file_paths, features='all', + load_features='all', res_kwargs: Optional[dict] = None, chunks: Union[str, Dict[str, int]] = 'auto', target: Optional[tuple] = None, @@ -69,9 +70,16 @@ def __init__( file_paths : str | list | pathlib.Path file_paths input to LoaderClass features : list | str - Features to load and / or derive. If 'all' then all available raw - features will be loaded. Specify explicit feature names for - derivations. + Features to derive. If 'all' then all available raw features will + just be loaded. Specify explicit feature names for derivations. + load_features : list | str + Features to load and make available for derivations. If 'all' then + all available raw features will be loaded and made available for + derivations. This can be used to restrict features used for + derivations. For example, to derive 'temperature_100m' from only + temperature isobars, from data that includes single level values as + well (like temperature_2m), don't include 'temperature_2m' in the + ``load_features`` list. res_kwargs : dict Additional keyword arguments passed through to the ``BaseLoader``. BaseLoader is usually xr.open_mfdataset for NETCDF files and @@ -146,12 +154,13 @@ def __init__( ) just_coords = not features - raster_feats = 'all' if any(missing_features) else [] + raster_feats = load_features if any(missing_features) else [] self.rasterizer = self.loader = self.cache = None if any(cached_features): self.cache = Loader( file_paths=cached_files, + features=load_features, res_kwargs=res_kwargs, chunks=chunks, BaseLoader=BaseLoader, diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index 9a76705285..d431d02c51 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -56,7 +56,7 @@ def __init__( a method to derive the feature in the registry. interp_kwargs : dict | None Dictionary of kwargs for level interpolation. Can include "method" - and "run_level_check" keys. Method specifies how to perform height + and "run_level_check". "method" specifies how to perform height interpolation. e.g. Deriving u_20m from u_10m and u_100m. Options are "linear" and "log". See :py:meth:`sup3r.preprocessing.derivers.Deriver.do_level_interpolation` @@ -65,7 +65,7 @@ def __init__( self.FEATURE_REGISTRY = FeatureRegistry super().__init__(data=data) - self.interp_kwargs = interp_kwargs + self.interp_kwargs = interp_kwargs or {} features = parse_to_list(data=data, features=features) new_features = [f for f in features if f not in self.data] for f in new_features: @@ -269,7 +269,6 @@ def get_single_level_data(self, feature): var_array = da.stack(var_list, axis=-1) sl_shape = (*var_array.shape[:-1], len(lev_list)) lev_array = da.broadcast_to(da.from_array(lev_list), sl_shape) - return var_array, lev_array def get_multi_level_data(self, feature): @@ -296,8 +295,8 @@ def get_multi_level_data(self, feature): assert can_calc_height or have_height, msg if can_calc_height: - lev_array = self.data[['zg', 'topography']].as_array() - lev_array = lev_array[..., 0] - lev_array[..., 1] + lev_array = self.data['zg'] - self.data['topography'] + lev_array = lev_array.data else: lev_array = da.broadcast_to( self.data[Dimension.HEIGHT].astype(np.float32), diff --git a/sup3r/utilities/interpolation.py b/sup3r/utilities/interpolation.py index dd0d8c1bbd..79f0935182 100644 --- a/sup3r/utilities/interpolation.py +++ b/sup3r/utilities/interpolation.py @@ -20,9 +20,6 @@ def get_level_masks(cls, lev_array, level): Parameters ---------- - var_array : Union[np.ndarray, da.core.Array] - Array of variable data, for example u-wind in a 4D array of shape - (lat, lon, time, level) lev_array : Union[np.ndarray, da.core.Array] Height or pressure values for the corresponding entries in var_array, in the same shape as var_array. If this is height and @@ -45,14 +42,14 @@ def get_level_masks(cls, lev_array, level): to the one requested. (lat, lon, time, level) """ - argmin1 = da.argmin(da.abs(lev_array - level), axis=-1, keepdims=True) + lev_diff = np.abs(lev_array - level) + argmin1 = da.argmin(lev_diff, axis=-1, keepdims=True) lev_indices = da.broadcast_to( da.arange(lev_array.shape[-1]), lev_array.shape ) mask1 = lev_indices == argmin1 - - other_levs = da.ma.masked_array(lev_array, mask1) - argmin2 = da.argmin(da.abs(other_levs - level), axis=-1, keepdims=True) + lev_diff = da.abs(da.ma.masked_array(lev_array, mask1) - level) + argmin2 = da.argmin(lev_diff, axis=-1, keepdims=True) mask2 = lev_indices == argmin2 return mask1, mask2 @@ -109,9 +106,6 @@ def interp_to_level( Parameters ---------- - var_array : xr.DataArray - Array of variable data, for example u-wind in a 4D array of shape - (lat, lon, time, level) lev_array : xr.DataArray Height or pressure values for the corresponding entries in var_array, in the same shape as var_array. If this is height and @@ -119,6 +113,9 @@ def interp_to_level( should be the geopotential height corresponding to every var_array index relative to the surface elevation (subtract the elevation at the surface from the geopotential height) + var_array : xr.DataArray + Array of variable data, for example u-wind in a 4D array of shape + (lat, lon, time, level) level : float level or levels to interpolate to (e.g. final desired hub height above surface elevation) diff --git a/tests/data_handlers/test_dh_nc_cc.py b/tests/data_handlers/test_dh_nc_cc.py index 9dd8680234..740b6f7605 100644 --- a/tests/data_handlers/test_dh_nc_cc.py +++ b/tests/data_handlers/test_dh_nc_cc.py @@ -68,7 +68,7 @@ def test_reload_cache(): features=features, target=target, shape=(20, 20), - cache_kwargs=cache_kwargs, + cache_kwargs=cache_kwargs ) # reload from cache @@ -80,7 +80,9 @@ def test_reload_cache(): cache_kwargs=cache_kwargs, ) assert all(f in cached for f in features) - assert np.array_equal(handler.as_array(), cached.as_array()) + harr = handler.as_array().compute() + carr = cached.as_array().compute() + assert np.array_equal(harr, carr) @pytest.mark.parametrize( diff --git a/tests/derivers/test_height_interp.py b/tests/derivers/test_height_interp.py index 9464aa49eb..a747f7b254 100644 --- a/tests/derivers/test_height_interp.py +++ b/tests/derivers/test_height_interp.py @@ -114,6 +114,41 @@ def test_plevel_height_interp_nc_with_cache(): ) +def test_plevel_height_interp_with_filtered_load_features(): + """Test that filtering load features can be used to control the features + used in the derivations.""" + + with TemporaryDirectory() as td: + orog_file = os.path.join(td, 'orog.nc') + make_fake_nc_file(orog_file, shape=(10, 10, 20), features=['orog']) + sfc_file = os.path.join(td, 'u_10m.nc') + make_fake_nc_file(sfc_file, shape=(10, 10, 20), features=['u_10m']) + level_file = os.path.join(td, 'wind_levs.nc') + make_fake_nc_file( + level_file, shape=(10, 10, 20, 3), features=['zg', 'u'] + ) + derive_features = ['u_20m'] + dh_filt = DataHandler( + [orog_file, sfc_file, level_file], + features=derive_features, + load_features=['topography', 'zg', 'u'], + ) + dh_no_filt = DataHandler( + [orog_file, sfc_file, level_file], + features=derive_features, + ) + dh = DataHandler( + [orog_file, level_file], + features=derive_features, + ) + assert np.array_equal( + dh_filt.data['u_20m'].data, dh.data['u_20m'].data + ) + assert not np.array_equal( + dh_filt.data['u_20m'].data, dh_no_filt.data['u_20m'].data + ) + + def test_only_interp_method(): """Test that interp method alone returns the right values""" hgt = np.zeros((10, 10, 5, 3)) @@ -197,7 +232,10 @@ def test_plevel_height_interp_with_single_lev_data_nc( [wind_file, level_file], target=target, shape=shape ) - transform = Deriver(no_transform.data, derive_features) + transform = Deriver( + no_transform.data, + derive_features, + ) hgt_array = ( no_transform['zg'].data - no_transform['topography'].data[..., None] diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 69d6cfda7d..76d2687626 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -491,9 +491,7 @@ def test_fwp_chunking(input_files): data_chunked[hr_slice][..., t_hr_slice, :] = out err = data_chunked - data_nochunk - err /= data_nochunk - - assert np.mean(np.abs(err.flatten())) < 0.01 + assert np.mean(np.abs(err)) < 1e-6 def test_fwp_nochunking(input_files):