Skip to content

Commit ced3a8b

Browse files
alexeykudinkinzhaoch23
authored andcommitted
[Data] Fixing Optimizer to apply rules until plan stabilize; (ray-project#52663)
Fixing tests <!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? 1. Fixing `Optimizer` to continue applying rules until plans stop changing 2. Fixing tests to avoid making assumption about number of blocks / data being sorted 3. Minor cleanups --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com> Signed-off-by: zhaoch23 <c233zhao@uwaterloo.ca>
1 parent cd5275c commit ced3a8b

File tree

9 files changed

+49
-52
lines changed

9 files changed

+49
-52
lines changed

python/ray/data/_internal/execution/operators/map_transformer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ def output_block_size_option(self):
9090
return self._output_block_size_option
9191

9292
def set_target_max_block_size(self, target_max_block_size: int):
93-
assert (
94-
self._output_block_size_option is None and target_max_block_size is not None
95-
)
93+
assert target_max_block_size is not None
9694
self._output_block_size_option = OutputBlockSizeOption(
9795
target_max_block_size=target_max_block_size
9896
)
@@ -105,10 +103,7 @@ def target_max_block_size(self):
105103
return self._output_block_size_option.target_max_block_size
106104

107105
def set_target_num_rows_per_block(self, target_num_rows_per_block: int):
108-
assert (
109-
self._output_block_size_option is None
110-
and target_num_rows_per_block is not None
111-
)
106+
assert target_num_rows_per_block is not None
112107
self._output_block_size_option = OutputBlockSizeOption(
113108
target_num_rows_per_block=target_num_rows_per_block
114109
)

python/ray/data/_internal/logical/interfaces/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterator, List, Callable
1+
from typing import Callable, Iterator, List
22

33

44
class Operator:

python/ray/data/_internal/logical/interfaces/optimizer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def rules(self) -> List[Rule]:
3434

3535
def optimize(self, plan: Plan) -> Plan:
3636
"""Optimize operators with a list of rules."""
37-
for rule in self.rules:
38-
plan = rule.apply(plan)
37+
# Apply rules until the plan is not changed
38+
previous_plan = plan
39+
while True:
40+
for rule in self.rules:
41+
plan = rule.apply(plan)
42+
# TODO: Eventually we should implement proper equality.
43+
# Using str to check equality seems brittle
44+
if plan.dag.dag_str == previous_plan.dag.dag_str:
45+
break
46+
previous_plan = plan
3947
return plan

python/ray/data/_internal/util.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,13 +785,14 @@ def find_partition_index(
785785
# is an index into the ascending order of ``col_vals``, so we need
786786
# to subtract it from ``len(col_vals)`` to get the index in the
787787
# original descending order of ``col_vals``.
788+
sorter = np.arange(len(col_vals) - 1, -1, -1)
788789
left = prevleft + (
789790
len(col_vals)
790791
- np.searchsorted(
791792
col_vals,
792793
desired_val,
793794
side="right",
794-
sorter=np.arange(len(col_vals) - 1, -1, -1),
795+
sorter=sorter,
795796
)
796797
)
797798
right = prevleft + (
@@ -800,7 +801,7 @@ def find_partition_index(
800801
col_vals,
801802
desired_val,
802803
side="left",
803-
sorter=np.arange(len(col_vals) - 1, -1, -1),
804+
sorter=sorter,
804805
)
805806
)
806807
else:

python/ray/data/read_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@
108108

109109
from ray.data._internal.datasource.tfrecords_datasource import TFXReadOptions
110110

111-
112111
T = TypeVar("T")
113112

114113
logger = logging.getLogger(__name__)

python/ray/data/tests/test_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
34
import ray
45

56

python/ray/data/tests/test_execution_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _check_valid_plan_and_result(
6868
expected_physical_plan_ops=None,
6969
):
7070
assert ds.take_all() == expected_result
71-
assert str(ds._plan._logical_plan.dag) == expected_plan
71+
assert ds._plan._logical_plan.dag.dag_str == expected_plan
7272

7373
expected_physical_plan_ops = expected_physical_plan_ops or []
7474
for op in expected_physical_plan_ops:

python/ray/data/tests/test_json.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_json_read_partitioning(ray_start_regular_shared, tmp_path):
3737

3838
ds = ray.data.read_json(path)
3939

40-
assert ds.take() == [
40+
assert sorted(ds.take(), key=lambda row: row["number"]) == [
4141
{"number": 0, "string": "foo", "country": "us"},
4242
{"number": 1, "string": "bar", "country": "us"},
4343
]
@@ -103,7 +103,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url):
103103
df2.to_json(path2, orient="records", lines=True, storage_options=storage_options)
104104
ds = ray.data.read_json(path, filesystem=fs)
105105
df = pd.concat([df1, df2], ignore_index=True)
106-
dsdf = ds.to_pandas()
106+
dsdf = ds.to_pandas().sort_values(by=["one", "two"]).reset_index(drop=True)
107107
assert df.equals(dsdf)
108108
if fs is None:
109109
shutil.rmtree(path)
@@ -136,7 +136,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url):
136136
)
137137
ds = ray.data.read_json([path1, path2], filesystem=fs)
138138
df = pd.concat([df1, df2, df3], ignore_index=True)
139-
dsdf = ds.to_pandas()
139+
dsdf = ds.to_pandas().sort_values(by=["one", "two"]).reset_index(drop=True)
140140
assert df.equals(dsdf)
141141
if fs is None:
142142
shutil.rmtree(path1)
@@ -159,7 +159,7 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url):
159159
df2.to_json(path2, orient="records", lines=True, storage_options=storage_options)
160160
ds = ray.data.read_json([dir_path, path2], filesystem=fs)
161161
df = pd.concat([df1, df2], ignore_index=True)
162-
dsdf = ds.to_pandas()
162+
dsdf = ds.to_pandas().sort_values(by=["one", "two"]).reset_index(drop=True)
163163
assert df.equals(dsdf)
164164
if fs is None:
165165
shutil.rmtree(dir_path)
@@ -189,9 +189,8 @@ def test_json_read(ray_start_regular_shared, fs, data_path, endpoint_url):
189189
)
190190

