Skip to content

Support map-reduce pattern in concurrency utils #17

@adivekar-utexas

Description

@adivekar-utexas

Right now, a lot of the orchestration of using dispatch_executor and dispatch has to be done manually, even though a lot of the parallel processing use-cases fit into a common pattern:

  1. Get a sequence of data
  2. Process each item in the sequence
  3. (Optional) Aggregate.

This is essentially MapReduce (where the Reduce step is optional). We should have a way to batch-submit and retrieve results iteratively.

Right now there is some version of this implementation in dispatch_apply, but that is not fleshed out. Instead, we want a paradigm which can be used a bit like this:

Example Usage (parallel map only):

>>> def process_query_df(query_id, query_df):
>>>     query_df: pd.DataFrame = set_ranks(query_df, sort_col="example_id")
>>>     query_df['product_text'] = query_df['product_text'].apply(clean_text)
>>>     return query_df['product_text'].apply(len).mean()
>>>
>>> for mean_query_doc_lens in map_reduce(
>>>     retrieval_dataset.groupby("query_id"),
>>>     fn=process_query_df,
>>>     parallelize='processes',
>>>     max_workers=20,
>>>     pbar=dict(miniters=1),
>>>     batch_size=30,
>>>     iter=True,
>>> ):
>>>     ## Prints the output of each call to process_query_df, which is the mean length of
>>>     ## product_text for each query_df:
>>>     print(mean_query_doc_lens)
>>> 1171.090909090909
>>> 1317.7931034482758
>>> 2051.945945945946
>>> 1249.9375
>>> ...

Example Usage (parallel map and reduce):

>>> def process_query_df(query_id, query_df):
>>>     query_df: pd.DataFrame = set_ranks(query_df, sort_col="example_id")
>>>     query_df['product_text'] = query_df['product_text'].apply(clean_text)
>>>     return query_df['product_text'].apply(len).sum()
>>>
>>> def reduce_query_df(l):
>>>     ## Applied to every batch of outputs from process_query_df
>>>     ## and then again to thr final list of reduced outputs:
>>>     return sum(l)
>>>
>>> print(map_reduce(
>>>     retrieval_dataset.groupby("query_id"),
>>>     fn=process_query_df,
>>>     parallelize='processes',
>>>     max_workers=20,
>>>     pbar=True,
>>>     batch_size=30,
>>>     reduce_fn=reduce_query_df,
>>> ))
>>> 374453878

Metadata

Metadata

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions