Skip to content

Commit b7dae2a

Browse files
[Data] Revisiting make_async_gen to address issues with concurrency control for sequences of varying lengths (#51661)
## Why are these changes needed? This change addresses potential deadlocks inside `make_async_gen` when used in with functions producing sequences of wildly varying in lengths. Fundamentally `make_async_gen` was trying to solve 2 problems respective solutions for which never actually overlapped: 1. Implement parallel processing based on transforming an input *iterator* into an output one, while preserving back-pressure semantic, where input iterator should not be outpacing output iterator being consumed. 2. Implement parallel processing allowing ordering of the input iterator being preserved. These requirements coupled with the fact the transformation is expected to received and produce *iterators* are what led to erroneous deduction that it could be implemented: - Transforming iterators is very different from bijective mapping: we actually don't know how many input elements will result into a single output element (ie transformation is a black box that could be anything from 1-to-1 to many-to-many) - Preserving ordering of the transformation of *iterators* requires N input and output queues (1 per worker) as well as bot h producer and consumer fill/draw these queues in the same consistent order (without skipping!) - Because there could be no skipping (to preserve the order) there could be a case where some input AND output queues get full at the same time getting both producer and consumer stuck and not able to make progress To resolve that problem fundamentally we decoupling this 2 use-cases into 1. Preserving order: has N input and output queues, with the input queues being uncapped (while output queues still being capped at `queue_buffer_size`), meaning that incoming iterator will be unrolled eagerly by the producer (till exhaustion) 2. Not preserving order: has *1* input queue and N output queues, with both input and output queues being capped in size based `queue_buffer_size` configuration. This allows to implement back-pressure semantic where consumption speed will limit production speed (and amount of buffered data) Changes --- - Added stress-test successfully repro-ing deadlocks on the current impl - Added `preserve_ordering` param - Adjusted semantic to handle preserve_ordering=True/False scenarios separately - Beefed up existing tests - Tidying up ## Related issue number <!-- For example: "Closes #1234" --> --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
1 parent ad86b59 commit b7dae2a

File tree

4 files changed

+225
-77
lines changed

4 files changed

+225
-77
lines changed

python/ray/data/_internal/block_batching/iter_batches.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ def _async_iter_batches(
169169
# Run everything in a separate thread to not block the main thread when waiting
170170
# for streaming results.
171171
async_batch_iter = make_async_gen(
172-
ref_bundles, fn=_async_iter_batches, num_workers=1
172+
ref_bundles,
173+
fn=_async_iter_batches,
174+
num_workers=1,
175+
preserve_ordering=False,
173176
)
174177

175178
while True:
@@ -223,6 +226,7 @@ def threadpool_computations_format_collate(
223226
collated_iter = make_async_gen(
224227
base_iterator=batch_iter,
225228
fn=threadpool_computations_format_collate,
229+
preserve_ordering=False,
226230
num_workers=num_threadpool_workers,
227231
)
228232
else:

python/ray/data/_internal/util.py

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -914,29 +914,39 @@ def put(self, item, block=True, timeout=None):
914914
def make_async_gen(
915915
base_iterator: Iterator[T],
916916
fn: Callable[[Iterator[T]], Iterator[U]],
917+
preserve_ordering: bool,
917918
num_workers: int = 1,
918-
queue_buffer_size: int = 2,
919+
buffer_size: int = 1,
919920
) -> Generator[U, None, None]:
920-
921-
gen_id = random.randint(0, 2**31 - 1)
922-
923921
"""Returns a generator (iterator) mapping items from the
924922
provided iterator applying provided transformation in parallel (using a
925923
thread-pool).
926924
927-
NOTE: Even though the mapping is performed in parallel across N
928-
threads, this method provides crucial guarantee of preserving the
929-
ordering of the source iterator, ie that
925+
NOTE: There are some important constraints that needs to be carefully
926+
understood before using this method
927+
928+
1. If `preserve_ordering` is True
929+
a. This method would unroll input iterator eagerly (irrespective
930+
of the speed of resulting generator being consumed). This is necessary
931+
as we can not guarantee liveness of the algorithm AND preserving of the
932+
original ordering at the same time.
933+
934+
b. Resulting ordering of the output will "match" ordering of the input, ie
935+
that:
936+
iterator = [A1, A2, ... An]
937+
output iterator = [map(A1), map(A2), ..., map(An)]
930938
931-
iterator = [A1, A2, ... An]
932-
mapped iterator = [map(A1), map(A2), ..., map(An)]
939+
2. If `preserve_ordering` is False
940+
a. No more than `num_workers * (queue_buffer_size + 1)` elements will be
941+
fetched from the iterator
933942
934-
Preserving ordering is crucial to eliminate non-determinism in producing
935-
content of the blocks.
943+
b. Resulting ordering of the output is unspecified (and is
944+
non-deterministic)
936945
937946
Args:
938947
base_iterator: Iterator yielding elements to map
939948
fn: Transformation to apply to each element
949+
preserve_ordering: Whether ordering has to be preserved
940950
num_workers: The number of threads to use in the threadpool (defaults to 1)
941951
buffer_size: Number of objects to be buffered in its input/output
942952
queues (per queue; defaults to 2). Total number of objects held
@@ -949,9 +959,14 @@ def make_async_gen(
949959
elements mapped by provided transformation (while *preserving the ordering*)
950960
"""
951961

962+
gen_id = random.randint(0, 2**31 - 1)
963+
952964
if num_workers < 1:
953965
raise ValueError("Size of threadpool must be at least 1.")
954966

967+
# Signal handler used to interrupt workers when terminating
968+
interrupted_event = threading.Event()
969+
955970
# To apply transformations to elements in parallel *and* preserve the ordering
956971
# following invariants are established:
957972
# - Every worker is handled by standalone thread
@@ -967,16 +982,26 @@ def make_async_gen(
967982
# order as input queues) dequeues 1 mapped element at a time from each output
968983
# queue and yields it
969984
#
970-
# Signal handler used to interrupt workers when terminating
971-
interrupted_event = threading.Event()
985+
# However, in case when we're preserving the ordering we can not enforce the input
986+
# queue size as this could result in deadlocks since transformations could be
987+
# producing sequences of arbitrary length.
988+
#
989+
# Check `test_make_async_gen_varying_seq_length_stress_test` for more context on
990+
# this problem.
991+
if preserve_ordering:
992+
input_queue_buf_size = -1
993+
num_input_queues = num_workers
994+
else:
995+
input_queue_buf_size = (buffer_size + 1) * num_workers
996+
num_input_queues = 1
972997

973998
input_queues = [
974-
_InterruptibleQueue(queue_buffer_size, interrupted_event)
975-
for _ in range(num_workers)
999+
_InterruptibleQueue(input_queue_buf_size, interrupted_event)
1000+
for _ in range(num_input_queues)
9761001
]
1002+
9771003
output_queues = [
978-
_InterruptibleQueue(queue_buffer_size, interrupted_event)
979-
for _ in range(num_workers)
1004+
_InterruptibleQueue(buffer_size, interrupted_event) for _ in range(num_workers)
9801005
]
9811006

9821007
# Filling worker
@@ -985,11 +1010,16 @@ def _run_filling_worker():
9851010
# First, round-robin elements from the iterator into
9861011
# corresponding input queues (one by one)
9871012
for idx, item in enumerate(base_iterator):
988-
input_queues[idx % num_workers].put(item)
989-
990-
# Enqueue sentinel objects to signal end of the line
1013+
input_queues[idx % num_input_queues].put(item)
1014+
1015+
# NOTE: We have to Enqueue sentinel objects for every transforming
1016+
# worker:
1017+
# - In case of preserving order of ``num_queues`` == ``num_workers``
1018+
# we will enqueue 1 sentinel per queue
1019+
# - In case of NOT preserving order all ``num_workers`` sentinels
1020+
# will be enqueued into a single queue
9911021
for idx in range(num_workers):
992-
input_queues[idx].put(SENTINEL)
1022+
input_queues[idx % num_input_queues].put(SENTINEL)
9931023

9941024
except InterruptedError:
9951025
pass
@@ -1004,18 +1034,14 @@ def _run_filling_worker():
10041034
output_queue.put(e)
10051035

10061036
# Transforming worker
1007-
def _run_transforming_worker(worker_id: int):
1008-
input_queue = input_queues[worker_id]
1009-
output_queue = output_queues[worker_id]
1010-
1037+
def _run_transforming_worker(input_queue, output_queue):
10111038
try:
10121039
# Create iterator draining the queue, until it receives sentinel
10131040
#
10141041
# NOTE: `queue.get` is blocking!
10151042
input_queue_iter = iter(input_queue.get, SENTINEL)
10161043

1017-
mapped_iter = fn(input_queue_iter)
1018-
for result in mapped_iter:
1044+
for result in fn(input_queue_iter):
10191045
# Enqueue result of the transformation
10201046
output_queue.put(result)
10211047

@@ -1042,11 +1068,11 @@ def _run_transforming_worker(worker_id: int):
10421068
transforming_worker_threads = [
10431069
threading.Thread(
10441070
target=_run_transforming_worker,
1045-
name=f"map_tp_transforming_worker-{gen_id}-{worker_idx}",
1046-
args=(worker_idx,),
1071+
name=f"map_tp_transforming_worker-{gen_id}-{idx}",
1072+
args=(input_queues[idx % num_input_queues], output_queues[idx]),
10471073
daemon=True,
10481074
)
1049-
for worker_idx in range(num_workers)
1075+
for idx in range(num_workers)
10501076
]
10511077

10521078
for t in transforming_worker_threads:
@@ -1071,7 +1097,6 @@ def _run_transforming_worker(worker_id: int):
10711097
# order and one single element is dequeued (in a blocking way!) at a
10721098
# time from every individual output queue
10731099
#
1074-
non_empty_queues = []
10751100
empty_queues = []
10761101

10771102
# At every iteration only remaining non-empty queues
@@ -1086,10 +1111,12 @@ def _run_transforming_worker(worker_id: int):
10861111
if item is SENTINEL:
10871112
empty_queues.append(output_queue)
10881113
else:
1089-
non_empty_queues.append(output_queue)
10901114
yield item
10911115

1092-
remaining_output_queues = non_empty_queues
1116+
if empty_queues:
1117+
remaining_output_queues = [
1118+
q for q in remaining_output_queues if q not in empty_queues
1119+
]
10931120

10941121
finally:
10951122
# Set flag to interrupt workers (to make sure no dangling

python/ray/data/datasource/file_based_datasource.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ def read_task_fn():
286286
iter(read_paths),
287287
read_files,
288288
num_workers=num_threads,
289+
preserve_ordering=True,
290+
buffer_size=max(len(read_paths) // num_threads, 1),
289291
)
290292
else:
291293
logger.debug(f"Reading {len(read_paths)} files.")

0 commit comments

Comments
 (0)