Skip to content

Commit 2391c6f

Browse files
committed
streaming invoke responses poc
1 parent 7afdef3 commit 2391c6f

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

unstructured_platform_plugins/etl_uvicorn/api_generator.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import asyncio
2+
from fastapi.responses import StreamingResponse
23
import hashlib
34
import inspect
45
import json
56
import logging
67
from functools import partial
78
from typing import Any, Callable, Optional
9+
import inspect
810

911
from fastapi import FastAPI, status
1012
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
@@ -26,6 +28,7 @@
2628
schema_to_base_model,
2729
)
2830
from unstructured_platform_plugins.schema.usage import UsageData
31+
import concurrent.futures
2932

3033
logger = logging.getLogger("uvicorn.error")
3134

@@ -46,12 +49,29 @@ def log_func_and_body(func: Callable, body: Optional[str] = None) -> None:
4649
logger.log(level=logger.level, msg=msg)
4750

4851

52+
async def run_generator_in_executor(generator_func, **kwargs):
53+
loop = asyncio.get_event_loop()
54+
# Create a future that will yield the generator values one by one
55+
gen = generator_func(**kwargs)
56+
while True:
57+
result = await loop.run_in_executor(None, next, gen, None)
58+
if result is None:
59+
break
60+
yield result
61+
62+
4963
async def invoke_func(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> Any:
5064
kwargs = kwargs or {}
51-
if inspect.iscoroutinefunction(func):
52-
return await func(**kwargs)
65+
if inspect.isasyncgenfunction(func):
66+
async for val in func(**kwargs):
67+
yield val
68+
elif inspect.isgeneratorfunction(func):
69+
async for val in run_generator_in_executor(func, **kwargs):
70+
yield val
71+
elif inspect.iscoroutinefunction(func):
72+
yield await func(**kwargs)
5373
else:
54-
return await asyncio.get_event_loop().run_in_executor(None, partial(func, **kwargs))
74+
yield await asyncio.get_event_loop().run_in_executor(None, partial(func, **kwargs))
5575

5676

5777
def check_precheck_func(precheck_func: Callable):
@@ -95,7 +115,7 @@ def generate_fast_api(
95115

96116
logger.debug(f"set static id response to: {plugin_id}")
97117

98-
fastapi_app = FastAPI()
118+
fastapi_app = FastAPI(debug=True)
99119

100120
response_type = get_output_sig(func)
101121

@@ -118,8 +138,14 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In
118138
else:
119139
logger.warning("usage data not an expected parameter, omitting")
120140
try:
121-
output = await invoke_func(func=func, kwargs=request_dict)
122-
return InvokeResponse(usage=usage, status_code=status.HTTP_200_OK, output=output)
141+
142+
async def _stream_response():
143+
async for output in invoke_func(func=func, kwargs=request_dict):
144+
yield InvokeResponse(
145+
usage=usage, status_code=status.HTTP_200_OK, output=output
146+
).model_dump_json() + "\n"
147+
148+
return StreamingResponse(_stream_response(), media_type="application/x-ndjson")
123149
except Exception as invoke_error:
124150
logger.error(f"failed to invoke plugin: {invoke_error}", exc_info=True)
125151
return InvokeResponse(

0 commit comments

Comments
 (0)