diff --git a/.github/workflows/on-push.yaml b/.github/workflows/on-push.yaml index bbcb9f1..5786f97 100644 --- a/.github/workflows/on-push.yaml +++ b/.github/workflows/on-push.yaml @@ -80,6 +80,35 @@ jobs: run: | make unit-tests COV_REPORT=xml + type-check: + needs: combine-environments + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: combined-environments + path: ci + - name: Get current date + id: date + run: echo "date=$(date +%Y-%m-%d)" >> "${GITHUB_OUTPUT}" + - uses: mamba-org/setup-micromamba@v2 + with: + environment-file: ci/combined-environment-ci.yml + environment-name: DEVELOP + cache-environment: true + cache-environment-key: environment-${{ steps.date.outputs.date }} + cache-downloads-key: downloads-${{ steps.date.outputs.date }} + create-args: >- + python=3.11 + - name: Install package + run: | + python -m pip install --no-deps -e . + - name: Run code quality checks + run: | + make type-check + docs-build: needs: [combine-environments] runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index dbdbc00..15a2c75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ rasterio = [ [tool.coverage.run] branch = true +[tool.mypy] +ignore_missing_imports = true + [tool.ruff] # Black line length is 88, but black does not format comments. line-length = 110 diff --git a/src/earthkit/transforms/aggregate/climatology.py b/src/earthkit/transforms/aggregate/climatology.py index f786cb5..f8ec630 100644 --- a/src/earthkit/transforms/aggregate/climatology.py +++ b/src/earthkit/transforms/aggregate/climatology.py @@ -715,11 +715,11 @@ def _anomaly_dataarray( if var_name in climatology: climatology_da = climatology[var_name] else: - potential_clim_vars = [c_var for c_var in climatology.data_vars if var_name in c_var] + potential_clim_vars = [c_var for c_var in climatology.data_vars if str(var_name) in str(c_var)] if len(potential_clim_vars) == 1: climatology_da = climatology[potential_clim_vars[0]] - elif var_name + "_" + climatology_how_tag in potential_clim_vars: - climatology_da = climatology[var_name + "_" + climatology_how_tag] + elif f"{var_name}_{climatology_how_tag}" in potential_clim_vars: + climatology_da = climatology[f"{var_name}_{climatology_how_tag}"] elif len(potential_clim_vars) > 1: raise KeyError( "Multiple potential climatologies found in climatology dataset, " @@ -854,7 +854,7 @@ def auto_anomaly( xr.DataArray """ # If climate range is defined, use it - if all(c_r is not None for c_r in climatology_range): + if climatology_range is not None and all(c_r is not None for c_r in climatology_range): selection = dataarray.sel(time=slice(*climatology_range)) else: selection = dataarray diff --git a/src/earthkit/transforms/aggregate/general.py b/src/earthkit/transforms/aggregate/general.py index d9f9fcc..419cdb6 100644 --- a/src/earthkit/transforms/aggregate/general.py +++ b/src/earthkit/transforms/aggregate/general.py @@ -13,7 +13,7 @@ def how_label_rename( # Update variable names, depends on dataset or dataarray format if isinstance(dataarray, xr.Dataset): renames = {data_arr: f"{data_arr}_{how_label}" for data_arr in dataarray} - dataarray = dataarray.rename(**renames) + dataarray = dataarray.rename(renames) else: dataarray = dataarray.rename(f"{dataarray.name}_{how_label}") @@ -62,16 +62,16 @@ def _reduce_dataarray( if weights is not None: # Create any standard weights, e.g. latitude if isinstance(weights, str): - weights = tools.standard_weights(dataarray, weights, **kwargs) + _weights = tools.standard_weights(dataarray, weights, **kwargs) + else: + _weights = weights # We ensure the callable is always a string if callable(how): - how = how.__name__ + how = weighted_how = how.__name__ # map any alias methods: - how = tools.WEIGHTED_HOW_METHODS.get(how, how) - - dataarray = dataarray.weighted(weights) + weighted_how = tools.WEIGHTED_HOW_METHODS.get(how, how) - red_array = dataarray.__getattribute__(how)(**kwargs) + red_array = dataarray.weighted(_weights).__getattribute__(weighted_how)(**kwargs) else: # If how is string, fetch function from dictionary: @@ -80,7 +80,7 @@ def _reduce_dataarray( else: if isinstance(how, str): how = tools.get_how(how) - assert isinstance(how, T.Callable), f"how method not recognised: {how}" + assert callable(how), f"how method not recognised: {how}" red_array = dataarray.reduce(how, **kwargs) @@ -129,14 +129,19 @@ def reduce( """ # handle how as arg or kwarg kwargs["how"] = _args[0] if _args else kwargs.get("how", "mean") - out = _reduce_dataarray(dataarray, **kwargs) - # Ensure any input attributes are preserved (maybe not necessary) + if isinstance(dataarray, xr.Dataset): - out.attrs.update(dataarray.attrs) + out_ds = xr.Dataset().assign_attrs(dataarray.attrs) + for var in dataarray.data_vars: + out_da = _reduce_dataarray(dataarray[var], **kwargs) + out_ds[out_da.name] = out_da + return out_ds + + out = _reduce_dataarray(dataarray, **kwargs) return out -def rolling_reduce(dataarray: xr.Dataset | xr.DataArray, *_args, **kwargs) -> xr.DataArray: +def rolling_reduce(dataarray: xr.Dataset | xr.DataArray, *_args, **kwargs) -> xr.Dataset | xr.DataArray: """Return reduced data using a moving window over which to apply the reduction. Parameters @@ -214,7 +219,7 @@ def _rolling_reduce_dataarray( if isinstance(kwargs.get("dim"), dict): kwargs.update(kwargs.pop("dim")) - window_dims = [_dim for _dim in list(dataarray.dims) if _dim in list(kwargs)] + window_dims = [str(_dim) for _dim in list(dataarray.dims) if _dim in list(kwargs)] rolling_kwargs_keys = ["min_periods", "center"] + window_dims rolling_kwargs_keys = [_kwarg for _kwarg in kwargs if _kwarg in rolling_kwargs_keys] rolling_kwargs = {_kwarg: kwargs.pop(_kwarg) for _kwarg in rolling_kwargs_keys} @@ -225,7 +230,7 @@ def _rolling_reduce_dataarray( data_rolling = dataarray.rolling(**rolling_kwargs) reduce_kwargs.setdefault("how", how_reduce) - data_windowed = _reduce_dataarray(data_rolling, **reduce_kwargs) + data_windowed = _reduce_dataarray(data_rolling, **reduce_kwargs) # type: ignore data_windowed = _dropna(data_windowed, window_dims, how_dropna) diff --git a/src/earthkit/transforms/aggregate/spatial.py b/src/earthkit/transforms/aggregate/spatial.py index fb053b9..df80835 100644 --- a/src/earthkit/transforms/aggregate/spatial.py +++ b/src/earthkit/transforms/aggregate/spatial.py @@ -37,7 +37,6 @@ def rasterize( coords: xr.core.coordinates.Coordinates, lat_key: str = "latitude", lon_key: str = "longitude", - dtype: type = int, **kwargs, ) -> xr.DataArray: """Rasterize a list of geometries onto the given xarray coordinates. @@ -70,12 +69,18 @@ def rasterize( transform = _transform_from_latlon(coords[lat_key], coords[lon_key]) out_shape = (len(coords[lat_key]), len(coords[lon_key])) - raster = features.rasterize(shape_list, out_shape=out_shape, transform=transform, dtype=dtype, **kwargs) + raster = features.rasterize(shape_list, out_shape=out_shape, transform=transform, **kwargs) spatial_coords = {lat_key: coords[lat_key], lon_key: coords[lon_key]} return xr.DataArray(raster, coords=spatial_coords, dims=(lat_key, lon_key)) -def mask_contains_points(shape_list, coords, lat_key="lat", lon_key="lon", **_kwargs) -> xr.DataArray: +def mask_contains_points( + shape_list: T.List, + coords: xr.core.coordinates.Coordinates, + lat_key: str = "latitude", + lon_key: str = "longitude", + **_kwargs, +) -> xr.DataArray: """Return a mask array for the spatial points of data that lie within shapes in shape_list. Function uses matplotlib.Path so can accept a list of points, this is much faster than shapely. @@ -360,7 +365,7 @@ def mask( if union_geometries: out = masked_arrays[0] else: - out = xr.concat(masked_arrays, dim=mask_dim_index.name) + out = xr.concat(masked_arrays, dim=mask_dim_index.name) # type: ignore if chunk: out = out.chunk({mask_dim_index.name: 1}) @@ -375,7 +380,6 @@ def reduce( dataarray: xr.Dataset | xr.DataArray, geodataframe: gpd.GeoDataFrame | None = None, mask_arrays: xr.DataArray | list[xr.DataArray] | None = None, - *_args, **kwargs, ) -> xr.Dataset | xr.DataArray: """Apply a shape object to an xarray.DataArray object using the specified 'how' method. @@ -426,15 +430,21 @@ def reduce( Each slice of layer corresponds to a feature in layer. """ + assert not ( + geodataframe is not None and mask_arrays is not None + ), "Either a geodataframe or mask arrays must be provided, not both" if mask_arrays is not None: - mask_arrays = ensure_list(mask_arrays) + _mask_arrays: list[xr.DataArray] | None = ensure_list(mask_arrays) + else: + _mask_arrays = None + if isinstance(dataarray, xr.Dataset): return_as: str = kwargs.get("return_as", "xarray") if return_as in ["xarray"]: out_ds = xr.Dataset().assign_attrs(dataarray.attrs) for var in dataarray.data_vars: out_da = _reduce_dataarray( - dataarray[var], geodataframe=geodataframe, mask_arrays=mask_arrays, *_args, **kwargs + dataarray[var], geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs ) out_ds[out_da.name] = out_da return out_ds @@ -446,24 +456,23 @@ def reduce( if geodataframe is not None: out = geodataframe for var in dataarray.data_vars: - out = _reduce_dataarray(dataarray[var], geodataframe=out, *_args, **kwargs) + out = _reduce_dataarray(dataarray[var], geodataframe=out, **kwargs) else: out = None for var in dataarray.data_vars: - _out = _reduce_dataarray(dataarray[var], mask_arrays=mask_arrays, *_args, **kwargs) + _out = _reduce_dataarray(dataarray[var], mask_arrays=_mask_arrays, **kwargs) if out is None: out = _out else: - out = pd.merge(out, _out) + out = pd.merge(out, _out) # type: ignore return out else: raise TypeError("Return as type not recognised or incompatible with inputs") else: - return _reduce_dataarray( - dataarray, geodataframe=geodataframe, mask_arrays=mask_arrays, *_args, **kwargs - ) + return _reduce_dataarray(dataarray, geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs) # type: ignore +# TODO: split into two functions, one for xarray and one for pandas def _reduce_dataarray( dataarray: xr.DataArray, geodataframe: gpd.GeoDataFrame | None = None, @@ -540,8 +549,10 @@ def _reduce_dataarray( # convert how string to a method to apply if isinstance(how, str): how_str = deepcopy(how) - how = get_how(how) - assert isinstance(how, T.Callable), f"how must be a callable: {how}" + how = reduce_how = get_how(how) + # else: + # reduce_how = how + assert callable(how), f"how must be a callable: {how}" if how_str is None: # get label from how method how_str = how.__name__ @@ -549,10 +560,14 @@ def _reduce_dataarray( # Create any standard weights, e.g. latitude. # TODO: handle kwargs better, currently only lat_key is accepted if isinstance(weights, str): - weights = standard_weights(dataarray, weights, lat_key=lat_key) + _weights = standard_weights(dataarray, weights, lat_key=lat_key) + else: + _weights = weights # We ensure the callable is a string if callable(how): - how = how.__name__ + how = weighted_how = how.__name__ + else: + weighted_how = how if how_str is None: how_str = how @@ -561,7 +576,7 @@ def _reduce_dataarray( comp for comp in [how_str, dataarray.attrs.get("long_name", dataarray.name)] if comp is not None ] new_long_name = " ".join(new_long_name_components) - new_short_name_components = [comp for comp in [dataarray.name, how_label] if comp is not None] + new_short_name_components = [f"{comp}" for comp in [dataarray.name, how_label] if comp is not None] new_short_name = "_".join(new_short_name_components) if isinstance(extra_reduce_dims, str): @@ -597,10 +612,10 @@ def _reduce_dataarray( # If weighted, use xarray weighted arrays which correctly handle missing values etc. if weights is not None: - this = this.weighted(weights) - reduced_list.append(this.__getattribute__(how)(**reduce_kwargs)) + this_weighted = this.weighted(_weights) + reduced_list.append(this_weighted.__getattribute__(weighted_how)(**reduce_kwargs)) else: - reduced = this.reduce(how, **reduce_kwargs).compute() + reduced = this.reduce(reduce_how, **reduce_kwargs).compute() reduced = reduced.assign_attrs(dataarray.attrs) reduced_list.append(reduced) @@ -646,8 +661,9 @@ def _reduce_dataarray( # add the reduced data into a new column as a numpy array, # store the dim information in the attributes + # TODO: fix typing out_dims = { - dim: dataarray.coords.get(dim).values if dim in dataarray.coords else None + dim: dataarray.coords.get(dim).values if dim in dataarray.coords else None # type: ignore for dim in reduced_list[0].dims } reduce_attrs[f"{new_short_name}"].update({"dims": out_dims}) diff --git a/src/earthkit/transforms/aggregate/temporal.py b/src/earthkit/transforms/aggregate/temporal.py index 05a4c29..447dec9 100644 --- a/src/earthkit/transforms/aggregate/temporal.py +++ b/src/earthkit/transforms/aggregate/temporal.py @@ -425,7 +425,7 @@ def daily_reduce( # If how is string, fetch function from dictionary: if isinstance(how, str): how = tools.get_how(how) - assert isinstance(how, T.Callable), f"how method not recognised: {how}" + assert callable(how), f"how method not recognised: {how}" red_array = grouped_data.reduce(how, **kwargs) try: @@ -684,7 +684,7 @@ def monthly_reduce( # If how is string, fetch function from dictionary: if isinstance(how, str): how = tools.get_how(how) - assert isinstance(how, T.Callable), f"how method not recognised: {how}" + assert callable(how), f"how method not recognised: {how}" red_array = grouped_data.reduce(how, **kwargs) # Remove the year_months coordinate @@ -935,5 +935,5 @@ def rolling_reduce( A dataarray reduced values with a rolling window applied along the time dimension. """ if window_length is not None: - kwargs.update({time_dim: window_length}) + kwargs.update({str(time_dim): window_length}) return _rolling_reduce(dataarray, **kwargs) diff --git a/src/earthkit/transforms/tools.py b/src/earthkit/transforms/tools.py index 73282e3..3b1f612 100644 --- a/src/earthkit/transforms/tools.py +++ b/src/earthkit/transforms/tools.py @@ -1,4 +1,5 @@ import functools +import typing as T import numpy as np import pandas as pd @@ -27,7 +28,7 @@ } -def ensure_list(thing): +def ensure_list(thing) -> list[T.Any]: if isinstance(thing, list): return thing try: @@ -77,7 +78,7 @@ def wrapper( return wrapper -GROUPBY_KWARGS = ["frequency", "bin_widths", "squeeze"] +GROUPBY_KWARGS = ["frequency", "bin_widths"] def groupby_kwargs_decorator(func): @@ -379,7 +380,7 @@ def get_spatial_info(dataarray, lat_key=None, lon_key=None): def _pandas_frequency_and_bins( frequency: str, -) -> tuple: +) -> tuple[str, int | None]: freq = frequency.lstrip("0123456789") bins = int(frequency[: -len(freq)]) or None freq = _PANDAS_FREQUENCIES.get(freq.lstrip(" "), frequency) @@ -390,7 +391,6 @@ def groupby_time( dataarray: xr.Dataset | xr.DataArray, frequency: str | None = None, bin_widths: int | None = None, - squeeze: bool = False, time_dim: str = "time", ): if frequency is None: @@ -404,10 +404,10 @@ def groupby_time( bin_widths = bin_widths or possible_bins if bin_widths is not None: - grouped_data = groupby_bins(dataarray, frequency, bin_widths, squeeze, time_dim=time_dim) + grouped_data = groupby_bins(dataarray, frequency, bin_widths, time_dim=time_dim) else: try: - grouped_data = dataarray.groupby(f"{time_dim}.{frequency}", squeeze=squeeze) + grouped_data = dataarray.groupby(f"{time_dim}.{frequency}") except AttributeError: raise ValueError( f"Invalid frequency '{frequency}' - see xarray documentation for " @@ -420,15 +420,14 @@ def groupby_time( def groupby_bins( dataarray: xr.Dataset | xr.DataArray, frequency: str, - bin_widths: int = 1, - squeeze: bool = False, + bin_widths: list[int] | int = 1, time_dim: str = "time", ): if not isinstance(bin_widths, (list, tuple)): max_value = _BIN_MAXES[frequency] bin_widths = list(range(0, max_value + 1, bin_widths)) try: - grouped_data = dataarray.groupby_bins(f"{time_dim}.{frequency}", bin_widths, squeeze=squeeze) + grouped_data = dataarray.groupby_bins(f"{time_dim}.{frequency}", bin_widths) except AttributeError: raise ValueError( f"Invalid frequency '{frequency}' - see xarray documentation for " diff --git a/tests/test_10_tools.py b/tests/test_10_tools.py index b45a17d..6d842fd 100644 --- a/tests/test_10_tools.py +++ b/tests/test_10_tools.py @@ -103,7 +103,7 @@ def test_groupby_kwargs_decorator_none(): # Test case for the decorator when groupby_kwargs is provided def test_groupby_kwargs_decorator_provided(): # Prepare groupby_kwargs and other kwargs - groupby_kwargs = {"frequency": "day", "bin_widths": 1, "squeeze": True} + groupby_kwargs = {"frequency": "day", "bin_widths": 1} other_kwargs = {"method": "linear", "fill_value": 0} # Call the decorated function with groupby_kwargs provided @@ -133,7 +133,7 @@ def test_groupby_kwargs_decorator_partial_provided(): # Test case for the decorator when groupby_kwargs is provided def test_groupby_kwargs_decorator_override(): # Prepare groupby_kwargs and other kwargs - groupby_kwargs = {"frequency": "day", "bin_widths": 1, "squeeze": True} + groupby_kwargs = {"frequency": "day", "bin_widths": 1} other_kwargs = {"method": "linear", "fill_value": 0} override_groupby_kwargs = {"frequency": "hour"} diff --git a/tests/test_30_spatial.py b/tests/test_30_spatial.py index 1740e0d..5d0bb4b 100644 --- a/tests/test_30_spatial.py +++ b/tests/test_30_spatial.py @@ -19,6 +19,20 @@ ek_data.settings.set("cache-policy", "user") +SAMPLE_ARRAY = xr.DataArray( + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + [3, 3, 3, 3], + ], + dims=["latitude", "longitude"], + coords={ + "latitude": [0, 60, 90], # Chosen for latitude weight tests + "longitude": [0, 30, 60, 90], + }, +) + + class dummy_class: def __init__(self): self.to_pandas = pd.DataFrame @@ -82,6 +96,13 @@ def test_spatial_reduce_no_geometry(era5_data, expected_result_type): assert list(reduced_data.dims) == ["forecast_reference_time"] +def test_spatial_reduce_no_geometry_result(): + reduced_data = spatial.reduce(SAMPLE_ARRAY, how="mean") + assert reduced_data.values == 2.0 + reduced_data = spatial.reduce(SAMPLE_ARRAY, how="mean", weights="latitude") + assert np.isclose(reduced_data.values, 1 + (1.0 / 3)) + + @pytest.mark.skipif( not rasterio_available, reason="rasterio is not available",