Skip to content

Commit b491595

Browse files
Merge pull request #352 from PyPSA/groupby-df-with-name-column
groupby: fix pandas dataframe with column `name` as grouper
2 parents 705194a + 9e26384 commit b491595

File tree

3 files changed

+59
-14
lines changed

3 files changed

+59
-14
lines changed

doc/release_notes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Release Notes
44
Upcoming Version
55
----------------
66

7+
* The group dimension when grouping by a pandas dataframe is now always `group`. This fixes the case that the dataframe contains a column named `name`.
8+
79
Version 0.3.14
810
--------------
911

linopy/expressions.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,17 @@ def groupby(self) -> xarray.core.groupby.DatasetGroupBy:
148148
xarray.core.groupby.DataArrayGroupBy
149149
The groupby object.
150150
"""
151-
if isinstance(self.group, (pd.Series, pd.DataFrame)):
151+
if isinstance(self.group, pd.DataFrame):
152152
raise ValueError(
153-
"Grouping by pandas objects is only supported in sum function."
153+
"Grouping by a DataFrame only supported for `sum` operation with `use_fallback=False`."
154154
)
155+
if isinstance(self.group, pd.Series):
156+
group_name = self.group.name or "group"
157+
group = DataArray(self.group, name=group_name)
158+
else:
159+
group = self.group # type: ignore
155160

156-
return self.data.groupby(group=self.group, **self.kwargs)
161+
return self.data.groupby(group=group, **self.kwargs)
157162

158163
def map(
159164
self, func: Callable, shortcut: bool = False, args: tuple[()] = (), **kwargs
@@ -210,7 +215,11 @@ def sum(self, use_fallback: bool = False, **kwargs) -> LinearExpression:
210215
non_fallback_types = (pd.Series, pd.DataFrame, xr.DataArray)
211216
if isinstance(self.group, non_fallback_types) and not use_fallback:
212217
group: pd.Series | pd.DataFrame | xr.DataArray = self.group
213-
group_name = getattr(group, "name", "group") or "group"
218+
if isinstance(group, pd.DataFrame):
219+
# dataframes do not have a name, so we need to set it
220+
group_name = "group"
221+
else:
222+
group_name = getattr(group, "name", "group") or "group"
214223

215224
if isinstance(group, DataArray):
216225
group = group.to_pandas()
@@ -224,7 +233,9 @@ def sum(self, use_fallback: bool = False, **kwargs) -> LinearExpression:
224233

225234
group_dim = group.index.name
226235
if group_name == group_dim:
227-
raise ValueError("Group name cannot be the same as group dimension")
236+
raise ValueError(
237+
"Group name cannot be the same as group dimension in non-fallback mode."
238+
)
228239

229240
arrays = [group, group.groupby(group).cumcount()]
230241
idx = pd.MultiIndex.from_arrays(

test/test_linear_expression.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -690,42 +690,74 @@ def test_linear_expression_groupby_with_name(v, use_fallback):
690690
assert grouped.nterm == 10
691691

692692

693-
def test_linear_expression_groupby_with_series(v):
693+
@pytest.mark.parametrize("use_fallback", [True, False])
694+
def test_linear_expression_groupby_with_series(v, use_fallback):
694695
expr = 1 * v
695696
groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes["dim_2"])
696-
grouped = expr.groupby(groups).sum()
697+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
697698
assert "group" in grouped.dims
698699
assert (grouped.data.group == [1, 2]).all()
699700
assert grouped.nterm == 10
700701

701702

702-
def test_linear_expression_groupby_with_series_false(v):
703+
@pytest.mark.parametrize("use_fallback", [True, False])
704+
def test_linear_expression_groupby_series_with_name(v, use_fallback):
705+
expr = 1 * v
706+
groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes[v.dims[0]], name="my_group")
707+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
708+
assert "my_group" in grouped.dims
709+
assert (grouped.data.my_group == [1, 2]).all()
710+
assert grouped.nterm == 10
711+
712+
713+
@pytest.mark.parametrize("use_fallback", [True, False])
714+
def test_linear_expression_groupby_with_series_false(v, use_fallback):
703715
expr = 1 * v
704716
groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes["dim_2"])
705717
groups.name = "dim_2"
706-
with pytest.raises(ValueError):
707-
expr.groupby(groups).sum()
718+
if not use_fallback:
719+
with pytest.raises(ValueError):
720+
expr.groupby(groups).sum(use_fallback=use_fallback)
721+
return
722+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
723+
assert "dim_2" in grouped.dims
724+
assert (grouped.data.dim_2 == [1, 2]).all()
725+
assert grouped.nterm == 10
708726

709727

710-
def test_linear_expression_groupby_with_dataframe(v):
728+
@pytest.mark.parametrize("use_fallback", [True, False])
729+
def test_linear_expression_groupby_with_dataframe(v, use_fallback):
711730
expr = 1 * v
712731
groups = pd.DataFrame(
713732
{"a": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, index=v.indexes["dim_2"]
714733
)
715-
grouped = expr.groupby(groups).sum()
734+
if use_fallback:
735+
with pytest.raises(ValueError):
736+
expr.groupby(groups).sum(use_fallback=use_fallback)
737+
return
738+
739+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
716740
index = pd.MultiIndex.from_frame(groups)
717741
assert "group" in grouped.dims
718742
assert set(grouped.data.group.values) == set(index.values)
719743
assert grouped.nterm == 3
720744

721745

722-
def test_linear_expression_groupby_with_dataarray(v):
746+
@pytest.mark.parametrize("use_fallback", [True, False])
747+
def test_linear_expression_groupby_with_dataarray(v, use_fallback):
723748
expr = 1 * v
724749
df = pd.DataFrame(
725750
{"a": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, index=v.indexes["dim_2"]
726751
)
727752
groups = xr.DataArray(df)
728-
grouped = expr.groupby(groups).sum()
753+
754+
# this should not be the case, see https://github.com/PyPSA/linopy/issues/351
755+
if use_fallback:
756+
with pytest.raises(KeyError):
757+
expr.groupby(groups).sum(use_fallback=use_fallback)
758+
return
759+
760+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
729761
index = pd.MultiIndex.from_frame(df)
730762
assert "group" in grouped.dims
731763
assert set(grouped.data.group.values) == set(index.values)

0 commit comments

Comments
 (0)