Skip to content

Commit e4f94a7

Browse files
authored
Support streaming invoke responses (#23)
1 parent 7afdef3 commit e4f94a7

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.0.6
2+
3+
* **Support streaming response types for /invoke if callable is async generator**
4+
15
## 0.0.5
26

37
* **Improve logging to hide body in case of sensitive data unless TRACE level**
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.5" # pragma: no cover
1+
__version__ = "0.0.6" # pragma: no cover

unstructured_platform_plugins/etl_uvicorn/api_generator.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Callable, Optional
88

99
from fastapi import FastAPI, status
10+
from fastapi.responses import StreamingResponse
1011
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
1112
from pydantic import BaseModel
1213
from starlette.responses import RedirectResponse
@@ -110,16 +111,29 @@ class InvokeResponse(BaseModel):
110111

111112
logging.getLogger("etl_uvicorn.fastapi")
112113

113-
async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> InvokeResponse:
114+
ResponseType = StreamingResponse if inspect.isasyncgenfunction(func) else InvokeResponse
115+
116+
async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> ResponseType:
114117
usage: list[UsageData] = []
115118
request_dict = kwargs if kwargs else {}
116119
if "usage" in inspect.signature(func).parameters:
117120
request_dict["usage"] = usage
118121
else:
119122
logger.warning("usage data not an expected parameter, omitting")
120123
try:
121-
output = await invoke_func(func=func, kwargs=request_dict)
122-
return InvokeResponse(usage=usage, status_code=status.HTTP_200_OK, output=output)
124+
if inspect.isasyncgenfunction(func):
125+
# Stream response if function is an async generator
126+
127+
async def _stream_response():
128+
async for output in func(**(request_dict or {})):
129+
yield InvokeResponse(
130+
usage=usage, status_code=status.HTTP_200_OK, output=output
131+
).model_dump_json() + "\n"
132+
133+
return StreamingResponse(_stream_response(), media_type="application/x-ndjson")
134+
else:
135+
output = await invoke_func(func=func, kwargs=request_dict)
136+
return InvokeResponse(usage=usage, status_code=status.HTTP_200_OK, output=output)
123137
except Exception as invoke_error:
124138
logger.error(f"failed to invoke plugin: {invoke_error}", exc_info=True)
125139
return InvokeResponse(
@@ -132,7 +146,7 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In
132146
if input_schema_model.model_fields:
133147

134148
@fastapi_app.post("/invoke", response_model=InvokeResponse)
135-
async def run_job(request: input_schema_model) -> InvokeResponse:
149+
async def run_job(request: input_schema_model) -> ResponseType:
136150
log_func_and_body(func=func, body=request.json())
137151
# Create dictionary from pydantic model while preserving underlying types
138152
request_dict = {f: getattr(request, f) for f in request.model_fields}
@@ -144,7 +158,7 @@ async def run_job(request: input_schema_model) -> InvokeResponse:
144158
else:
145159

146160
@fastapi_app.post("/invoke", response_model=InvokeResponse)
147-
async def run_job() -> InvokeResponse:
161+
async def run_job() -> ResponseType:
148162
log_func_and_body(func=func)
149163
return await wrap_fn(
150164
func=func,

0 commit comments

Comments
 (0)