Skip to content

Commit 1df00df

Browse files
qinxuyewjsi
authored andcommitted
Experimental add support for DataFrame (#351)
* add some properties like is_unique to index * add should_be_monotonic property to all index types in DataFrame * add columns_values to dataframe and its chunk * add max_val_close and min_val_close to index_value * add split_monotonic_index_min_max * add arithmetic to dataframe, not finish yet * add tiling for arithmetic
1 parent 20e59a7 commit 1df00df

28 files changed

+2395
-138
lines changed

mars/_utils.pyx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,12 @@ cdef h_numpy(ob):
157157

158158

159159
cdef h_pandas_index(ob):
160-
return h_iterative([ob.name, getattr(ob, 'names', None), ob.values])
160+
if isinstance(ob, pd.RangeIndex):
161+
# for range index, there is no need to get the values
162+
return h_iterative([ob.name, getattr(ob, 'names', None),
163+
slice(ob._start, ob._stop, ob._step)])
164+
else:
165+
return h_iterative([ob.name, getattr(ob, 'names', None), ob.values])
161166

162167

163168
cdef h_pandas_series(ob):

mars/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,11 @@ def to_coarse(self):
394394
new_entity.params.update({'raw_chunk_size': self.nsplits})
395395
return new_entity
396396

397+
def is_sparse(self):
398+
return self.op.is_sparse()
399+
400+
issparse = is_sparse
401+
397402
def tiles(self):
398403
return handler.tiles(self)
399404

mars/dataframe/core.py

Lines changed: 106 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,75 @@
1616

1717
from ..core import ChunkData, Chunk, Entity, TilesableData
1818
from ..serialize import Serializable, ValueType, ProviderType, DataTypeField, AnyField, SeriesField, \
19-
BoolField, Int64Field, Int32Field, ListField, SliceField, OneOfField, ReferenceField
19+
BoolField, Int64Field, Int32Field, StringField, ListField, SliceField, OneOfField, ReferenceField
2020

2121

2222
class IndexValue(Serializable):
2323
__slots__ = ()
2424

25-
class Index(Serializable):
25+
class IndexBase(Serializable):
26+
_key = StringField('key') # to identify if the index is the same
27+
_is_monotonic_increasing = BoolField('is_monotonic_increasing')
28+
_is_monotonic_decreasing = BoolField('is_monotonic_decreasing')
29+
_is_unique = BoolField('is_unique')
30+
_should_be_monotonic = BoolField('should_be_monotonic')
31+
_max_val = AnyField('max_val')
32+
_max_val_close = BoolField('max_val_close')
33+
_min_val = AnyField('min_val')
34+
_min_val_close = BoolField('min_val_close')
35+
36+
@property
37+
def is_monotonic_increasing(self):
38+
return self._is_monotonic_increasing
39+
40+
@property
41+
def is_monotonic_decreasing(self):
42+
return self._is_monotonic_decreasing
43+
44+
@property
45+
def is_unique(self):
46+
return self._is_unique
47+
48+
@property
49+
def should_be_monotonic(self):
50+
return self._should_be_monotonic
51+
52+
@property
53+
def min_val(self):
54+
return self._min_val
55+
56+
@property
57+
def min_val_close(self):
58+
return self._min_val_close
59+
60+
@property
61+
def max_val(self):
62+
return self._max_val
63+
64+
@property
65+
def max_val_close(self):
66+
return self._max_val_close
67+
68+
class Index(IndexBase):
2669
_name = AnyField('name')
2770
_data = ListField('data')
2871
_dtype = DataTypeField('dtype')
2972

30-
class RangeIndex(Serializable):
73+
class RangeIndex(IndexBase):
3174
_name = AnyField('name')
3275
_slice = SliceField('slice')
3376

34-
class CategoricalIndex(Serializable):
77+
class CategoricalIndex(IndexBase):
3578
_name = AnyField('name')
3679
_categories = ListField('categories')
3780
_ordered = BoolField('ordered')
3881

39-
class IntervalIndex(Serializable):
82+
class IntervalIndex(IndexBase):
4083
_name = AnyField('name')
4184
_data = ListField('data')
4285
_closed = BoolField('closed')
4386

44-
class DatetimeIndex(Serializable):
87+
class DatetimeIndex(IndexBase):
4588
_name = AnyField('name')
4689
_data = ListField('data')
4790
_freq = AnyField('freq')
@@ -53,7 +96,7 @@ class DatetimeIndex(Serializable):
5396
_dayfirst = BoolField('dayfirst')
5497
_yearfirst = BoolField('yearfirst')
5598

56-
class TimedeltaIndex(Serializable):
99+
class TimedeltaIndex(IndexBase):
57100
_name = AnyField('name')
58101
_data = ListField('data')
59102
_unit = AnyField('unit')
@@ -63,7 +106,7 @@ class TimedeltaIndex(Serializable):
63106
_end = AnyField('end')
64107
_closed = AnyField('closed')
65108

66-
class PeriodIndex(Serializable):
109+
class PeriodIndex(IndexBase):
67110
_name = AnyField('name')
68111
_data = ListField('data')
69112
_freq = AnyField('freq')
@@ -80,25 +123,24 @@ class PeriodIndex(Serializable):
80123
_tz = AnyField('tz')
81124
_dtype = DataTypeField('dtype')
82125

83-
class Int64Index(Serializable):
126+
class Int64Index(IndexBase):
84127
_name = AnyField('name')
85128
_data = ListField('data')
86129
_dtype = DataTypeField('dtype')
87130

88-
class UInt64Index(Serializable):
131+
class UInt64Index(IndexBase):
89132
_name = AnyField('name')
90133
_data = ListField('data')
91134
_dtype = DataTypeField('dtype')
92135

93-
class Float64Index(Serializable):
136+
class Float64Index(IndexBase):
94137
_name = AnyField('name')
95138
_data = ListField('data')
96139
_dtype = DataTypeField('dtype')
97140

98-
class MultiIndex(Serializable):
141+
class MultiIndex(IndexBase):
99142
_names = ListField('name')
100-
_levels = ListField('levels')
101-
_labels = ListField('labels')
143+
_data = ListField('data')
102144
_sortorder = Int32Field('sortorder')
103145

104146
_index_value = OneOfField('index_value', index=Index,
@@ -113,6 +155,42 @@ def __mars_tokenize__(self):
113155
v = self._index_value
114156
return [type(v).__name__] + [getattr(v, f, None) for f in v.__slots__]
115157

158+
@property
159+
def value(self):
160+
return self._index_value
161+
162+
@property
163+
def is_monotonic_increasing(self):
164+
return self._index_value.is_monotonic_increasing
165+
166+
@property
167+
def is_monotonic_decreasing(self):
168+
return self._index_value.is_monotonic_decreasing
169+
170+
@property
171+
def is_monotonic_increasing_or_decreasing(self):
172+
return self.is_monotonic_increasing or self.is_monotonic_decreasing
173+
174+
@property
175+
def is_unique(self):
176+
return self._index_value.is_unique
177+
178+
@property
179+
def min_val(self):
180+
return self._index_value.min_val
181+
182+
@property
183+
def min_val_close(self):
184+
return self._index_value.min_val_close
185+
186+
@property
187+
def max_val(self):
188+
return self._index_value.max_val
189+
190+
@property
191+
def max_val_close(self):
192+
return self._index_value.max_val_close
193+
116194

117195
class IndexChunkData(ChunkData):
118196
__slots__ = ()
@@ -224,18 +302,22 @@ class DataFrameChunkData(ChunkData):
224302
# optional field
225303
_dtypes = SeriesField('dtypes')
226304
_index_value = ReferenceField('index_value', IndexValue)
305+
_columns_value = ReferenceField('columns_value', IndexValue)
227306

228307
@property
229308
def dtypes(self):
230-
return getattr(self, '_dtypes', None) or getattr(self.op, 'dtypes', None)
309+
dt = getattr(self, '_dtypes', None)
310+
if dt is not None:
311+
return dt
312+
return getattr(self.op, 'dtypes', None)
231313

232314
@property
233315
def index_value(self):
234316
return self._index_value
235317

236318
@property
237319
def columns(self):
238-
return self._columns
320+
return self._columns_value
239321

240322

241323
class DataFrameChunk(Chunk):
@@ -249,27 +331,33 @@ class DataFrameData(TilesableData):
249331
# optional field
250332
_dtypes = SeriesField('dtypes')
251333
_index_value = ReferenceField('index_value', IndexValue)
334+
_columns_value = ReferenceField('columns_value', IndexValue)
252335
_chunks = ListField('chunks', ValueType.reference(DataFrameChunkData),
253336
on_serialize=lambda x: [it.data for it in x] if x is not None else x,
254337
on_deserialize=lambda x: [DataFrameChunk(it) for it in x] if x is not None else x)
255338

256339
@property
257340
def dtypes(self):
258-
return getattr(self, '_dtypes', None) or getattr(self.op, 'dtypes', None)
341+
dt = getattr(self, '_dtypes', None)
342+
if dt is not None:
343+
return dt
344+
return getattr(self.op, 'dtypes', None)
259345

260346
@property
261347
def index_value(self):
262348
return self._index_value
263349

264350
@property
265351
def columns(self):
266-
return self._columns
352+
return self._columns_value
267353

268354

269355
class DataFrame(Entity):
270356
__slots__ = ()
271357
_allow_data_type_ = (DataFrameData,)
272358

273359

360+
INDEX_TYPE = (Index, IndexData)
361+
SERIES_TYPE = (Series, SeriesData)
274362
DATAFRAME_TYPE = (DataFrame, DataFrameData)
275363
CHUNK_TYPE = (DataFrameChunk, DataFrameChunkData)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 1999-2018 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import operator
16+
import itertools
17+
import hashlib
18+
import functools
19+
20+
try:
21+
import pandas as pd
22+
except ImportError: # pragma: no cover
23+
pass
24+
25+
from ..expressions.arithmetic.core import DataFrameIndexAlignMap, DataFrameIndexAlignReduce
26+
from ..expressions.arithmetic import DataFrameAdd
27+
28+
29+
def _hash(index, size):
30+
def func(x, size):
31+
return int(hashlib.md5(bytes(x)).hexdigest(), 16) % size
32+
33+
f = functools.partial(func, size=size)
34+
grouped = sorted(index.groupby(index.map(f)).items(),
35+
key=operator.itemgetter(0))
36+
return [g[1] for g in grouped]
37+
38+
39+
def _index_align_map(ctx, chunk):
40+
# TODO(QIN): add GPU support here
41+
df = ctx[chunk.inputs[0].key]
42+
43+
filters = [[], []]
44+
45+
if chunk.op.index_shuffle_size is None:
46+
# no shuffle on index
47+
op = operator.ge if chunk.op.index_min_close else operator.gt
48+
index_cond = op(df.index, chunk.op.index_min)
49+
op = operator.le if chunk.op.index_max_close else operator.lt
50+
index_cond = index_cond & op(df.index, chunk.op.index_max)
51+
filters[0].append(index_cond)
52+
else:
53+
# shuffle on index
54+
shuffle_size = chunk.op.index_shuffle_size
55+
filters[0].extend(_hash(df.index, shuffle_size))
56+
57+
if chunk.op.column_shuffle_size is None:
58+
# no shuffle on columns
59+
op = operator.ge if chunk.op.column_min_close else operator.gt
60+
columns_cond = op(df.columns, chunk.op.column_min)
61+
op = operator.le if chunk.op.column_max_close else operator.ge
62+
columns_cond = columns_cond & op(df.columns, chunk.op.column_max)
63+
filters[1].append(columns_cond)
64+
else:
65+
# shuffle on columns
66+
shuffle_size = chunk.op.column_shuffle_size
67+
filters[1].extend(_hash(df.columns, shuffle_size))
68+
69+
if all(len(it) == 1 for it in filters):
70+
# no shuffle
71+
ctx[chunk.key] = df.loc[filters[0][0], filters[1][0]]
72+
return
73+
elif len(filters[0]) == 1:
74+
# shuffle on columns
75+
for column_idx, column_filter in enumerate(filters[1]):
76+
group_key = ','.join([str(chunk.index[0]), str(column_idx)])
77+
ctx[(chunk.key, group_key)] = df.loc[filters[0][0], column_filter]
78+
elif len(filters[1]) == 1:
79+
# shuffle on index
80+
for index_idx, index_filter in enumerate(filters[0]):
81+
group_key = ','.join([str(index_idx), str(chunk.index[1])])
82+
ctx[(chunk.key, group_key)] = df.loc[index_filter, filters[1][0]]
83+
else:
84+
# full shuffle
85+
shuffle_index_size = chunk.op.index_shuffle_size
86+
shuffle_column_size = chunk.op.column_shuffle_size
87+
out_idxes = itertools.product(range(shuffle_index_size), range(shuffle_column_size))
88+
out_index_columns = itertools.product(*filters)
89+
for out_idx, out_index_column in zip(out_idxes, out_index_columns):
90+
index_filter, column_filter = out_index_column
91+
group_key = ','.join(str(i) for i in out_idx)
92+
ctx[(chunk.key, group_key)] = df.loc[index_filter, column_filter]
93+
94+
95+
def _index_align_reduce(ctx, chunk):
96+
input_idx_to_df = {inp.index: ctx[inp.key, ','.join(str(idx) for idx in chunk.index)]
97+
for inp in chunk.inputs[0].inputs}
98+
row_idxes = sorted({idx[0] for idx in input_idx_to_df})
99+
col_idxes = sorted({idx[1] for idx in input_idx_to_df})
100+
101+
res = None
102+
for row_idx in row_idxes:
103+
row_df = None
104+
for col_idx in col_idxes:
105+
df = input_idx_to_df[row_idx, col_idx]
106+
if row_df is None:
107+
row_df = df
108+
else:
109+
row_df = pd.concat([row_df, df], axis=1)
110+
111+
if res is None:
112+
res = row_df
113+
else:
114+
res = pd.concat([res, row_df], axis=0)
115+
116+
ctx[chunk.key] = res
117+
118+
119+
def _add(ctx, chunk):
120+
left, right = ctx[chunk.inputs[0].key], ctx[chunk.inputs[1].key]
121+
ctx[chunk.key] = left.add(right, axis=chunk.op.axis,
122+
level=chunk.op.level, fill_value=chunk.op.fill_value)
123+
124+
125+
def register_arithmetic_handler():
126+
from ...executor import register
127+
128+
register(DataFrameIndexAlignMap, _index_align_map)
129+
register(DataFrameIndexAlignReduce, _index_align_reduce)
130+
register(DataFrameAdd, _add)

0 commit comments

Comments
 (0)