@@ -111,7 +111,9 @@ class InvokeResponse(BaseModel):
111
111
112
112
logging .getLogger ("etl_uvicorn.fastapi" )
113
113
114
- async def wrap_fn (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> InvokeResponse :
114
+ ResponseType = StreamingResponse if inspect .isasyncgenfunction (func ) else InvokeResponse
115
+
116
+ async def wrap_fn (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> ResponseType :
115
117
usage : list [UsageData ] = []
116
118
request_dict = kwargs if kwargs else {}
117
119
if "usage" in inspect .signature (func ).parameters :
@@ -144,7 +146,7 @@ async def _stream_response():
144
146
if input_schema_model .model_fields :
145
147
146
148
@fastapi_app .post ("/invoke" , response_model = InvokeResponse )
147
- async def run_job (request : input_schema_model ) -> InvokeResponse :
149
+ async def run_job (request : input_schema_model ) -> ResponseType :
148
150
log_func_and_body (func = func , body = request .json ())
149
151
# Create dictionary from pydantic model while preserving underlying types
150
152
request_dict = {f : getattr (request , f ) for f in request .model_fields }
@@ -156,7 +158,7 @@ async def run_job(request: input_schema_model) -> InvokeResponse:
156
158
else :
157
159
158
160
@fastapi_app .post ("/invoke" , response_model = InvokeResponse )
159
- async def run_job () -> InvokeResponse :
161
+ async def run_job () -> ResponseType :
160
162
log_func_and_body (func = func )
161
163
return await wrap_fn (
162
164
func = func ,
0 commit comments