Skip to content

Commit 647b74a

Browse files
authored
[Data] fix RandomAccessDataset.multiget returning unexpected values for missing keys (#44769)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> As stated in #44768, the current implementation of `multiget` based on `np.searchsorted` does not check for missing keys. I added the required checks and updated unit test for this case. ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> Closes #44768 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( Signed-off-by: Wu Yufei <wuyufei.2000@bytedance.com>
1 parent 2a1677c commit 647b74a

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

python/ray/data/random_access_dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,10 @@ def multiget(self, block_indices, keys):
234234
col = block[self.key_field]
235235
indices = np.searchsorted(col, keys)
236236
acc = BlockAccessor.for_block(block)
237-
result = [acc._get_row(i) for i in indices]
238-
# assert result == [self._get(i, k) for i, k in zip(block_indices, keys)]
237+
result = [
238+
acc._get_row(i) if k1.as_py() == k2 else None
239+
for i, k1, k2 in zip(indices, col.take(indices), keys)
240+
]
239241
else:
240242
result = [self._get(i, k) for i, k in zip(block_indices, keys)]
241243
self.total_time += time.perf_counter() - start

python/ray/data/tests/test_random_access.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,28 @@
88
@pytest.mark.parametrize("pandas", [False, True])
99
def test_basic(ray_start_regular_shared, pandas):
1010
ds = ray.data.range(100, override_num_blocks=10)
11+
ds = ds.add_column("key", lambda b: b["id"] * 2)
1112
ds = ds.add_column("embedding", lambda b: b["id"] ** 2)
1213
if not pandas:
1314
ds = ds.map_batches(
1415
lambda df: pyarrow.Table.from_pandas(df), batch_format="pandas"
1516
)
1617

17-
rad = ds.to_random_access_dataset("id", num_workers=1)
18+
rad = ds.to_random_access_dataset("key", num_workers=1)
19+
20+
def expected(i):
21+
return {"id": i, "key": i * 2, "embedding": i**2}
1822

1923
# Test get.
2024
assert ray.get(rad.get_async(-1)) is None
21-
assert ray.get(rad.get_async(100)) is None
25+
assert ray.get(rad.get_async(200)) is None
2226
for i in range(100):
23-
assert ray.get(rad.get_async(i)) == {"id": i, "embedding": i**2}
24-
25-
def expected(i):
26-
return {"id": i, "embedding": i**2}
27+
assert ray.get(rad.get_async(i * 2 + 1)) is None
28+
assert ray.get(rad.get_async(i * 2)) == expected(i)
2729

2830
# Test multiget.
29-
results = rad.multiget([-1] + list(range(10)) + [100])
30-
assert results == [None] + [expected(i) for i in range(10)] + [None]
31+
results = rad.multiget([-1] + list(range(0, 20, 2)) + list(range(1, 21, 2)) + [200])
32+
assert results == [None] + [expected(i) for i in range(10)] + [None] * 10 + [None]
3133

3234

3335
def test_empty_blocks(ray_start_regular_shared):

0 commit comments

Comments
 (0)