Skip to content

Commit 3d26408

Browse files
alexeykudinkinzhaoch23
authored andcommitted
[Data] Fixing FuseOperators rule to properly handle the case of transformations drastically changing size of the dataset (ray-project#52570)
These changes are needed to make sure `FuseOperators` is appropriately handling potential impacts of transformations on the dataset sizes and whether fusion should occur in that case. For ex, consider following scenarios: ``` ds.filter(...).map_batches(..., bathc_size=1024) ``` Could not be fused as fusing it could potentially violate batching semantic -- fused operator gonna first gonna collect 1024 rows, then apply filter and subsequent transformation (which might expect exactly 1024 rows to be provided in a batch). ``` read_parquet(...).map_batches(..., bathc_size=1024) ``` Also could not be fused in that case, as fusing these 2 operations could lead to drastic reduction of the parallelism of the read operation: fused operator first gonna batch 1024 rows, then apply combined read->map transformation. This change: 1. Cleans up `can_modify_num_rows` method 2. Makes sure `Read` overrides `can_modify_num_rows` as well 3. Avoids fusion with ops that could be drastically modifying dataset size 4. Cleaning up `FuseOperators` rule 5. Adding telemetry --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com> Signed-off-by: zhaoch23 <c233zhao@uwaterloo.ca>
1 parent ced3a8b commit 3d26408

File tree

9 files changed

+282
-75
lines changed

9 files changed

+282
-75
lines changed

python/ray/data/_internal/execution/operators/map_operator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,10 @@ def __init__(self, min_rows_per_bundle: Optional[int]):
564564
bundle up to this target, but only exceed it if not doing so would
565565
result in an empty bundle.
566566
"""
567+
assert (
568+
min_rows_per_bundle is None or min_rows_per_bundle >= 0
569+
), "Min rows per bundle has to be non-negative"
570+
567571
self._min_rows_per_bundle = min_rows_per_bundle
568572
self._bundle_buffer: List[RefBundle] = []
569573
self._bundle_buffer_size = 0

python/ray/data/_internal/logical/interfaces/logical_operator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,7 @@ def is_lineage_serializable(self) -> bool:
8282
objects aren't available on the deserialized machine.
8383
"""
8484
return True
85+
86+
@classmethod
87+
def is_read_op(cls):
88+
return False

python/ray/data/_internal/logical/operators/map_operator.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def __init__(
3737
inspecting the logical plan of a Dataset.
3838
input_op: The operator preceding this operator in the plan DAG. The outputs
3939
of `input_op` will be the inputs to this operator.
40-
min_rows_per_bundled_input: The target number of rows to pass to
41-
``MapOperator._add_bundled_input()``.
40+
min_rows_per_bundled_input: Min number of rows a single bundle of blocks
41+
passed on to the task must possess.
4242
ray_remote_args: Args to provide to :func:`ray.remote`.
4343
ray_remote_args_fn: A function that returns a dictionary of remote args
4444
passed to each map worker. The purpose of this argument is to generate
@@ -177,7 +177,6 @@ def __init__(
177177
self._batch_format = batch_format
178178
self._zero_copy_batch = zero_copy_batch
179179

180-
@property
181180
def can_modify_num_rows(self) -> bool:
182181
return False
183182

@@ -210,7 +209,6 @@ def __init__(
210209
ray_remote_args=ray_remote_args,
211210
)
212211

213-
@property
214212
def can_modify_num_rows(self) -> bool:
215213
return False
216214

@@ -249,7 +247,6 @@ def __init__(
249247
ray_remote_args=ray_remote_args,
250248
)
251249

252-
@property
253250
def can_modify_num_rows(self) -> bool:
254251
return True
255252

@@ -285,7 +282,6 @@ def cols(self) -> Optional[List[str]]:
285282
def cols_rename(self) -> Optional[Dict[str, str]]:
286283
return self._cols_rename
287284

288-
@property
289285
def can_modify_num_rows(self) -> bool:
290286
return False
291287

@@ -318,7 +314,6 @@ def __init__(
318314
ray_remote_args=ray_remote_args,
319315
)
320316

321-
@property
322317
def can_modify_num_rows(self) -> bool:
323318
return True
324319

@@ -342,6 +337,5 @@ def __init__(
342337
def target_num_rows_per_block(self) -> int:
343338
return self._target_num_rows_per_block
344339

345-
@property
346340
def can_modify_num_rows(self) -> bool:
347341
return False

python/ray/data/_internal/logical/operators/one_to_one_operator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import abc
21
from typing import Optional
32

43
from ray.data._internal.logical.interfaces import LogicalOperator
@@ -29,11 +28,10 @@ def __init__(
2928
def input_dependency(self) -> LogicalOperator:
3029
return self._input_dependencies[0]
3130

32-
@property
33-
@abc.abstractmethod
3431
def can_modify_num_rows(self) -> bool:
3532
"""Whether this operator can modify the number of rows,
3633
i.e. number of input rows != number of output rows."""
34+
...
3735

3836

3937
class Limit(AbstractOneToOne):
@@ -50,7 +48,6 @@ def __init__(
5048
)
5149
self._limit = limit
5250

53-
@property
5451
def can_modify_num_rows(self) -> bool:
5552
return True
5653

python/ray/data/_internal/logical/operators/read_operator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,12 @@ def _cached_output_metadata(self) -> BlockMetadata:
9393
input_files=input_files,
9494
exec_stats=None,
9595
)
96+
97+
@classmethod
98+
def is_read_op(cls):
99+
return True
100+
101+
def can_modify_num_rows(self) -> bool:
102+
# NOTE: Returns true, since most of the readers expands its input
103+
# and produce many rows for every single row of the input
104+
return True

python/ray/data/_internal/logical/rules/limit_pushdown.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _apply_limit_pushdown(self, op: LogicalOperator) -> LogicalOperator:
6161
while (
6262
isinstance(new_input_into_limit, AbstractOneToOne)
6363
and not isinstance(new_input_into_limit, Read)
64-
and not getattr(new_input_into_limit, "can_modify_num_rows", False)
64+
and not new_input_into_limit.can_modify_num_rows()
6565
):
6666
new_input_into_limit_copy = copy.copy(new_input_into_limit)
6767
ops_between_new_input_and_limit.append(new_input_into_limit_copy)

python/ray/data/_internal/logical/rules/operator_fusion.py

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import logging
23
from typing import List, Optional, Tuple
34

45
from ray.data._internal.compute import (
@@ -38,6 +39,9 @@
3839
INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]
3940

4041

42+
logger = logging.getLogger(__name__)
43+
44+
4145
class FuseOperators(Rule):
4246
"""Fuses linear chains of compatible physical operators."""
4347

@@ -175,35 +179,22 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
175179
(
176180
isinstance(up_logical_op, AbstractMap)
177181
and isinstance(down_logical_op, AbstractMap)
182+
and self._can_fuse_map_ops(up_logical_op, down_logical_op)
178183
)
179184
or (
180185
isinstance(up_logical_op, AbstractMap)
181186
and isinstance(down_logical_op, RandomShuffle)
182187
)
188+
# Do not fuse Repartition operator if shuffle is disabled
189+
# (i.e. using split shuffle).
183190
or (
184191
isinstance(up_logical_op, AbstractMap)
185192
and isinstance(down_logical_op, Repartition)
193+
and down_logical_op._shuffle
186194
)
187195
):
188196
return False
189197

190-
# Do not fuse Repartition operator if shuffle is disabled
191-
# (i.e. using split shuffle).
192-
if isinstance(down_logical_op, Repartition) and not down_logical_op._shuffle:
193-
return False
194-
195-
if isinstance(down_logical_op, AbstractMap) and isinstance(
196-
up_logical_op, AbstractMap
197-
):
198-
if (
199-
self._fuse_compute_strategy(
200-
up_logical_op._compute,
201-
down_logical_op._compute,
202-
)
203-
is None
204-
):
205-
return False
206-
207198
# Only fuse if the ops' remote arguments are compatible.
208199
if not _are_remote_args_compatible(
209200
getattr(up_logical_op, "_ray_remote_args", {}),
@@ -228,8 +219,9 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
228219
# Otherwise, ops are compatible for fusion.
229220
return True
230221

222+
@classmethod
231223
def _fuse_compute_strategy(
232-
self, up_compute: ComputeStrategy, down_compute: ComputeStrategy
224+
cls, up_compute: ComputeStrategy, down_compute: ComputeStrategy
233225
) -> Optional[ComputeStrategy]:
234226
"""Fuse the compute strategies of the upstream and downstream operators.
235227
Returns None if they are not compatible.
@@ -302,21 +294,10 @@ def _get_fused_map_operator(
302294
assert isinstance(down_logical_op, AbstractMap)
303295
assert isinstance(up_logical_op, AbstractMap)
304296

305-
# Merge minimum block sizes.
306-
down_min_rows_per_bundled_input = down_logical_op._min_rows_per_bundled_input
307-
up_min_rows_per_bundled_input = up_logical_op._min_rows_per_bundled_input
308-
309-
if (
310-
down_min_rows_per_bundled_input is not None
311-
and up_min_rows_per_bundled_input is not None
312-
):
313-
min_rows_per_bundled_input = max(
314-
down_min_rows_per_bundled_input, up_min_rows_per_bundled_input
315-
)
316-
elif up_min_rows_per_bundled_input is not None:
317-
min_rows_per_bundled_input = up_min_rows_per_bundled_input
318-
else:
319-
min_rows_per_bundled_input = down_min_rows_per_bundled_input
297+
# Derive min num rows per input bundle
298+
min_rows_per_bundled_input = self._derive_bundle_min_num_rows(
299+
down_logical_op, up_logical_op
300+
)
320301

321302
target_max_block_size = self._get_merged_target_max_block_size(
322303
up_op.target_max_block_size, down_op.target_max_block_size
@@ -389,6 +370,27 @@ def _get_fused_map_operator(
389370
# Return the fused physical operator.
390371
return op
391372

373+
@classmethod
374+
def _derive_bundle_min_num_rows(
375+
cls,
376+
down_logical_op: AbstractMap,
377+
up_logical_op: AbstractMap,
378+
) -> Optional[int]:
379+
us_bundle_min_rows_req = up_logical_op._min_rows_per_bundled_input
380+
ds_bundle_min_rows_req = down_logical_op._min_rows_per_bundled_input
381+
382+
# In case neither of the ops specify `min_rows_per_bundled_input`,
383+
# return None
384+
if us_bundle_min_rows_req is None and ds_bundle_min_rows_req is None:
385+
return None
386+
387+
# Target min bundle size is selected as max of upstream and downstream ones
388+
# such that it could satisfy both of their requirements
389+
return max(
390+
ds_bundle_min_rows_req or 0,
391+
us_bundle_min_rows_req or 0,
392+
)
393+
392394
def _get_fused_all_to_all_operator(
393395
self, down_op: AllToAllOperator, up_op: MapOperator
394396
) -> AllToAllOperator:
@@ -460,6 +462,50 @@ def fused_all_to_all_transform_fn(
460462
# Return the fused physical operator.
461463
return op
462464

465+
@classmethod
466+
def _can_fuse_map_ops(
467+
cls,
468+
upstream_op: AbstractMap,
469+
downstream_op: AbstractMap,
470+
) -> bool:
471+
if (
472+
cls._fuse_compute_strategy(
473+
upstream_op._compute,
474+
downstream_op._compute,
475+
)
476+
is None
477+
):
478+
return False
479+
480+
# Do not fuse Map operators in case:
481+
#
482+
# - Upstream could (potentially) drastically modify number of rows, while
483+
# - Downstream has `min_rows_per_input_bundle` specified
484+
#
485+
# Fusing such transformations is not desirable as it could
486+
#
487+
# - Drastically reduce parallelism for the upstream up (for ex, if
488+
# fusing ``Read->MapBatches(batch_size=...)`` with large enough batch-size
489+
# could drastically reduce parallelism level of the Read op)
490+
#
491+
# - Potentially violate batching semantic by fusing
492+
# ``Filter->MapBatches(batch_size=...)``
493+
#
494+
if (
495+
upstream_op.can_modify_num_rows()
496+
and downstream_op._min_rows_per_bundled_input is not None
497+
):
498+
logger.debug(
499+
f"Upstream operator '{upstream_op}' could be modifying # of input "
500+
f"rows, while downstream operator '{downstream_op}' expects at least "
501+
f"{downstream_op._min_rows_per_bundled_input} rows in a batch. "
502+
f"Skipping fusion"
503+
)
504+
505+
return False
506+
507+
return True
508+
463509

464510
def _are_remote_args_compatible(prev_args, next_args):
465511
"""Check if Ray remote arguments are compatible for merging."""

python/ray/data/_internal/plan.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from ray.data._internal.execution.interfaces import RefBundle
1111
from ray.data._internal.logical.interfaces.logical_operator import LogicalOperator
1212
from ray.data._internal.logical.interfaces.logical_plan import LogicalPlan
13-
from ray.data._internal.logical.operators.from_operators import AbstractFrom
14-
from ray.data._internal.logical.operators.input_data_operator import InputData
1513
from ray.data._internal.logical.operators.read_operator import Read
1614
from ray.data._internal.stats import DatasetStats
1715
from ray.data._internal.util import unify_block_metadata_schema
@@ -136,7 +134,7 @@ def generate_logical_plan_string(
136134
):
137135
"""Traverse (DFS) the LogicalPlan DAG and
138136
return a string representation of the operators."""
139-
if isinstance(op, (Read, InputData, AbstractFrom)):
137+
if not op.input_dependencies or op.is_read_op():
140138
return curr_str, depth
141139

142140
curr_max_depth = depth
@@ -168,19 +166,27 @@ def generate_logical_plan_string(
168166
count = self._snapshot_metadata.num_rows
169167
else:
170168
# This plan hasn't executed any operators.
171-
sources = self._logical_plan.sources()
169+
has_n_ary_operator = False
170+
dag = self._logical_plan.dag
171+
172+
while not dag.is_read_op() and dag.input_dependencies:
173+
if len(dag.input_dependencies) > 1:
174+
has_n_ary_operator = True
175+
break
176+
177+
dag = dag.input_dependencies[0]
178+
172179
# TODO(@bveeramani): Handle schemas for n-ary operators like `Union`.
173-
if len(sources) > 1:
174-
# Multiple sources, cannot determine schema.
180+
if has_n_ary_operator:
175181
schema = None
176182
count = None
177183
else:
178-
assert len(sources) == 1
184+
assert dag.is_read_op() or not dag.input_dependencies, dag
179185
plan = ExecutionPlan(
180186
DatasetStats(metadata={}, parent=None),
181187
self._context,
182188
)
183-
plan.link_logical_plan(LogicalPlan(sources[0], plan._context))
189+
plan.link_logical_plan(LogicalPlan(dag, plan._context))
184190
schema = plan.schema()
185191
count = plan.meta_count()
186192
else:
@@ -587,7 +593,8 @@ def is_read_only(self, root_op: Optional[LogicalOperator] = None) -> bool:
587593
the LogicalPlan is used."""
588594
if root_op is None:
589595
root_op = self._logical_plan.dag
590-
return isinstance(root_op, Read) and len(root_op.input_dependencies) == 0
596+
597+
return root_op.is_read_op()
591598

592599
def has_computed_output(self) -> bool:
593600
"""Whether this plan has a computed snapshot for the final operator, i.e. for

0 commit comments

Comments
 (0)