|
20 | 20 | from ... import opcodes as OperandDef
|
21 | 21 | from ...core.operand import OperandStage, MapReduceOperand
|
22 | 22 | from ...utils import lazy_import
|
23 |
| -from ...serialization.serializables import Int32Field, ListField, StringField, BoolField |
| 23 | +from ...serialization.serializables import ( |
| 24 | + AnyField, |
| 25 | + Int32Field, |
| 26 | + ListField, |
| 27 | + StringField, |
| 28 | + BoolField, |
| 29 | +) |
24 | 30 | from ...tensor.base.psrs import PSRSOperandMixin
|
25 | 31 | from ..core import IndexValue, OutputType
|
26 | 32 | from ..utils import standardize_range_index, parse_index, is_cudf
|
@@ -48,6 +54,23 @@ def __gt__(self, other):
|
48 | 54 | _largest = _Largest()
|
49 | 55 |
|
50 | 56 |
|
| 57 | +class _ReversedValue: |
| 58 | + def __init__(self, value): |
| 59 | + self._value = value |
| 60 | + |
| 61 | + def __lt__(self, other): |
| 62 | + if type(other) is _ReversedValue: |
| 63 | + # may happen when call searchsorted |
| 64 | + return self._value >= other._value |
| 65 | + return self._value >= other |
| 66 | + |
| 67 | + def __gt__(self, other): |
| 68 | + return self._value <= other |
| 69 | + |
| 70 | + def __repr__(self): |
| 71 | + return repr(self._value) |
| 72 | + |
| 73 | + |
51 | 74 | class DataFramePSRSOperandMixin(DataFrameOperandMixin, PSRSOperandMixin):
|
52 | 75 | @classmethod
|
53 | 76 | def _collect_op_properties(cls, op):
|
@@ -377,90 +400,23 @@ def execute_sort_index(data, op, inplace=None):
|
377 | 400 |
|
378 | 401 | class DataFramePSRSChunkOperand(DataFrameOperand):
|
379 | 402 | # sort type could be 'sort_values' or 'sort_index'
|
380 |
| - _sort_type = StringField("sort_type") |
| 403 | + sort_type = StringField("sort_type") |
381 | 404 |
|
382 |
| - _axis = Int32Field("axis") |
383 |
| - _by = ListField("by") |
384 |
| - _ascending = BoolField("ascending") |
385 |
| - _inplace = BoolField("inplace") |
386 |
| - _kind = StringField("kind") |
387 |
| - _na_position = StringField("na_position") |
| 405 | + axis = Int32Field("axis") |
| 406 | + by = ListField("by", default=None) |
| 407 | + ascending = AnyField("ascending") |
| 408 | + inplace = BoolField("inplace") |
| 409 | + kind = StringField("kind") |
| 410 | + na_position = StringField("na_position") |
388 | 411 |
|
389 | 412 | # for sort_index
|
390 |
| - _level = ListField("level") |
391 |
| - _sort_remaining = BoolField("sort_remaining") |
392 |
| - |
393 |
| - _n_partition = Int32Field("n_partition") |
394 |
| - |
395 |
| - def __init__( |
396 |
| - self, |
397 |
| - sort_type=None, |
398 |
| - by=None, |
399 |
| - axis=None, |
400 |
| - ascending=None, |
401 |
| - inplace=None, |
402 |
| - kind=None, |
403 |
| - na_position=None, |
404 |
| - level=None, |
405 |
| - sort_remaining=None, |
406 |
| - n_partition=None, |
407 |
| - output_types=None, |
408 |
| - **kw |
409 |
| - ): |
410 |
| - super().__init__( |
411 |
| - _sort_type=sort_type, |
412 |
| - _by=by, |
413 |
| - _axis=axis, |
414 |
| - _ascending=ascending, |
415 |
| - _inplace=inplace, |
416 |
| - _kind=kind, |
417 |
| - _na_position=na_position, |
418 |
| - _level=level, |
419 |
| - _sort_remaining=sort_remaining, |
420 |
| - _n_partition=n_partition, |
421 |
| - _output_types=output_types, |
422 |
| - **kw |
423 |
| - ) |
| 413 | + level = ListField("level") |
| 414 | + sort_remaining = BoolField("sort_remaining") |
424 | 415 |
|
425 |
| - @property |
426 |
| - def sort_type(self): |
427 |
| - return self._sort_type |
| 416 | + n_partition = Int32Field("n_partition") |
428 | 417 |
|
429 |
| - @property |
430 |
| - def axis(self): |
431 |
| - return self._axis |
432 |
| - |
433 |
| - @property |
434 |
| - def by(self): |
435 |
| - return self._by |
436 |
| - |
437 |
| - @property |
438 |
| - def ascending(self): |
439 |
| - return self._ascending |
440 |
| - |
441 |
| - @property |
442 |
| - def inplace(self): |
443 |
| - return self._inplace |
444 |
| - |
445 |
| - @property |
446 |
| - def kind(self): |
447 |
| - return self._kind |
448 |
| - |
449 |
| - @property |
450 |
| - def na_position(self): |
451 |
| - return self._na_position |
452 |
| - |
453 |
| - @property |
454 |
| - def level(self): |
455 |
| - return self._level |
456 |
| - |
457 |
| - @property |
458 |
| - def sort_remaining(self): |
459 |
| - return self._sort_remaining |
460 |
| - |
461 |
| - @property |
462 |
| - def n_partition(self): |
463 |
| - return self._n_partition |
| 418 | + def __init__(self, output_types=None, **kw): |
| 419 | + super().__init__(_output_types=output_types, **kw) |
464 | 420 |
|
465 | 421 |
|
466 | 422 | class DataFramePSRSSortRegularSample(DataFramePSRSChunkOperand, DataFrameOperandMixin):
|
@@ -564,99 +520,49 @@ def execute(cls, ctx, op):
|
564 | 520 | class DataFramePSRSShuffle(MapReduceOperand, DataFrameOperandMixin):
|
565 | 521 | _op_type_ = OperandDef.PSRS_SHUFFLE
|
566 | 522 |
|
567 |
| - _sort_type = StringField("sort_type") |
| 523 | + sort_type = StringField("sort_type") |
568 | 524 |
|
569 | 525 | # for shuffle map
|
570 |
| - _axis = Int32Field("axis") |
571 |
| - _by = ListField("by") |
572 |
| - _ascending = BoolField("ascending") |
573 |
| - _inplace = BoolField("inplace") |
574 |
| - _na_position = StringField("na_position") |
575 |
| - _n_partition = Int32Field("n_partition") |
| 526 | + axis = Int32Field("axis") |
| 527 | + by = ListField("by") |
| 528 | + ascending = AnyField("ascending") |
| 529 | + inplace = BoolField("inplace") |
| 530 | + na_position = StringField("na_position") |
| 531 | + n_partition = Int32Field("n_partition") |
576 | 532 |
|
577 | 533 | # for sort_index
|
578 |
| - _level = ListField("level") |
579 |
| - _sort_remaining = BoolField("sort_remaining") |
| 534 | + level = ListField("level") |
| 535 | + sort_remaining = BoolField("sort_remaining") |
580 | 536 |
|
581 | 537 | # for shuffle reduce
|
582 |
| - _kind = StringField("kind") |
583 |
| - |
584 |
| - def __init__( |
585 |
| - self, |
586 |
| - sort_type=None, |
587 |
| - by=None, |
588 |
| - axis=None, |
589 |
| - ascending=None, |
590 |
| - n_partition=None, |
591 |
| - na_position=None, |
592 |
| - inplace=None, |
593 |
| - kind=None, |
594 |
| - level=None, |
595 |
| - sort_remaining=None, |
596 |
| - output_types=None, |
597 |
| - **kw |
598 |
| - ): |
599 |
| - super().__init__( |
600 |
| - _sort_type=sort_type, |
601 |
| - _by=by, |
602 |
| - _axis=axis, |
603 |
| - _ascending=ascending, |
604 |
| - _n_partition=n_partition, |
605 |
| - _na_position=na_position, |
606 |
| - _inplace=inplace, |
607 |
| - _kind=kind, |
608 |
| - _level=level, |
609 |
| - _sort_remaining=sort_remaining, |
610 |
| - _output_types=output_types, |
611 |
| - **kw |
612 |
| - ) |
613 |
| - |
614 |
| - @property |
615 |
| - def sort_type(self): |
616 |
| - return self._sort_type |
617 |
| - |
618 |
| - @property |
619 |
| - def by(self): |
620 |
| - return self._by |
621 |
| - |
622 |
| - @property |
623 |
| - def axis(self): |
624 |
| - return self._axis |
625 |
| - |
626 |
| - @property |
627 |
| - def ascending(self): |
628 |
| - return self._ascending |
| 538 | + kind = StringField("kind") |
629 | 539 |
|
630 |
| - @property |
631 |
| - def inplace(self): |
632 |
| - return self._inplace |
633 |
| - |
634 |
| - @property |
635 |
| - def na_position(self): |
636 |
| - return self._na_position |
637 |
| - |
638 |
| - @property |
639 |
| - def level(self): |
640 |
| - return self._level |
641 |
| - |
642 |
| - @property |
643 |
| - def sort_remaining(self): |
644 |
| - return self._sort_remaining |
645 |
| - |
646 |
| - @property |
647 |
| - def n_partition(self): |
648 |
| - return self._n_partition |
649 |
| - |
650 |
| - @property |
651 |
| - def kind(self): |
652 |
| - return self._kind |
| 540 | + def __init__(self, output_types=None, **kw): |
| 541 | + super().__init__(_output_types=output_types, **kw) |
653 | 542 |
|
654 | 543 | @property
|
655 | 544 | def output_limit(self):
|
656 | 545 | return 1
|
657 | 546 |
|
658 | 547 | @staticmethod
|
659 | 548 | def _calc_poses(src_cols, pivots, ascending=True):
|
| 549 | + if isinstance(ascending, list): |
| 550 | + for asc, col in zip(ascending, pivots.columns): |
| 551 | + # Make pivots available to use ascending order when mixed order specified |
| 552 | + if not asc: |
| 553 | + if pd.api.types.is_numeric_dtype(pivots.dtypes[col]): |
| 554 | + # for numeric dtypes, convert to negative is more efficient |
| 555 | + pivots[col] = -pivots[col] |
| 556 | + src_cols[col] = -src_cols[col] |
| 557 | + else: |
| 558 | + # for other types, convert to ReversedValue |
| 559 | + pivots[col] = pivots[col].map( |
| 560 | + lambda x: x |
| 561 | + if type(x) is _ReversedValue |
| 562 | + else _ReversedValue(x) |
| 563 | + ) |
| 564 | + ascending = True |
| 565 | + |
660 | 566 | records = src_cols.to_records(index=False)
|
661 | 567 | p_records = pivots.to_records(index=False)
|
662 | 568 | if ascending:
|
|
0 commit comments