1
1
import asyncio
2
+ from fastapi .responses import StreamingResponse
2
3
import hashlib
3
4
import inspect
4
5
import json
5
6
import logging
6
7
from functools import partial
7
8
from typing import Any , Callable , Optional
9
+ import inspect
8
10
9
11
from fastapi import FastAPI , status
10
12
from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
26
28
schema_to_base_model ,
27
29
)
28
30
from unstructured_platform_plugins .schema .usage import UsageData
31
+ import concurrent .futures
29
32
30
33
logger = logging .getLogger ("uvicorn.error" )
31
34
@@ -46,12 +49,29 @@ def log_func_and_body(func: Callable, body: Optional[str] = None) -> None:
46
49
logger .log (level = logger .level , msg = msg )
47
50
48
51
52
+ async def run_generator_in_executor (generator_func , ** kwargs ):
53
+ loop = asyncio .get_event_loop ()
54
+ # Create a future that will yield the generator values one by one
55
+ gen = generator_func (** kwargs )
56
+ while True :
57
+ result = await loop .run_in_executor (None , next , gen , None )
58
+ if result is None :
59
+ break
60
+ yield result
61
+
62
+
49
63
async def invoke_func (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> Any :
50
64
kwargs = kwargs or {}
51
- if inspect .iscoroutinefunction (func ):
52
- return await func (** kwargs )
65
+ if inspect .isasyncgenfunction (func ):
66
+ async for val in func (** kwargs ):
67
+ yield val
68
+ elif inspect .isgeneratorfunction (func ):
69
+ async for val in run_generator_in_executor (func , ** kwargs ):
70
+ yield val
71
+ elif inspect .iscoroutinefunction (func ):
72
+ yield await func (** kwargs )
53
73
else :
54
- return await asyncio .get_event_loop ().run_in_executor (None , partial (func , ** kwargs ))
74
+ yield await asyncio .get_event_loop ().run_in_executor (None , partial (func , ** kwargs ))
55
75
56
76
57
77
def check_precheck_func (precheck_func : Callable ):
@@ -95,7 +115,7 @@ def generate_fast_api(
95
115
96
116
logger .debug (f"set static id response to: { plugin_id } " )
97
117
98
- fastapi_app = FastAPI ()
118
+ fastapi_app = FastAPI (debug = True )
99
119
100
120
response_type = get_output_sig (func )
101
121
@@ -118,8 +138,14 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In
118
138
else :
119
139
logger .warning ("usage data not an expected parameter, omitting" )
120
140
try :
121
- output = await invoke_func (func = func , kwargs = request_dict )
122
- return InvokeResponse (usage = usage , status_code = status .HTTP_200_OK , output = output )
141
+
142
+ async def _stream_response ():
143
+ async for output in invoke_func (func = func , kwargs = request_dict ):
144
+ yield InvokeResponse (
145
+ usage = usage , status_code = status .HTTP_200_OK , output = output
146
+ ).model_dump_json () + "\n "
147
+
148
+ return StreamingResponse (_stream_response (), media_type = "application/x-ndjson" )
123
149
except Exception as invoke_error :
124
150
logger .error (f"failed to invoke plugin: { invoke_error } " , exc_info = True )
125
151
return InvokeResponse (
0 commit comments