11import asyncio
2+ from fastapi .responses import StreamingResponse
23import hashlib
34import inspect
45import json
56import logging
67from functools import partial
78from typing import Any , Callable , Optional
9+ import inspect
810
911from fastapi import FastAPI , status
1012from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
2628 schema_to_base_model ,
2729)
2830from unstructured_platform_plugins .schema .usage import UsageData
31+ import concurrent .futures
2932
3033logger = logging .getLogger ("uvicorn.error" )
3134
@@ -46,12 +49,29 @@ def log_func_and_body(func: Callable, body: Optional[str] = None) -> None:
4649 logger .log (level = logger .level , msg = msg )
4750
4851
52+ async def run_generator_in_executor (generator_func , ** kwargs ):
53+ loop = asyncio .get_event_loop ()
54+ # Create a future that will yield the generator values one by one
55+ gen = generator_func (** kwargs )
56+ while True :
57+ result = await loop .run_in_executor (None , next , gen , None )
58+ if result is None :
59+ break
60+ yield result
61+
62+
4963async def invoke_func (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> Any :
5064 kwargs = kwargs or {}
51- if inspect .iscoroutinefunction (func ):
52- return await func (** kwargs )
65+ if inspect .isasyncgenfunction (func ):
66+ async for val in func (** kwargs ):
67+ yield val
68+ elif inspect .isgeneratorfunction (func ):
69+ async for val in run_generator_in_executor (func , ** kwargs ):
70+ yield val
71+ elif inspect .iscoroutinefunction (func ):
72+ yield await func (** kwargs )
5373 else :
54- return await asyncio .get_event_loop ().run_in_executor (None , partial (func , ** kwargs ))
74+ yield await asyncio .get_event_loop ().run_in_executor (None , partial (func , ** kwargs ))
5575
5676
5777def check_precheck_func (precheck_func : Callable ):
@@ -95,7 +115,7 @@ def generate_fast_api(
95115
96116 logger .debug (f"set static id response to: { plugin_id } " )
97117
98- fastapi_app = FastAPI ()
118+ fastapi_app = FastAPI (debug = True )
99119
100120 response_type = get_output_sig (func )
101121
@@ -118,8 +138,14 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In
118138 else :
119139 logger .warning ("usage data not an expected parameter, omitting" )
120140 try :
121- output = await invoke_func (func = func , kwargs = request_dict )
122- return InvokeResponse (usage = usage , status_code = status .HTTP_200_OK , output = output )
141+
142+ async def _stream_response ():
143+ async for output in invoke_func (func = func , kwargs = request_dict ):
144+ yield InvokeResponse (
145+ usage = usage , status_code = status .HTTP_200_OK , output = output
146+ ).model_dump_json () + "\n "
147+
148+ return StreamingResponse (_stream_response (), media_type = "application/x-ndjson" )
123149 except Exception as invoke_error :
124150 logger .error (f"failed to invoke plugin: { invoke_error } " , exc_info = True )
125151 return InvokeResponse (
0 commit comments