Skip to content

Commit 0807fd0

Browse files
committed
#84 and #91 refactoring event loop policy
1 parent e3e93e0 commit 0807fd0

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

bulk_chain/api.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515

1616

1717
INFER_MODES = {
18-
"single": lambda llm, batch: [llm.ask(prompt) for prompt in batch],
19-
"single_stream": lambda llm, batch: [llm.ask_stream(prompt) for prompt in batch],
20-
"batch": lambda llm, batch: llm.ask(batch),
21-
"batch_async": lambda llm, batch: AsyncioService.run_tasks(batch=batch, async_handler=llm.ask_async),
22-
"batch_stream_async": lambda llm, batch: AsyncioService.run_tasks(batch=batch, async_handler=llm.ask_stream_async),
18+
"single": lambda llm, batch, **kwargs: [llm.ask(prompt) for prompt in batch],
19+
"single_stream": lambda llm, batch, **kwargs: [llm.ask_stream(prompt) for prompt in batch],
20+
"batch": lambda llm, batch, **kwargs: llm.ask(batch),
21+
"batch_async": lambda llm, batch, **kwargs: AsyncioService.run_tasks(
22+
batch=batch, async_handler=llm.ask_async, event_loop=kwargs.get("event_loop")
23+
),
24+
"batch_stream_async": lambda llm, batch, **kwargs: AsyncioService.run_tasks(
25+
batch=batch, async_handler=llm.ask_stream_async, event_loop=kwargs.get("event_loop")
26+
),
2327
}
2428

2529

@@ -35,7 +39,7 @@ def _iter_batch_prompts(c, batch_content_it, **kwargs):
3539
yield ind_in_batch, content
3640

3741

38-
def __handle_agen_to_gen(gen, **kwargs):
42+
def __handle_agen_to_gen(handle, batch, event_loop):
3943
""" This handler provides conversion of the async generator to generator (sync).
4044
"""
4145

@@ -45,16 +49,17 @@ async def wrapper(index, agen):
4549
yield index, item
4650
return [wrapper(i, agen) for i, agen in enumerate(async_gens)]
4751

52+
agen_list = handle(batch, event_loop=event_loop)
53+
4854
it = AsyncioService.async_gen_to_iter(
49-
gen=AsyncioService.merge_generators(*__wrap_with_index(gen)),
50-
loop=asyncio.get_event_loop()
51-
)
55+
gen=AsyncioService.merge_generators(*__wrap_with_index(agen_list)),
56+
loop=event_loop)
5257

5358
for ind_in_batch, chunk in it:
5459
yield ind_in_batch, str(chunk)
5560

5661

57-
def __handle_gen(gen, **kwargs):
62+
def __handle_gen(handle, batch, event_loop):
5863
""" This handler deals with the iteration of each individual element of the batch.
5964
"""
6065

@@ -67,15 +72,16 @@ def _iter_entry_content(entry):
6772
else:
6873
raise Exception(f"Non supported type `{type(entry)}` for handling output from batch")
6974

70-
for ind_in_batch, entry in enumerate(gen):
75+
for ind_in_batch, entry in enumerate(handle(batch, event_loop=event_loop)):
7176
for chunk in _iter_entry_content(entry=entry):
7277
yield ind_in_batch, chunk
7378

7479

7580
def _iter_chunks(p_column, batch_content_it, **kwargs):
7681
handler = __handle_agen_to_gen if kwargs["infer_mode"] == "batch_stream_async" else __handle_gen
7782
p_batch = [item[p_column] for item in batch_content_it]
78-
for ind_in_batch, chunk in handler(kwargs["handle_batch_func"](p_batch), **kwargs):
83+
it = handler(handle=kwargs["handle_batch_func"], batch=p_batch, event_loop=kwargs["event_loop"])
84+
for ind_in_batch, chunk in it:
7985
yield ind_in_batch, chunk
8086

8187

@@ -124,7 +130,8 @@ def _infer_batch(batch, batch_ind, schema, return_mode, cols=None, **kwargs):
124130

125131

126132
def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None,
127-
infer_mode="batch", return_mode="batch", attempts=1, **kwargs):
133+
infer_mode="batch", return_mode="batch", attempts=1, event_loop=None,
134+
**kwargs):
128135
""" This method represent Python API aimed at application of `llm` towards
129136
iterator of input_dicts via cache_target that refers to the SQLite using
130137
the given `schema`
@@ -133,6 +140,10 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None,
133140
assert (return_mode in ["batch", "chunk", "record"])
134141
assert (isinstance(llm, BaseLM))
135142

143+
# Setup event loop.
144+
event_loop = asyncio.get_event_loop_policy().get_event_loop() \
145+
if event_loop is None else event_loop
146+
136147
# Quick initialization of the schema.
137148
if isinstance(schema, str):
138149
schema = JsonService.read(schema)
@@ -144,8 +155,10 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None,
144155
input_dicts_it
145156
)
146157

147-
handle_batch_func = lambda batch: INFER_MODES[infer_mode](
148-
llm, DataService.limit_prompts(batch, limit=limit_prompt)
158+
handle_batch_func = lambda batch, **handle_kwargs: INFER_MODES[infer_mode](
159+
llm,
160+
DataService.limit_prompts(batch, limit=limit_prompt),
161+
**handle_kwargs
149162
)
150163

151164
# Optional wrapping into attempts.
@@ -166,6 +179,7 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None,
166179
handle_missed_value_func=lambda *_: None,
167180
return_mode=return_mode,
168181
schema=schema,
182+
event_loop=event_loop,
169183
**kwargs)
170184
for batch_ind, batch in enumerate(BatchIterator(prompts_it, batch_size=batch_size)))
171185

bulk_chain/core/service_asyncio.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ async def _run_generator(gen, output_queue, idx):
1919

2020

2121
@staticmethod
22-
def run_tasks(**tasks_kwargs):
23-
return asyncio.get_event_loop().run_until_complete(
24-
AsyncioService._run_tasks_async(**tasks_kwargs)
25-
)
22+
def run_tasks(event_loop, **tasks_kwargs):
23+
return event_loop.run_until_complete(AsyncioService._run_tasks_async(**tasks_kwargs))
2624

2725
@staticmethod
2826
async def merge_generators(*gens: AsyncGenerator[Any, None]) -> AsyncGenerator[Any, None]:

0 commit comments

Comments
 (0)