Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions sup3r/preprocessing/data_handlers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self,
file_paths,
features='all',
load_features='all',
res_kwargs: Optional[dict] = None,
chunks: Union[str, Dict[str, int]] = 'auto',
target: Optional[tuple] = None,
Expand All @@ -69,9 +70,16 @@ def __init__(
file_paths : str | list | pathlib.Path
file_paths input to LoaderClass
features : list | str
Features to load and / or derive. If 'all' then all available raw
features will be loaded. Specify explicit feature names for
derivations.
Features to derive. If 'all' then all available raw features will
just be loaded. Specify explicit feature names for derivations.
load_features : list | str
Features to load and make available for derivations. If 'all' then
all available raw features will be loaded and made available for
derivations. This can be used to restrict features used for
derivations. For example, to derive 'temperature_100m' from only
temperature isobars, from data that includes single level values as
well (like temperature_2m), don't include 'temperature_2m' in the
``load_features`` list.
res_kwargs : dict
Additional keyword arguments passed through to the ``BaseLoader``.
BaseLoader is usually xr.open_mfdataset for NETCDF files and
Expand Down Expand Up @@ -146,12 +154,13 @@ def __init__(
)

just_coords = not features
raster_feats = 'all' if any(missing_features) else []
raster_feats = load_features if any(missing_features) else []
self.rasterizer = self.loader = self.cache = None

if any(cached_features):
self.cache = Loader(
file_paths=cached_files,
features=load_features,
res_kwargs=res_kwargs,
chunks=chunks,
BaseLoader=BaseLoader,
Expand Down
9 changes: 4 additions & 5 deletions sup3r/preprocessing/derivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
a method to derive the feature in the registry.
interp_kwargs : dict | None
Dictionary of kwargs for level interpolation. Can include "method"
and "run_level_check" keys. Method specifies how to perform height
and "run_level_check". "method" specifies how to perform height
interpolation. e.g. Deriving u_20m from u_10m and u_100m. Options
are "linear" and "log". See
:py:meth:`sup3r.preprocessing.derivers.Deriver.do_level_interpolation`
Expand All @@ -65,7 +65,7 @@ def __init__(
self.FEATURE_REGISTRY = FeatureRegistry

super().__init__(data=data)
self.interp_kwargs = interp_kwargs
self.interp_kwargs = interp_kwargs or {}
features = parse_to_list(data=data, features=features)
new_features = [f for f in features if f not in self.data]
for f in new_features:
Expand Down Expand Up @@ -269,7 +269,6 @@ def get_single_level_data(self, feature):
var_array = da.stack(var_list, axis=-1)
sl_shape = (*var_array.shape[:-1], len(lev_list))
lev_array = da.broadcast_to(da.from_array(lev_list), sl_shape)

return var_array, lev_array

def get_multi_level_data(self, feature):
Expand All @@ -296,8 +295,8 @@ def get_multi_level_data(self, feature):
assert can_calc_height or have_height, msg

if can_calc_height:
lev_array = self.data[['zg', 'topography']].as_array()
lev_array = lev_array[..., 0] - lev_array[..., 1]
lev_array = self.data['zg'] - self.data['topography']
lev_array = lev_array.data
else:
lev_array = da.broadcast_to(
self.data[Dimension.HEIGHT].astype(np.float32),
Expand Down
17 changes: 7 additions & 10 deletions sup3r/utilities/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ def get_level_masks(cls, lev_array, level):

Parameters
----------
var_array : Union[np.ndarray, da.core.Array]
Array of variable data, for example u-wind in a 4D array of shape
(lat, lon, time, level)
lev_array : Union[np.ndarray, da.core.Array]
Height or pressure values for the corresponding entries in
var_array, in the same shape as var_array. If this is height and
Expand All @@ -45,14 +42,14 @@ def get_level_masks(cls, lev_array, level):
to the one requested.
(lat, lon, time, level)
"""
argmin1 = da.argmin(da.abs(lev_array - level), axis=-1, keepdims=True)
lev_diff = np.abs(lev_array - level)
argmin1 = da.argmin(lev_diff, axis=-1, keepdims=True)
lev_indices = da.broadcast_to(
da.arange(lev_array.shape[-1]), lev_array.shape
)
mask1 = lev_indices == argmin1

other_levs = da.ma.masked_array(lev_array, mask1)
argmin2 = da.argmin(da.abs(other_levs - level), axis=-1, keepdims=True)
lev_diff = da.abs(da.ma.masked_array(lev_array, mask1) - level)
argmin2 = da.argmin(lev_diff, axis=-1, keepdims=True)
mask2 = lev_indices == argmin2
return mask1, mask2

Expand Down Expand Up @@ -109,16 +106,16 @@ def interp_to_level(

Parameters
----------
var_array : xr.DataArray
Array of variable data, for example u-wind in a 4D array of shape
(lat, lon, time, level)
lev_array : xr.DataArray
Height or pressure values for the corresponding entries in
var_array, in the same shape as var_array. If this is height and
the requested levels are hub heights above surface, lev_array
should be the geopotential height corresponding to every var_array
index relative to the surface elevation (subtract the elevation at
the surface from the geopotential height)
var_array : xr.DataArray
Array of variable data, for example u-wind in a 4D array of shape
(lat, lon, time, level)
level : float
level or levels to interpolate to (e.g. final desired hub height
above surface elevation)
Expand Down
6 changes: 4 additions & 2 deletions tests/data_handlers/test_dh_nc_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_reload_cache():
features=features,
target=target,
shape=(20, 20),
cache_kwargs=cache_kwargs,
cache_kwargs=cache_kwargs
)

# reload from cache
Expand All @@ -80,7 +80,9 @@ def test_reload_cache():
cache_kwargs=cache_kwargs,
)
assert all(f in cached for f in features)
assert np.array_equal(handler.as_array(), cached.as_array())
harr = handler.as_array().compute()
carr = cached.as_array().compute()
assert np.array_equal(harr, carr)


@pytest.mark.parametrize(
Expand Down
40 changes: 39 additions & 1 deletion tests/derivers/test_height_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,41 @@ def test_plevel_height_interp_nc_with_cache():
)


def test_plevel_height_interp_with_filtered_load_features():
"""Test that filtering load features can be used to control the features
used in the derivations."""

with TemporaryDirectory() as td:
orog_file = os.path.join(td, 'orog.nc')
make_fake_nc_file(orog_file, shape=(10, 10, 20), features=['orog'])
sfc_file = os.path.join(td, 'u_10m.nc')
make_fake_nc_file(sfc_file, shape=(10, 10, 20), features=['u_10m'])
level_file = os.path.join(td, 'wind_levs.nc')
make_fake_nc_file(
level_file, shape=(10, 10, 20, 3), features=['zg', 'u']
)
derive_features = ['u_20m']
dh_filt = DataHandler(
[orog_file, sfc_file, level_file],
features=derive_features,
load_features=['topography', 'zg', 'u'],
)
dh_no_filt = DataHandler(
[orog_file, sfc_file, level_file],
features=derive_features,
)
dh = DataHandler(
[orog_file, level_file],
features=derive_features,
)
assert np.array_equal(
dh_filt.data['u_20m'].data, dh.data['u_20m'].data
)
assert not np.array_equal(
dh_filt.data['u_20m'].data, dh_no_filt.data['u_20m'].data
)


def test_only_interp_method():
"""Test that interp method alone returns the right values"""
hgt = np.zeros((10, 10, 5, 3))
Expand Down Expand Up @@ -197,7 +232,10 @@ def test_plevel_height_interp_with_single_lev_data_nc(
[wind_file, level_file], target=target, shape=shape
)

transform = Deriver(no_transform.data, derive_features)
transform = Deriver(
no_transform.data,
derive_features,
)

hgt_array = (
no_transform['zg'].data - no_transform['topography'].data[..., None]
Expand Down
4 changes: 1 addition & 3 deletions tests/forward_pass/test_forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,7 @@ def test_fwp_chunking(input_files):
data_chunked[hr_slice][..., t_hr_slice, :] = out

err = data_chunked - data_nochunk
err /= data_nochunk

assert np.mean(np.abs(err.flatten())) < 0.01
assert np.mean(np.abs(err)) < 1e-6


def test_fwp_nochunking(input_files):
Expand Down
Loading