Skip to content

Feature/typeset #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/on-push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,35 @@ jobs:
run: |
make unit-tests COV_REPORT=xml

type-check:
needs: combine-environments
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
with:
name: combined-environments
path: ci
- name: Get current date
id: date
run: echo "date=$(date +%Y-%m-%d)" >> "${GITHUB_OUTPUT}"
- uses: mamba-org/setup-micromamba@v2
with:
environment-file: ci/combined-environment-ci.yml
environment-name: DEVELOP
cache-environment: true
cache-environment-key: environment-${{ steps.date.outputs.date }}
cache-downloads-key: downloads-${{ steps.date.outputs.date }}
create-args: >-
python=3.11
- name: Install package
run: |
python -m pip install --no-deps -e .
- name: Run code quality checks
run: |
make type-check

docs-build:
needs: [combine-environments]
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ rasterio = [
[tool.coverage.run]
branch = true

[tool.mypy]
ignore_missing_imports = true

[tool.ruff]
# Black line length is 88, but black does not format comments.
line-length = 110
Expand Down
8 changes: 4 additions & 4 deletions src/earthkit/transforms/aggregate/climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,11 +715,11 @@ def _anomaly_dataarray(
if var_name in climatology:
climatology_da = climatology[var_name]
else:
potential_clim_vars = [c_var for c_var in climatology.data_vars if var_name in c_var]
potential_clim_vars = [c_var for c_var in climatology.data_vars if str(var_name) in str(c_var)]
if len(potential_clim_vars) == 1:
climatology_da = climatology[potential_clim_vars[0]]
elif var_name + "_" + climatology_how_tag in potential_clim_vars:
climatology_da = climatology[var_name + "_" + climatology_how_tag]
elif f"{var_name}_{climatology_how_tag}" in potential_clim_vars:
climatology_da = climatology[f"{var_name}_{climatology_how_tag}"]
elif len(potential_clim_vars) > 1:
raise KeyError(
"Multiple potential climatologies found in climatology dataset, "
Expand Down Expand Up @@ -854,7 +854,7 @@ def auto_anomaly(
xr.DataArray
"""
# If climate range is defined, use it
if all(c_r is not None for c_r in climatology_range):
if climatology_range is not None and all(c_r is not None for c_r in climatology_range):
selection = dataarray.sel(time=slice(*climatology_range))
else:
selection = dataarray
Expand Down
33 changes: 19 additions & 14 deletions src/earthkit/transforms/aggregate/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def how_label_rename(
# Update variable names, depends on dataset or dataarray format
if isinstance(dataarray, xr.Dataset):
renames = {data_arr: f"{data_arr}_{how_label}" for data_arr in dataarray}
dataarray = dataarray.rename(**renames)
dataarray = dataarray.rename(renames)
else:
dataarray = dataarray.rename(f"{dataarray.name}_{how_label}")

Expand Down Expand Up @@ -62,16 +62,16 @@ def _reduce_dataarray(
if weights is not None:
# Create any standard weights, e.g. latitude
if isinstance(weights, str):
weights = tools.standard_weights(dataarray, weights, **kwargs)
_weights = tools.standard_weights(dataarray, weights, **kwargs)
else:
_weights = weights
# We ensure the callable is always a string
if callable(how):
how = how.__name__
how = weighted_how = how.__name__
# map any alias methods:
how = tools.WEIGHTED_HOW_METHODS.get(how, how)

dataarray = dataarray.weighted(weights)
weighted_how = tools.WEIGHTED_HOW_METHODS.get(how, how)

red_array = dataarray.__getattribute__(how)(**kwargs)
red_array = dataarray.weighted(_weights).__getattribute__(weighted_how)(**kwargs)

else:
# If how is string, fetch function from dictionary:
Expand All @@ -80,7 +80,7 @@ def _reduce_dataarray(
else:
if isinstance(how, str):
how = tools.get_how(how)
assert isinstance(how, T.Callable), f"how method not recognised: {how}"
assert callable(how), f"how method not recognised: {how}"

red_array = dataarray.reduce(how, **kwargs)

Expand Down Expand Up @@ -129,14 +129,19 @@ def reduce(
"""
# handle how as arg or kwarg
kwargs["how"] = _args[0] if _args else kwargs.get("how", "mean")
out = _reduce_dataarray(dataarray, **kwargs)
# Ensure any input attributes are preserved (maybe not necessary)

if isinstance(dataarray, xr.Dataset):
out.attrs.update(dataarray.attrs)
out_ds = xr.Dataset().assign_attrs(dataarray.attrs)
for var in dataarray.data_vars:
out_da = _reduce_dataarray(dataarray[var], **kwargs)
out_ds[out_da.name] = out_da
return out_ds

out = _reduce_dataarray(dataarray, **kwargs)
return out


def rolling_reduce(dataarray: xr.Dataset | xr.DataArray, *_args, **kwargs) -> xr.DataArray:
def rolling_reduce(dataarray: xr.Dataset | xr.DataArray, *_args, **kwargs) -> xr.Dataset | xr.DataArray:
"""Return reduced data using a moving window over which to apply the reduction.

Parameters
Expand Down Expand Up @@ -214,7 +219,7 @@ def _rolling_reduce_dataarray(
if isinstance(kwargs.get("dim"), dict):
kwargs.update(kwargs.pop("dim"))

window_dims = [_dim for _dim in list(dataarray.dims) if _dim in list(kwargs)]
window_dims = [str(_dim) for _dim in list(dataarray.dims) if _dim in list(kwargs)]
rolling_kwargs_keys = ["min_periods", "center"] + window_dims
rolling_kwargs_keys = [_kwarg for _kwarg in kwargs if _kwarg in rolling_kwargs_keys]
rolling_kwargs = {_kwarg: kwargs.pop(_kwarg) for _kwarg in rolling_kwargs_keys}
Expand All @@ -225,7 +230,7 @@ def _rolling_reduce_dataarray(
data_rolling = dataarray.rolling(**rolling_kwargs)

reduce_kwargs.setdefault("how", how_reduce)
data_windowed = _reduce_dataarray(data_rolling, **reduce_kwargs)
data_windowed = _reduce_dataarray(data_rolling, **reduce_kwargs) # type: ignore

data_windowed = _dropna(data_windowed, window_dims, how_dropna)

Expand Down
60 changes: 38 additions & 22 deletions src/earthkit/transforms/aggregate/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def rasterize(
coords: xr.core.coordinates.Coordinates,
lat_key: str = "latitude",
lon_key: str = "longitude",
dtype: type = int,
**kwargs,
) -> xr.DataArray:
"""Rasterize a list of geometries onto the given xarray coordinates.
Expand Down Expand Up @@ -70,12 +69,18 @@ def rasterize(

transform = _transform_from_latlon(coords[lat_key], coords[lon_key])
out_shape = (len(coords[lat_key]), len(coords[lon_key]))
raster = features.rasterize(shape_list, out_shape=out_shape, transform=transform, dtype=dtype, **kwargs)
raster = features.rasterize(shape_list, out_shape=out_shape, transform=transform, **kwargs)
spatial_coords = {lat_key: coords[lat_key], lon_key: coords[lon_key]}
return xr.DataArray(raster, coords=spatial_coords, dims=(lat_key, lon_key))


def mask_contains_points(shape_list, coords, lat_key="lat", lon_key="lon", **_kwargs) -> xr.DataArray:
def mask_contains_points(
shape_list: T.List,
coords: xr.core.coordinates.Coordinates,
lat_key: str = "latitude",
lon_key: str = "longitude",
**_kwargs,
) -> xr.DataArray:
"""Return a mask array for the spatial points of data that lie within shapes in shape_list.

Function uses matplotlib.Path so can accept a list of points, this is much faster than shapely.
Expand Down Expand Up @@ -360,7 +365,7 @@ def mask(
if union_geometries:
out = masked_arrays[0]
else:
out = xr.concat(masked_arrays, dim=mask_dim_index.name)
out = xr.concat(masked_arrays, dim=mask_dim_index.name) # type: ignore
if chunk:
out = out.chunk({mask_dim_index.name: 1})

Expand All @@ -375,7 +380,6 @@ def reduce(
dataarray: xr.Dataset | xr.DataArray,
geodataframe: gpd.GeoDataFrame | None = None,
mask_arrays: xr.DataArray | list[xr.DataArray] | None = None,
*_args,
**kwargs,
) -> xr.Dataset | xr.DataArray:
"""Apply a shape object to an xarray.DataArray object using the specified 'how' method.
Expand Down Expand Up @@ -426,15 +430,21 @@ def reduce(
Each slice of layer corresponds to a feature in layer.

"""
assert not (
geodataframe is not None and mask_arrays is not None
), "Either a geodataframe or mask arrays must be provided, not both"
if mask_arrays is not None:
mask_arrays = ensure_list(mask_arrays)
_mask_arrays: list[xr.DataArray] | None = ensure_list(mask_arrays)
else:
_mask_arrays = None

if isinstance(dataarray, xr.Dataset):
return_as: str = kwargs.get("return_as", "xarray")
if return_as in ["xarray"]:
out_ds = xr.Dataset().assign_attrs(dataarray.attrs)
for var in dataarray.data_vars:
out_da = _reduce_dataarray(
dataarray[var], geodataframe=geodataframe, mask_arrays=mask_arrays, *_args, **kwargs
dataarray[var], geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs
)
out_ds[out_da.name] = out_da
return out_ds
Expand All @@ -446,24 +456,23 @@ def reduce(
if geodataframe is not None:
out = geodataframe
for var in dataarray.data_vars:
out = _reduce_dataarray(dataarray[var], geodataframe=out, *_args, **kwargs)
out = _reduce_dataarray(dataarray[var], geodataframe=out, **kwargs)
else:
out = None
for var in dataarray.data_vars:
_out = _reduce_dataarray(dataarray[var], mask_arrays=mask_arrays, *_args, **kwargs)
_out = _reduce_dataarray(dataarray[var], mask_arrays=_mask_arrays, **kwargs)
if out is None:
out = _out
else:
out = pd.merge(out, _out)
out = pd.merge(out, _out) # type: ignore
return out
else:
raise TypeError("Return as type not recognised or incompatible with inputs")
else:
return _reduce_dataarray(
dataarray, geodataframe=geodataframe, mask_arrays=mask_arrays, *_args, **kwargs
)
return _reduce_dataarray(dataarray, geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs) # type: ignore


# TODO: split into two functions, one for xarray and one for pandas
def _reduce_dataarray(
dataarray: xr.DataArray,
geodataframe: gpd.GeoDataFrame | None = None,
Expand Down Expand Up @@ -540,19 +549,25 @@ def _reduce_dataarray(
# convert how string to a method to apply
if isinstance(how, str):
how_str = deepcopy(how)
how = get_how(how)
assert isinstance(how, T.Callable), f"how must be a callable: {how}"
how = reduce_how = get_how(how)
# else:
# reduce_how = how
assert callable(how), f"how must be a callable: {how}"
if how_str is None:
# get label from how method
how_str = how.__name__
else:
# Create any standard weights, e.g. latitude.
# TODO: handle kwargs better, currently only lat_key is accepted
if isinstance(weights, str):
weights = standard_weights(dataarray, weights, lat_key=lat_key)
_weights = standard_weights(dataarray, weights, lat_key=lat_key)
else:
_weights = weights
# We ensure the callable is a string
if callable(how):
how = how.__name__
how = weighted_how = how.__name__
else:
weighted_how = how
if how_str is None:
how_str = how

Expand All @@ -561,7 +576,7 @@ def _reduce_dataarray(
comp for comp in [how_str, dataarray.attrs.get("long_name", dataarray.name)] if comp is not None
]
new_long_name = " ".join(new_long_name_components)
new_short_name_components = [comp for comp in [dataarray.name, how_label] if comp is not None]
new_short_name_components = [f"{comp}" for comp in [dataarray.name, how_label] if comp is not None]
new_short_name = "_".join(new_short_name_components)

if isinstance(extra_reduce_dims, str):
Expand Down Expand Up @@ -597,10 +612,10 @@ def _reduce_dataarray(

# If weighted, use xarray weighted arrays which correctly handle missing values etc.
if weights is not None:
this = this.weighted(weights)
reduced_list.append(this.__getattribute__(how)(**reduce_kwargs))
this_weighted = this.weighted(_weights)
reduced_list.append(this_weighted.__getattribute__(weighted_how)(**reduce_kwargs))
else:
reduced = this.reduce(how, **reduce_kwargs).compute()
reduced = this.reduce(reduce_how, **reduce_kwargs).compute()
reduced = reduced.assign_attrs(dataarray.attrs)
reduced_list.append(reduced)

Expand Down Expand Up @@ -646,8 +661,9 @@ def _reduce_dataarray(
# add the reduced data into a new column as a numpy array,
# store the dim information in the attributes

# TODO: fix typing
out_dims = {
dim: dataarray.coords.get(dim).values if dim in dataarray.coords else None
dim: dataarray.coords.get(dim).values if dim in dataarray.coords else None # type: ignore
for dim in reduced_list[0].dims
}
reduce_attrs[f"{new_short_name}"].update({"dims": out_dims})
Expand Down
6 changes: 3 additions & 3 deletions src/earthkit/transforms/aggregate/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def daily_reduce(
# If how is string, fetch function from dictionary:
if isinstance(how, str):
how = tools.get_how(how)
assert isinstance(how, T.Callable), f"how method not recognised: {how}"
assert callable(how), f"how method not recognised: {how}"

red_array = grouped_data.reduce(how, **kwargs)
try:
Expand Down Expand Up @@ -684,7 +684,7 @@ def monthly_reduce(
# If how is string, fetch function from dictionary:
if isinstance(how, str):
how = tools.get_how(how)
assert isinstance(how, T.Callable), f"how method not recognised: {how}"
assert callable(how), f"how method not recognised: {how}"

red_array = grouped_data.reduce(how, **kwargs)
# Remove the year_months coordinate
Expand Down Expand Up @@ -935,5 +935,5 @@ def rolling_reduce(
A dataarray reduced values with a rolling window applied along the time dimension.
"""
if window_length is not None:
kwargs.update({time_dim: window_length})
kwargs.update({str(time_dim): window_length})
return _rolling_reduce(dataarray, **kwargs)
Loading