Skip to content

Commit 2175678

Browse files
authored
Fix incorrect MutliIndex.names (#3355)
1 parent c36c53f commit 2175678

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

mars/dataframe/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ def min_max(self):
378378
def name(self):
379379
return getattr(self._index_value, "_name", None)
380380

381+
@property
382+
def names(self):
383+
return getattr(self._index_value, "_names", [self.name])
384+
381385
@property
382386
def inferred_type(self):
383387
return self._index_value.inferred_type

mars/dataframe/datasource/index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __call__(self, shape=None, chunk_size=None, inp=None, name=None, names=None)
6868
elif hasattr(inp, "index_value"):
6969
# get index from Mars DataFrame, Series or Index
7070
name = name if name is not None else inp.index_value.name
71-
names = names if names is not None else [name]
71+
names = names if names is not None else inp.index_value.names
7272
if inp.index_value.has_value():
7373
self.data = data = inp.index_value.to_pandas()
7474
return self.new_index(

mars/dataframe/datasource/tests/test_datasource.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ....core import tile
2828
from ....tests.core import require_ray
2929
from ....utils import lazy_import
30-
from ...core import IndexValue, DatetimeIndex, Int64Index, Float64Index
30+
from ...core import IndexValue, DatetimeIndex, Int64Index, Float64Index, MultiIndex
3131
from ..core import merge_small_files
3232
from ..dataframe import from_pandas as from_pandas_df
3333
from ..date_range import date_range
@@ -156,6 +156,16 @@ def test_from_pandas_dataframe():
156156
assert len([ns for ns in df.nsplits[1] if ns == 0]) == 0
157157

158158

159+
def test_from_pandas_dataframe_with_multi_index():
160+
index = pd.MultiIndex.from_tuples([("k1", "v1")], names=["X", "Y"])
161+
data = np.random.randint(0, 100, size=(1, 3))
162+
pdf = pd.DataFrame(data, columns=["A", "B", "C"], index=index)
163+
df = from_pandas_df(pdf, chunk_size=4)
164+
assert isinstance(df.index, MultiIndex)
165+
assert df.index.names == ["X", "Y"]
166+
assert df.index.name is None
167+
168+
159169
def test_from_pandas_series():
160170
data = pd.Series(np.random.rand(10), name="a")
161171
series = from_pandas_series(data, chunk_size=4)

0 commit comments

Comments
 (0)