Skip to content

Commit e550ae4

Browse files
wjsihekaisheng
andauthored
[BACKPORT] Fix incorrect result for df.sort_values when specifying multiple ascending (#2984) (#3006)
Co-authored-by: He Kaisheng <heks93@163.com>
1 parent 85331e8 commit e550ae4

File tree

8 files changed

+124
-169
lines changed

8 files changed

+124
-169
lines changed

mars/core/operand/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def __init__(self: OperandType, *args, **kwargs):
147147
extra_names = (
148148
set(kwargs) - set(self._FIELDS) - set(SchedulingHint.all_hint_names)
149149
)
150-
extras = AttributeDict((k, kwargs.pop(k)) for k in extra_names)
151-
kwargs["extra_params"] = kwargs.pop("extra_params", extras)
150+
extras = dict((k, kwargs.pop(k)) for k in extra_names)
151+
kwargs["extra_params"] = AttributeDict(kwargs.pop("extra_params", extras))
152152
self._extract_scheduling_hint(kwargs)
153153
super().__init__(*args, **kwargs)
154154

mars/core/operand/tests/test_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class MyOperand5(MyOperand4):
7373

7474

7575
def test_execute():
76+
op = MyOperand(extra_params={"my_extra_params": 1})
77+
assert op.extra_params["my_extra_params"] == 1
7678
MyOperand.register_executor(lambda *_: 2)
7779
assert execute(dict(), MyOperand(_key="1")) == 2
7880
assert execute(dict(), MyOperand2(_key="1")) == 2

mars/dataframe/sort/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
from ...core.operand import OperandStage
2020
from ...serialization.serializables import (
2121
FieldTypes,
22+
AnyField,
23+
BoolField,
2224
Int32Field,
2325
Int64Field,
24-
StringField,
2526
ListField,
26-
BoolField,
27+
StringField,
2728
)
2829
from ...utils import ceildiv
2930
from ..operands import DataFrameOperand
@@ -32,7 +33,7 @@
3233

3334
class DataFrameSortOperand(DataFrameOperand):
3435
_axis = Int32Field("axis")
35-
_ascending = BoolField("ascending")
36+
_ascending = AnyField("ascending")
3637
_inplace = BoolField("inplace")
3738
_kind = StringField("kind")
3839
_na_position = StringField("na_position")

mars/dataframe/sort/psrs.py

Lines changed: 65 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@
2020
from ... import opcodes as OperandDef
2121
from ...core.operand import OperandStage, MapReduceOperand
2222
from ...utils import lazy_import
23-
from ...serialization.serializables import Int32Field, ListField, StringField, BoolField
23+
from ...serialization.serializables import (
24+
AnyField,
25+
Int32Field,
26+
ListField,
27+
StringField,
28+
BoolField,
29+
)
2430
from ...tensor.base.psrs import PSRSOperandMixin
2531
from ..core import IndexValue, OutputType
2632
from ..utils import standardize_range_index, parse_index, is_cudf
@@ -48,6 +54,23 @@ def __gt__(self, other):
4854
_largest = _Largest()
4955

5056

57+
class _ReversedValue:
58+
def __init__(self, value):
59+
self._value = value
60+
61+
def __lt__(self, other):
62+
if type(other) is _ReversedValue:
63+
# may happen when call searchsorted
64+
return self._value >= other._value
65+
return self._value >= other
66+
67+
def __gt__(self, other):
68+
return self._value <= other
69+
70+
def __repr__(self):
71+
return repr(self._value)
72+
73+
5174
class DataFramePSRSOperandMixin(DataFrameOperandMixin, PSRSOperandMixin):
5275
@classmethod
5376
def _collect_op_properties(cls, op):
@@ -377,90 +400,23 @@ def execute_sort_index(data, op, inplace=None):
377400

378401
class DataFramePSRSChunkOperand(DataFrameOperand):
379402
# sort type could be 'sort_values' or 'sort_index'
380-
_sort_type = StringField("sort_type")
403+
sort_type = StringField("sort_type")
381404

382-
_axis = Int32Field("axis")
383-
_by = ListField("by")
384-
_ascending = BoolField("ascending")
385-
_inplace = BoolField("inplace")
386-
_kind = StringField("kind")
387-
_na_position = StringField("na_position")
405+
axis = Int32Field("axis")
406+
by = ListField("by", default=None)
407+
ascending = AnyField("ascending")
408+
inplace = BoolField("inplace")
409+
kind = StringField("kind")
410+
na_position = StringField("na_position")
388411

389412
# for sort_index
390-
_level = ListField("level")
391-
_sort_remaining = BoolField("sort_remaining")
392-
393-
_n_partition = Int32Field("n_partition")
394-
395-
def __init__(
396-
self,
397-
sort_type=None,
398-
by=None,
399-
axis=None,
400-
ascending=None,
401-
inplace=None,
402-
kind=None,
403-
na_position=None,
404-
level=None,
405-
sort_remaining=None,
406-
n_partition=None,
407-
output_types=None,
408-
**kw
409-
):
410-
super().__init__(
411-
_sort_type=sort_type,
412-
_by=by,
413-
_axis=axis,
414-
_ascending=ascending,
415-
_inplace=inplace,
416-
_kind=kind,
417-
_na_position=na_position,
418-
_level=level,
419-
_sort_remaining=sort_remaining,
420-
_n_partition=n_partition,
421-
_output_types=output_types,
422-
**kw
423-
)
413+
level = ListField("level")
414+
sort_remaining = BoolField("sort_remaining")
424415

425-
@property
426-
def sort_type(self):
427-
return self._sort_type
416+
n_partition = Int32Field("n_partition")
428417

429-
@property
430-
def axis(self):
431-
return self._axis
432-
433-
@property
434-
def by(self):
435-
return self._by
436-
437-
@property
438-
def ascending(self):
439-
return self._ascending
440-
441-
@property
442-
def inplace(self):
443-
return self._inplace
444-
445-
@property
446-
def kind(self):
447-
return self._kind
448-
449-
@property
450-
def na_position(self):
451-
return self._na_position
452-
453-
@property
454-
def level(self):
455-
return self._level
456-
457-
@property
458-
def sort_remaining(self):
459-
return self._sort_remaining
460-
461-
@property
462-
def n_partition(self):
463-
return self._n_partition
418+
def __init__(self, output_types=None, **kw):
419+
super().__init__(_output_types=output_types, **kw)
464420

465421

466422
class DataFramePSRSSortRegularSample(DataFramePSRSChunkOperand, DataFrameOperandMixin):
@@ -564,99 +520,49 @@ def execute(cls, ctx, op):
564520
class DataFramePSRSShuffle(MapReduceOperand, DataFrameOperandMixin):
565521
_op_type_ = OperandDef.PSRS_SHUFFLE
566522

567-
_sort_type = StringField("sort_type")
523+
sort_type = StringField("sort_type")
568524

569525
# for shuffle map
570-
_axis = Int32Field("axis")
571-
_by = ListField("by")
572-
_ascending = BoolField("ascending")
573-
_inplace = BoolField("inplace")
574-
_na_position = StringField("na_position")
575-
_n_partition = Int32Field("n_partition")
526+
axis = Int32Field("axis")
527+
by = ListField("by")
528+
ascending = AnyField("ascending")
529+
inplace = BoolField("inplace")
530+
na_position = StringField("na_position")
531+
n_partition = Int32Field("n_partition")
576532

577533
# for sort_index
578-
_level = ListField("level")
579-
_sort_remaining = BoolField("sort_remaining")
534+
level = ListField("level")
535+
sort_remaining = BoolField("sort_remaining")
580536

581537
# for shuffle reduce
582-
_kind = StringField("kind")
583-
584-
def __init__(
585-
self,
586-
sort_type=None,
587-
by=None,
588-
axis=None,
589-
ascending=None,
590-
n_partition=None,
591-
na_position=None,
592-
inplace=None,
593-
kind=None,
594-
level=None,
595-
sort_remaining=None,
596-
output_types=None,
597-
**kw
598-
):
599-
super().__init__(
600-
_sort_type=sort_type,
601-
_by=by,
602-
_axis=axis,
603-
_ascending=ascending,
604-
_n_partition=n_partition,
605-
_na_position=na_position,
606-
_inplace=inplace,
607-
_kind=kind,
608-
_level=level,
609-
_sort_remaining=sort_remaining,
610-
_output_types=output_types,
611-
**kw
612-
)
613-
614-
@property
615-
def sort_type(self):
616-
return self._sort_type
617-
618-
@property
619-
def by(self):
620-
return self._by
621-
622-
@property
623-
def axis(self):
624-
return self._axis
625-
626-
@property
627-
def ascending(self):
628-
return self._ascending
538+
kind = StringField("kind")
629539

630-
@property
631-
def inplace(self):
632-
return self._inplace
633-
634-
@property
635-
def na_position(self):
636-
return self._na_position
637-
638-
@property
639-
def level(self):
640-
return self._level
641-
642-
@property
643-
def sort_remaining(self):
644-
return self._sort_remaining
645-
646-
@property
647-
def n_partition(self):
648-
return self._n_partition
649-
650-
@property
651-
def kind(self):
652-
return self._kind
540+
def __init__(self, output_types=None, **kw):
541+
super().__init__(_output_types=output_types, **kw)
653542

654543
@property
655544
def output_limit(self):
656545
return 1
657546

658547
@staticmethod
659548
def _calc_poses(src_cols, pivots, ascending=True):
549+
if isinstance(ascending, list):
550+
for asc, col in zip(ascending, pivots.columns):
551+
# Make pivots available to use ascending order when mixed order specified
552+
if not asc:
553+
if pd.api.types.is_numeric_dtype(pivots.dtypes[col]):
554+
# for numeric dtypes, convert to negative is more efficient
555+
pivots[col] = -pivots[col]
556+
src_cols[col] = -src_cols[col]
557+
else:
558+
# for other types, convert to ReversedValue
559+
pivots[col] = pivots[col].map(
560+
lambda x: x
561+
if type(x) is _ReversedValue
562+
else _ReversedValue(x)
563+
)
564+
ascending = True
565+
660566
records = src_cols.to_records(index=False)
661567
p_records = pivots.to_records(index=False)
662568
if ascending:

mars/dataframe/sort/sort_values.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ def dataframe_sort_values(
252252
raise NotImplementedError("Only support sort on axis 0")
253253
psrs_kinds = _validate_sort_psrs_kinds(psrs_kinds)
254254
by = by if isinstance(by, (list, tuple)) else [by]
255+
if isinstance(ascending, list): # pragma: no cover
256+
if all(ascending):
257+
# all are True, convert to True
258+
ascending = True
259+
elif not any(ascending):
260+
# all are False, convert to False
261+
ascending = False
255262
op = DataFrameSortValues(
256263
by=by,
257264
axis=axis,

mars/dataframe/sort/tests/test_sort_execution.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
"distinct_opt", ["0"] if sys.platform.lower().startswith("win") else ["0", "1"]
2828
)
2929
def test_sort_values_execution(setup, distinct_opt):
30+
ns = np.random.RandomState(0)
3031
os.environ["PSRS_DISTINCT_COL"] = distinct_opt
31-
df = pd.DataFrame(
32-
np.random.rand(100, 10), columns=["a" + str(i) for i in range(10)]
33-
)
32+
df = pd.DataFrame(ns.rand(100, 10), columns=["a" + str(i) for i in range(10)])
3433

3534
# test one chunk
3635
mdf = DataFrame(df)
@@ -67,6 +66,38 @@ def test_sort_values_execution(setup, distinct_opt):
6766

6867
pd.testing.assert_frame_equal(result, expected)
6968

69+
# test ascending is a list
70+
result = (
71+
mdf.sort_values(["a3", "a4", "a5", "a6"], ascending=[False, True, True, False])
72+
.execute()
73+
.fetch()
74+
)
75+
expected = df.sort_values(
76+
["a3", "a4", "a5", "a6"], ascending=[False, True, True, False]
77+
)
78+
pd.testing.assert_frame_equal(result, expected)
79+
80+
in_df = pd.DataFrame(
81+
{
82+
"col1": ns.choice([f"a{i}" for i in range(5)], size=(100,)),
83+
"col2": ns.choice([f"b{i}" for i in range(5)], size=(100,)),
84+
"col3": ns.choice([f"c{i}" for i in range(5)], size=(100,)),
85+
"col4": ns.randint(10, 20, size=(100,)),
86+
}
87+
)
88+
mdf = DataFrame(in_df, chunk_size=10)
89+
result = (
90+
mdf.sort_values(
91+
["col1", "col4", "col3", "col2"], ascending=[False, False, True, False]
92+
)
93+
.execute()
94+
.fetch()
95+
)
96+
expected = in_df.sort_values(
97+
["col1", "col4", "col3", "col2"], ascending=[False, False, True, False]
98+
)
99+
pd.testing.assert_frame_equal(result, expected)
100+
70101
# test multiindex
71102
df2 = df.copy(deep=True)
72103
df2.columns = pd.MultiIndex.from_product([list("AB"), list("CDEFG")])

mars/oscar/backends/mars/pool.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,14 @@ async def kill_sub_pool(
238238
await asyncio.to_thread(process.join, 5)
239239

240240
async def is_sub_pool_alive(self, process: multiprocessing.Process):
241-
return await asyncio.to_thread(process.is_alive)
241+
try:
242+
return await asyncio.to_thread(process.is_alive)
243+
except RuntimeError as ex: # pragma: no cover
244+
if "shutdown" not in str(ex):
245+
# when atexit is triggered, the default pool might be shutdown
246+
# and to_thread will fail
247+
raise
248+
return process.is_alive()
242249

243250
async def recover_sub_pool(self, address: str):
244251
process_index = self._config.get_process_index(address)

0 commit comments

Comments
 (0)