15
15
16
16
17
17
INFER_MODES = {
18
- "single" : lambda llm , batch : [llm .ask (prompt ) for prompt in batch ],
19
- "single_stream" : lambda llm , batch : [llm .ask_stream (prompt ) for prompt in batch ],
20
- "batch" : lambda llm , batch : llm .ask (batch ),
21
- "batch_async" : lambda llm , batch : AsyncioService .run_tasks (batch = batch , async_handler = llm .ask_async ),
22
- "batch_stream_async" : lambda llm , batch : AsyncioService .run_tasks (batch = batch , async_handler = llm .ask_stream_async ),
18
+ "single" : lambda llm , batch , ** kwargs : [llm .ask (prompt ) for prompt in batch ],
19
+ "single_stream" : lambda llm , batch , ** kwargs : [llm .ask_stream (prompt ) for prompt in batch ],
20
+ "batch" : lambda llm , batch , ** kwargs : llm .ask (batch ),
21
+ "batch_async" : lambda llm , batch , ** kwargs : AsyncioService .run_tasks (
22
+ batch = batch , async_handler = llm .ask_async , event_loop = kwargs .get ("event_loop" )
23
+ ),
24
+ "batch_stream_async" : lambda llm , batch , ** kwargs : AsyncioService .run_tasks (
25
+ batch = batch , async_handler = llm .ask_stream_async , event_loop = kwargs .get ("event_loop" )
26
+ ),
23
27
}
24
28
25
29
@@ -35,7 +39,7 @@ def _iter_batch_prompts(c, batch_content_it, **kwargs):
35
39
yield ind_in_batch , content
36
40
37
41
38
- def __handle_agen_to_gen (gen , ** kwargs ):
42
+ def __handle_agen_to_gen (handle , batch , event_loop ):
39
43
""" This handler provides conversion of the async generator to generator (sync).
40
44
"""
41
45
@@ -45,16 +49,17 @@ async def wrapper(index, agen):
45
49
yield index , item
46
50
return [wrapper (i , agen ) for i , agen in enumerate (async_gens )]
47
51
52
+ agen_list = handle (batch , event_loop = event_loop )
53
+
48
54
it = AsyncioService .async_gen_to_iter (
49
- gen = AsyncioService .merge_generators (* __wrap_with_index (gen )),
50
- loop = asyncio .get_event_loop ()
51
- )
55
+ gen = AsyncioService .merge_generators (* __wrap_with_index (agen_list )),
56
+ loop = event_loop )
52
57
53
58
for ind_in_batch , chunk in it :
54
59
yield ind_in_batch , str (chunk )
55
60
56
61
57
- def __handle_gen (gen , ** kwargs ):
62
+ def __handle_gen (handle , batch , event_loop ):
58
63
""" This handler deals with the iteration of each individual element of the batch.
59
64
"""
60
65
@@ -67,15 +72,16 @@ def _iter_entry_content(entry):
67
72
else :
68
73
raise Exception (f"Non supported type `{ type (entry )} ` for handling output from batch" )
69
74
70
- for ind_in_batch , entry in enumerate (gen ):
75
+ for ind_in_batch , entry in enumerate (handle ( batch , event_loop = event_loop ) ):
71
76
for chunk in _iter_entry_content (entry = entry ):
72
77
yield ind_in_batch , chunk
73
78
74
79
75
80
def _iter_chunks (p_column , batch_content_it , ** kwargs ):
76
81
handler = __handle_agen_to_gen if kwargs ["infer_mode" ] == "batch_stream_async" else __handle_gen
77
82
p_batch = [item [p_column ] for item in batch_content_it ]
78
- for ind_in_batch , chunk in handler (kwargs ["handle_batch_func" ](p_batch ), ** kwargs ):
83
+ it = handler (handle = kwargs ["handle_batch_func" ], batch = p_batch , event_loop = kwargs ["event_loop" ])
84
+ for ind_in_batch , chunk in it :
79
85
yield ind_in_batch , chunk
80
86
81
87
@@ -124,7 +130,8 @@ def _infer_batch(batch, batch_ind, schema, return_mode, cols=None, **kwargs):
124
130
125
131
126
132
def iter_content (input_dicts_it , llm , schema , batch_size = 1 , limit_prompt = None ,
127
- infer_mode = "batch" , return_mode = "batch" , attempts = 1 , ** kwargs ):
133
+ infer_mode = "batch" , return_mode = "batch" , attempts = 1 , event_loop = None ,
134
+ ** kwargs ):
128
135
""" This method represent Python API aimed at application of `llm` towards
129
136
iterator of input_dicts via cache_target that refers to the SQLite using
130
137
the given `schema`
@@ -133,6 +140,10 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None,
133
140
assert (return_mode in ["batch" , "chunk" , "record" ])
134
141
assert (isinstance (llm , BaseLM ))
135
142
143
+ # Setup event loop.
144
+ event_loop = asyncio .get_event_loop_policy ().get_event_loop () \
145
+ if event_loop is None else event_loop
146
+
136
147
# Quick initialization of the schema.
137
148
if isinstance (schema , str ):
138
149
schema = JsonService .read (schema )
@@ -144,8 +155,10 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None,
144
155
input_dicts_it
145
156
)
146
157
147
- handle_batch_func = lambda batch : INFER_MODES [infer_mode ](
148
- llm , DataService .limit_prompts (batch , limit = limit_prompt )
158
+ handle_batch_func = lambda batch , ** handle_kwargs : INFER_MODES [infer_mode ](
159
+ llm ,
160
+ DataService .limit_prompts (batch , limit = limit_prompt ),
161
+ ** handle_kwargs
149
162
)
150
163
151
164
# Optional wrapping into attempts.
@@ -166,6 +179,7 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None,
166
179
handle_missed_value_func = lambda * _ : None ,
167
180
return_mode = return_mode ,
168
181
schema = schema ,
182
+ event_loop = event_loop ,
169
183
** kwargs )
170
184
for batch_ind , batch in enumerate (BatchIterator (prompts_it , batch_size = batch_size )))
171
185
0 commit comments