From 71959e3860339469e92bf83ce210b0680d1aa863 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 1 Sep 2022 16:18:18 -0400 Subject: [PATCH 1/4] fix passing of curvefit kwargs --- xarray/core/dataarray.py | 4 ++-- xarray/core/dataset.py | 2 +- xarray/tests/test_dataarray.py | 6 +++++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 9dfdb6603e6..04f7c7275c2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -5066,7 +5066,7 @@ def curvefit( p0: dict[str, Any] | None = None, bounds: dict[str, Any] | None = None, param_names: Sequence[str] | None = None, - kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> Dataset: """ Curve fitting optimization for arbitrary functions. @@ -5132,7 +5132,7 @@ def curvefit( p0=p0, bounds=bounds, param_names=param_names, - kwargs=kwargs, + **kwargs, ) def drop_duplicates( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c3717190df6..e8d04d4a3ae 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8199,7 +8199,7 @@ def curvefit( p0: dict[str, Any] | None = None, bounds: dict[str, Any] | None = None, param_names: Sequence[str] | None = None, - kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> T_Dataset: """ Curve fitting optimization for arbitrary functions. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 298840f3f66..515afc22a20 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4219,7 +4219,11 @@ def exp_decay(t, n0, tau=1): da = da.chunk({"x": 1}) fit = da.curvefit( - coords=[da.t], func=exp_decay, p0={"n0": 4}, bounds={"tau": [2, 6]} + coords=[da.t], + func=exp_decay, + p0={"n0": 4}, + bounds={"tau": [2, 6]}, + maxfev=1000, ) assert_allclose(fit.curvefit_coefficients, expected, rtol=1e-3) From d8915fe0b5af08b52212ebda80b2c556821455cb Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 2 Sep 2022 19:32:08 -0400 Subject: [PATCH 2/4] prevent regression for kwargs passed as dict --- xarray/core/dataset.py | 3 ++- xarray/tests/test_dataarray.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e8d04d4a3ae..dff4f9697d3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8269,7 +8269,8 @@ def curvefit( bounds = {} if kwargs is None: kwargs = {} - + elif "kwargs" in kwargs: + kwargs = {**kwargs.pop("kwargs"), **kwargs} if not reduce_dims: reduce_dims_ = [] elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 515afc22a20..72a434863c4 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4228,7 +4228,13 @@ def exp_decay(t, n0, tau=1): assert_allclose(fit.curvefit_coefficients, expected, rtol=1e-3) da = da.compute() - fit = da.curvefit(coords="t", func=np.power, reduce_dims="x", param_names=["a"]) + fit = da.curvefit( + coords="t", + func=np.power, + reduce_dims="x", + param_names=["a"], + kwargs={"maxfev": 1000}, + ) assert "a" in fit.param assert "x" not in fit.dims From 4ca591b59e7e3b5a6f9170ace17cec1a42962fd5 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 23 Sep 2022 09:45:13 -0400 Subject: [PATCH 3/4] update user-guide example to new kwargs method --- doc/user-guide/computation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index dc9748af80b..215fc76d5bf 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -542,7 +542,7 @@ two gaussian peaks: coords=["x", "y"], func=multi_peak, param_names=names, - kwargs={"maxfev": 10000}, + maxfev=10000, ) .. note:: From a2e07b0a1bf2a487e01b717776b2f2ec3fbb6366 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 23 Sep 2022 09:59:01 -0400 Subject: [PATCH 4/4] add whats-new and fix docs typo --- doc/whats-new.rst | 3 +++ xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 10b59ba1c0c..6af6914e5e6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -75,6 +75,9 @@ Bug fixes - Allow writing NetCDF files including only dimensionless variables using the distributed or multiprocessing scheduler (:issue:`7013`, :pull:`7040`). By `Francesco Nattino `_. +- Allow passing additional kwargs directly to ``scipy`` in + :py:meth:`Dataset.curvefit`. (:issue:`6891`, :pull:`6978`) + By `Sam Levang `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1e5ca4a615f..d7323e7c871 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -5109,7 +5109,7 @@ def curvefit( should be manually supplied when fitting a function that takes a variable number of parameters. **kwargs : optional - Additional keyword arguments to passed to scipy curve_fit. + Additional keyword arguments passed to scipy curve_fit. Returns ------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 22847e343bd..9158ebf4e03 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8247,7 +8247,7 @@ def curvefit( should be manually supplied when fitting a function that takes a variable number of parameters. **kwargs : optional - Additional keyword arguments to passed to scipy curve_fit. + Additional keyword arguments passed to scipy curve_fit. Returns -------