191191
ds = ray.data.read_json(path, filesystem=fs)
192-
assert ds._plan.initial_num_blocks() == 2
193192
df = pd.concat([df1, df2], ignore_index=True)
194-
dsdf = ds.to_pandas()
193+
dsdf = ds.to_pandas().sort_values(by=["one", "two"]).reset_index(drop=True)
195194
assert df.equals(dsdf)
196195
if fs is None:
197196
shutil.rmtree(path)
@@ -410,7 +409,9 @@ def test_json_read_with_parse_options(
410409
(lazy_fixture("s3_fs"), lazy_fixture("s3_path"), lazy_fixture("s3_server")),
411410
],
412411
)
412+
@pytest.mark.parametrize("style", [PartitionStyle.HIVE, PartitionStyle.DIRECTORY])
413413
def test_json_read_partitioned_with_filter(
414+
style,
414415
ray_start_regular_shared,
415416
fs,
416417
data_path,
@@ -476,12 +477,13 @@ def skip_unpartitioned(kv_dict):
476477
ray.get(skipped_file_counter.reset.remote())
477478

478479

479-
def test_jsonl_mixed_types(ray_start_regular_shared, tmp_path):
480+
@pytest.mark.parametrize("override_num_blocks", [None, 1, 3])
481+
def test_jsonl_lists(ray_start_regular_shared, tmp_path, override_num_blocks):
480482
"""Test JSONL with mixed types and schemas."""
481483
data = [
482-
{"a": 1, "b": {"c": 2}}, # Nested dict
483-
{"a": 1, "b": {"c": 3}}, # Nested dict
484-
{"a": 1, "b": {"c": {"hello": "world"}}}, # Mixed Schema
484+
["ray", "rocks", "hello"],
485+
["oh", "no"],
486+
["rocking", "with", "ray"],
485487
]
486488

487489
path = os.path.join(tmp_path, "test.jsonl")
@@ -490,21 +492,20 @@ def test_jsonl_mixed_types(ray_start_regular_shared, tmp_path):
490492
json.dump(record, f)
491493
f.write("\n")
492494

493-
ds = ray.data.read_json(path, lines=True)
495+
ds = ray.data.read_json(path, lines=True, override_num_blocks=override_num_blocks)
494496
result = ds.take_all()
495497

496-
assert result[0] == data[0] # Dict stays as is
497-
assert result[1] == data[1]
498-
assert result[2] == data[2]
498+
assert result[0] == {"0": "ray", "1": "rocks", "2": "hello"}
499+
assert result[1] == {"0": "oh", "1": "no", "2": None}
500+
assert result[2] == {"0": "rocking", "1": "with", "2": "ray"}
499501

500502

501-
@pytest.mark.parametrize("override_num_blocks", [None, 1, 3])
502-
def test_jsonl_lists(ray_start_regular_shared, tmp_path, override_num_blocks):
503+
def test_jsonl_mixed_types(ray_start_regular_shared, tmp_path):
503504
"""Test JSONL with mixed types and schemas."""
504505
data = [
505-
["ray", "rocks", "hello"],
506-
["oh", "no"],
507-
["rocking", "with", "ray"],
506+
{"a": 1, "b": {"c": 2}}, # Nested dict
507+
{"a": 1, "b": {"c": 3}}, # Nested dict
508+
{"a": 1, "b": {"c": {"hello": "world"}}}, # Mixed Schema
508509
]
509510

510511
path = os.path.join(tmp_path, "test.jsonl")
@@ -513,12 +514,12 @@ def test_jsonl_lists(ray_start_regular_shared, tmp_path, override_num_blocks):
513514
json.dump(record, f)
514515
f.write("\n")
515516

516-
ds = ray.data.read_json(path, lines=True, override_num_blocks=override_num_blocks)
517+
ds = ray.data.read_json(path, lines=True)
517518
result = ds.take_all()
518519

519-
assert result[0] == {"0": "ray", "1": "rocks", "2": "hello"}
520-
assert result[1] == {"0": "oh", "1": "no", "2": None}
521-
assert result[2] == {"0": "rocking", "1": "with", "2": "ray"}
520+
assert result[0] == data[0] # Dict stays as is
521+
assert result[1] == data[1]
522+
assert result[2] == data[2]
522523

523524

524525
@pytest.mark.parametrize(

python/ray/data/tests/test_webdataset.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_webdataset_read(ray_start_2_cpus, tmp_path):
3939
tf.write(f"{i}.b", str(i**2).encode("utf-8"))
4040
assert os.path.exists(path)
4141
assert len(glob.glob(f"{tmp_path}/*.tar")) == 1
42-
ds = ray.data.read_webdataset(paths=[str(tmp_path)], override_num_blocks=1)
42+
ds = ray.data.read_webdataset(paths=[str(tmp_path)])
4343
samples = ds.take(100)
4444
assert len(samples) == 100
4545
for i, sample in enumerate(samples):
@@ -92,18 +92,14 @@ def test_webdataset_suffixes(ray_start_2_cpus, tmp_path):
9292
assert len(glob.glob(f"{tmp_path}/*.tar")) == 1
9393

9494
# test simple suffixes
95-
ds = ray.data.read_webdataset(
96-
paths=[str(tmp_path)], override_num_blocks=1, suffixes=["txt", "cls"]
97-
)
95+
ds = ray.data.read_webdataset(paths=[str(tmp_path)], suffixes=["txt", "cls"])
9896
samples = ds.take(100)
9997
assert len(samples) == 100
10098
for i, sample in enumerate(samples):
10199
assert set(sample.keys()) == {"__url__", "__key__", "txt", "cls"}
102100

103101
# test fnmatch patterns for suffixes
104-
ds = ray.data.read_webdataset(
105-
paths=[str(tmp_path)], override_num_blocks=1, suffixes=["*.txt", "*.cls"]
106-
)
102+
ds = ray.data.read_webdataset(paths=[str(tmp_path)], suffixes=["*.txt", "*.cls"])
107103
samples = ds.take(100)
108104
assert len(samples) == 100
109105
for i, sample in enumerate(samples):
@@ -113,9 +109,7 @@ def test_webdataset_suffixes(ray_start_2_cpus, tmp_path):
113109
def select(name):
114110
return name.endswith("txt")
115111

116-
ds = ray.data.read_webdataset(
117-
paths=[str(tmp_path)], override_num_blocks=1, suffixes=select
118-
)
112+
ds = ray.data.read_webdataset(paths=[str(tmp_path)], suffixes=select)
119113
samples = ds.take(100)
120114
assert len(samples) == 100
121115
for i, sample in enumerate(samples):
@@ -127,9 +121,7 @@ def renamer(name):
127121
print("***", name, result)
128122
return result
129123

130-
ds = ray.data.read_webdataset(
131-
paths=[str(tmp_path)], override_num_blocks=1, filerename=renamer
132-
)
124+
ds = ray.data.read_webdataset(paths=[str(tmp_path)], filerename=renamer)
133125
samples = ds.take(100)
134126
assert len(samples) == 100
135127
for i, sample in enumerate(samples):
@@ -198,7 +190,7 @@ def test_webdataset_coding(ray_start_2_cpus, tmp_path):
198190
assert len(paths) == 1
199191
path = paths[0]
200192
assert os.path.exists(path)
201-
ds = ray.data.read_webdataset(paths=[str(tmp_path)], override_num_blocks=1)
193+
ds = ray.data.read_webdataset(paths=[str(tmp_path)])
202194
samples = ds.take(1)
203195
assert len(samples) == 1
204196
for sample in samples:
@@ -218,7 +210,7 @@ def test_webdataset_coding(ray_start_2_cpus, tmp_path):
218210

219211
# test the format argument to the default decoder and multiple decoders
220212
ds = ray.data.read_webdataset(
221-
paths=[str(tmp_path)], override_num_blocks=1, decoder=["PIL", custom_decoder]
213+
paths=[str(tmp_path)], decoder=["PIL", custom_decoder]
222214
)
223215
samples = ds.take(1)
224216
assert len(samples) == 1

0 commit comments

Comments
 (0)