Skip to content

Commit 4c3c22b

Browse files
Add support for coordinate inputs in polyfit. (#9369)
* Update polyfit to work with coordinate inputs. * Test whether polyfit properly handles coordinate inputs. * Document polyfit coordinate fix in whats-new.rst. * Update get_clean_interp_index's use_coordinate parameter to take a hashable type. * Replace call to get_clean_interp_index with inline coversion code in polyfit. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Declare x as Any type in polyfit. * Add polyfit output test. * Use floatize_x to convert coords to floats in polyfit. * Update dataset.py Use "raise from" when dimensions aren't castable to float in polyfit. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f01096f commit 4c3c22b

File tree

4 files changed

+38
-6
lines changed

4 files changed

+38
-6
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ Bug fixes
149149
date "0001-01-01". (:issue:`9108`, :pull:`9116`) By `Spencer Clark
150150
<https://github.com/spencerkclark>`_ and `Deepak Cherian
151151
<https://github.com/dcherian>`_.
152+
- Fix issue where polyfit wouldn't handle non-dimension coordinates. (:issue:`4375`, :pull:`9369`)
153+
By `Karl Krauth <https://github.com/Karl-Krauth>`_.
152154
- Fix issue with passing parameters to ZarrStore.open_store when opening
153155
datatree in zarr format (:issue:`9376`, :pull:`9377`).
154156
By `Alfonso Ladino <https://github.com/aladinor>`_

xarray/core/dataset.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
merge_coordinates_without_align,
8888
merge_core,
8989
)
90-
from xarray.core.missing import get_clean_interp_index
90+
from xarray.core.missing import _floatize_x
9191
from xarray.core.options import OPTIONS, _get_keep_attrs
9292
from xarray.core.types import (
9393
Bins,
@@ -9054,7 +9054,16 @@ def polyfit(
90549054
variables = {}
90559055
skipna_da = skipna
90569056

9057-
x = get_clean_interp_index(self, dim, strict=False)
9057+
x: Any = self.coords[dim].variable
9058+
x = _floatize_x((x,), (x,))[0][0]
9059+
9060+
try:
9061+
x = x.values.astype(np.float64)
9062+
except TypeError as e:
9063+
raise TypeError(
9064+
f"Dim {dim!r} must be castable to float64, got {type(x).__name__}."
9065+
) from e
9066+
90589067
xname = f"{self[dim].name}_"
90599068
order = int(deg) + 1
90609069
lhs = np.vander(x, order)
@@ -9093,8 +9102,11 @@ def polyfit(
90939102
)
90949103
variables[sing.name] = sing
90959104

9105+
# If we have a coordinate get its underlying dimension.
9106+
true_dim = self.coords[dim].dims[0]
9107+
90969108
for name, da in self.data_vars.items():
9097-
if dim not in da.dims:
9109+
if true_dim not in da.dims:
90989110
continue
90999111

91009112
if is_duck_dask_array(da.data) and (
@@ -9106,11 +9118,11 @@ def polyfit(
91069118
elif skipna is None:
91079119
skipna_da = bool(np.any(da.isnull()))
91089120

9109-
dims_to_stack = [dimname for dimname in da.dims if dimname != dim]
9121+
dims_to_stack = [dimname for dimname in da.dims if dimname != true_dim]
91109122
stacked_coords: dict[Hashable, DataArray] = {}
91119123
if dims_to_stack:
91129124
stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked")
9113-
rhs = da.transpose(dim, *dims_to_stack).stack(
9125+
rhs = da.transpose(true_dim, *dims_to_stack).stack(
91149126
{stacked_dim: dims_to_stack}
91159127
)
91169128
stacked_coords = {stacked_dim: rhs[stacked_dim]}

xarray/core/missing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs):
227227

228228

229229
def get_clean_interp_index(
230-
arr, dim: Hashable, use_coordinate: str | bool = True, strict: bool = True
230+
arr, dim: Hashable, use_coordinate: Hashable | bool = True, strict: bool = True
231231
):
232232
"""Return index to use for x values in interpolation or curve fitting.
233233

xarray/tests/test_dataset.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6694,6 +6694,24 @@ def test_polyfit_weighted(self) -> None:
66946694
ds.polyfit("dim2", 2, w=np.arange(ds.sizes["dim2"]))
66956695
xr.testing.assert_identical(ds, ds_copy)
66966696

6697+
def test_polyfit_coord(self) -> None:
6698+
# Make sure polyfit works when given a non-dimension coordinate.
6699+
ds = create_test_data(seed=1)
6700+
6701+
out = ds.polyfit("numbers", 2, full=False)
6702+
assert "var3_polyfit_coefficients" in out
6703+
assert "dim1" in out
6704+
assert "dim2" not in out
6705+
assert "dim3" not in out
6706+
6707+
def test_polyfit_coord_output(self) -> None:
6708+
da = xr.DataArray(
6709+
[1, 3, 2], dims=["x"], coords=dict(x=["a", "b", "c"], y=("x", [0, 1, 2]))
6710+
)
6711+
out = da.polyfit("y", deg=1)["polyfit_coefficients"]
6712+
assert out.sel(degree=0).item() == pytest.approx(1.5)
6713+
assert out.sel(degree=1).item() == pytest.approx(0.5)
6714+
66976715
def test_polyfit_warnings(self) -> None:
66986716
ds = create_test_data(seed=1)
66996717

0 commit comments

Comments
 (0)