Skip to content

Commit 6af547c

Browse files
authored
Handle .oindex and .vindex for the PandasMultiIndexingAdapter and PandasIndexingAdapter (#8869)
1 parent 7c3d2dd commit 6af547c

File tree

1 file changed

+89
-14
lines changed

1 file changed

+89
-14
lines changed

xarray/core/indexing.py

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,11 +1680,65 @@ def _convert_scalar(self, item):
16801680
# a NumPy array.
16811681
return to_0d_array(item)
16821682

1683-
def _oindex_get(self, indexer: OuterIndexer):
1684-
return self.__getitem__(indexer)
1683+
def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]:
1684+
if isinstance(key, tuple) and len(key) == 1:
1685+
# unpack key so it can index a pandas.Index object (pandas.Index
1686+
# objects don't like tuples)
1687+
(key,) = key
16851688

1686-
def _vindex_get(self, indexer: VectorizedIndexer):
1687-
return self.__getitem__(indexer)
1689+
return key
1690+
1691+
def _handle_result(
1692+
self, result: Any
1693+
) -> (
1694+
PandasIndexingAdapter
1695+
| NumpyIndexingAdapter
1696+
| np.ndarray
1697+
| np.datetime64
1698+
| np.timedelta64
1699+
):
1700+
if isinstance(result, pd.Index):
1701+
return type(self)(result, dtype=self.dtype)
1702+
else:
1703+
return self._convert_scalar(result)
1704+
1705+
def _oindex_get(
1706+
self, indexer: OuterIndexer
1707+
) -> (
1708+
PandasIndexingAdapter
1709+
| NumpyIndexingAdapter
1710+
| np.ndarray
1711+
| np.datetime64
1712+
| np.timedelta64
1713+
):
1714+
key = self._prepare_key(indexer.tuple)
1715+
1716+
if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
1717+
indexable = NumpyIndexingAdapter(np.asarray(self))
1718+
return indexable.oindex[indexer]
1719+
1720+
result = self.array[key]
1721+
1722+
return self._handle_result(result)
1723+
1724+
def _vindex_get(
1725+
self, indexer: VectorizedIndexer
1726+
) -> (
1727+
PandasIndexingAdapter
1728+
| NumpyIndexingAdapter
1729+
| np.ndarray
1730+
| np.datetime64
1731+
| np.timedelta64
1732+
):
1733+
key = self._prepare_key(indexer.tuple)
1734+
1735+
if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
1736+
indexable = NumpyIndexingAdapter(np.asarray(self))
1737+
return indexable.vindex[indexer]
1738+
1739+
result = self.array[key]
1740+
1741+
return self._handle_result(result)
16881742

16891743
def __getitem__(
16901744
self, indexer: ExplicitIndexer
@@ -1695,22 +1749,15 @@ def __getitem__(
16951749
| np.datetime64
16961750
| np.timedelta64
16971751
):
1698-
key = indexer.tuple
1699-
if isinstance(key, tuple) and len(key) == 1:
1700-
# unpack key so it can index a pandas.Index object (pandas.Index
1701-
# objects don't like tuples)
1702-
(key,) = key
1752+
key = self._prepare_key(indexer.tuple)
17031753

17041754
if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
17051755
indexable = NumpyIndexingAdapter(np.asarray(self))
1706-
return apply_indexer(indexable, indexer)
1756+
return indexable[indexer]
17071757

17081758
result = self.array[key]
17091759

1710-
if isinstance(result, pd.Index):
1711-
return type(self)(result, dtype=self.dtype)
1712-
else:
1713-
return self._convert_scalar(result)
1760+
return self._handle_result(result)
17141761

17151762
def transpose(self, order) -> pd.Index:
17161763
return self.array # self.array should be always one-dimensional
@@ -1766,6 +1813,34 @@ def _convert_scalar(self, item):
17661813
item = item[idx]
17671814
return super()._convert_scalar(item)
17681815

1816+
def _oindex_get(
1817+
self, indexer: OuterIndexer
1818+
) -> (
1819+
PandasIndexingAdapter
1820+
| NumpyIndexingAdapter
1821+
| np.ndarray
1822+
| np.datetime64
1823+
| np.timedelta64
1824+
):
1825+
result = super()._oindex_get(indexer)
1826+
if isinstance(result, type(self)):
1827+
result.level = self.level
1828+
return result
1829+
1830+
def _vindex_get(
1831+
self, indexer: VectorizedIndexer
1832+
) -> (
1833+
PandasIndexingAdapter
1834+
| NumpyIndexingAdapter
1835+
| np.ndarray
1836+
| np.datetime64
1837+
| np.timedelta64
1838+
):
1839+
result = super()._vindex_get(indexer)
1840+
if isinstance(result, type(self)):
1841+
result.level = self.level
1842+
return result
1843+
17691844
def __getitem__(self, indexer: ExplicitIndexer):
17701845
result = super().__getitem__(indexer)
17711846
if isinstance(result, type(self)):

0 commit comments

Comments
 (0)