Skip to content

Commit 4ecb6b8

Browse files
committed
better
1 parent 263fc28 commit 4ecb6b8

File tree

2 files changed

+22
-25
lines changed

2 files changed

+22
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## 0.0.6
22

3-
* **Support streaming response types for /invoke**
3+
* **Support streaming response types for /invoke if callable is async generator**
44

55
## 0.0.5
66

unstructured_platform_plugins/etl_uvicorn/api_generator.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,15 @@ def log_func_and_body(func: Callable, body: Optional[str] = None) -> None:
4747
logger.log(level=logger.level, msg=msg)
4848

4949

50-
async def run_generator_in_executor(generator_func, **kwargs):
51-
loop = asyncio.get_event_loop()
52-
# Create a future that will yield the generator values one by one
53-
gen = generator_func(**kwargs)
54-
while True:
55-
result = await loop.run_in_executor(None, next, gen, None)
56-
if result is None:
57-
break
58-
yield result
50+
async def invoke_async_gen_func(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> Any:
51+
kwargs = kwargs or {}
52+
async for val in func(**kwargs):
53+
yield val
5954

6055

6156
async def invoke_func(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> Any:
6257
kwargs = kwargs or {}
63-
if inspect.isasyncgenfunction(func):
64-
async for val in func(**kwargs):
65-
yield val
66-
elif inspect.isgeneratorfunction(func):
67-
async for val in run_generator_in_executor(func, **kwargs):
68-
yield val
69-
elif inspect.iscoroutinefunction(func):
58+
if inspect.iscoroutinefunction(func):
7059
yield await func(**kwargs)
7160
else:
7261
yield await asyncio.get_event_loop().run_in_executor(None, partial(func, **kwargs))
@@ -136,14 +125,22 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In
136125
else:
137126
logger.warning("usage data not an expected parameter, omitting")
138127
try:
139-
140-
async def _stream_response():
141-
async for output in invoke_func(func=func, kwargs=request_dict):
142-
yield InvokeResponse(
143-
usage=usage, status_code=status.HTTP_200_OK, output=output
144-
).model_dump_json() + "\n"
145-
146-
return StreamingResponse(_stream_response(), media_type="application/x-ndjson")
128+
if inspect.isasyncgenfunction(func):
129+
# Stream response if function is an async generator
130+
131+
async def _stream_response():
132+
async for output in invoke_async_gen_func(func=func, kwargs=request_dict):
133+
yield InvokeResponse(
134+
usage=usage, status_code=status.HTTP_200_OK, output=output
135+
).model_dump_json() + "\n"
136+
137+
return StreamingResponse(_stream_response(), media_type="application/x-ndjson")
138+
else:
139+
return InvokeResponse(
140+
usage=usage,
141+
status_code=status.HTTP_200_OK,
142+
output=await invoke_func(func, request_dict),
143+
)
147144
except Exception as invoke_error:
148145
logger.error(f"failed to invoke plugin: {invoke_error}", exc_info=True)
149146
return InvokeResponse(

0 commit comments

Comments
 (0)