7
7
from typing import Any , Callable , Optional
8
8
9
9
from fastapi import FastAPI , status
10
+ from fastapi .responses import StreamingResponse
10
11
from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
11
12
from pydantic import BaseModel
12
13
from starlette .responses import RedirectResponse
@@ -110,16 +111,29 @@ class InvokeResponse(BaseModel):
110
111
111
112
logging .getLogger ("etl_uvicorn.fastapi" )
112
113
113
- async def wrap_fn (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> InvokeResponse :
114
+ ResponseType = StreamingResponse if inspect .isasyncgenfunction (func ) else InvokeResponse
115
+
116
+ async def wrap_fn (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> ResponseType :
114
117
usage : list [UsageData ] = []
115
118
request_dict = kwargs if kwargs else {}
116
119
if "usage" in inspect .signature (func ).parameters :
117
120
request_dict ["usage" ] = usage
118
121
else :
119
122
logger .warning ("usage data not an expected parameter, omitting" )
120
123
try :
121
- output = await invoke_func (func = func , kwargs = request_dict )
122
- return InvokeResponse (usage = usage , status_code = status .HTTP_200_OK , output = output )
124
+ if inspect .isasyncgenfunction (func ):
125
+ # Stream response if function is an async generator
126
+
127
+ async def _stream_response ():
128
+ async for output in func (** (request_dict or {})):
129
+ yield InvokeResponse (
130
+ usage = usage , status_code = status .HTTP_200_OK , output = output
131
+ ).model_dump_json () + "\n "
132
+
133
+ return StreamingResponse (_stream_response (), media_type = "application/x-ndjson" )
134
+ else :
135
+ output = await invoke_func (func = func , kwargs = request_dict )
136
+ return InvokeResponse (usage = usage , status_code = status .HTTP_200_OK , output = output )
123
137
except Exception as invoke_error :
124
138
logger .error (f"failed to invoke plugin: { invoke_error } " , exc_info = True )
125
139
return InvokeResponse (
@@ -132,7 +146,7 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In
132
146
if input_schema_model .model_fields :
133
147
134
148
@fastapi_app .post ("/invoke" , response_model = InvokeResponse )
135
- async def run_job (request : input_schema_model ) -> InvokeResponse :
149
+ async def run_job (request : input_schema_model ) -> ResponseType :
136
150
log_func_and_body (func = func , body = request .json ())
137
151
# Create dictionary from pydantic model while preserving underlying types
138
152
request_dict = {f : getattr (request , f ) for f in request .model_fields }
@@ -144,7 +158,7 @@ async def run_job(request: input_schema_model) -> InvokeResponse:
144
158
else :
145
159
146
160
@fastapi_app .post ("/invoke" , response_model = InvokeResponse )
147
- async def run_job () -> InvokeResponse :
161
+ async def run_job () -> ResponseType :
148
162
log_func_and_body (func = func )
149
163
return await wrap_fn (
150
164
func = func ,
0 commit comments