4
4
import json
5
5
import logging
6
6
from functools import partial
7
- from typing import Any , Callable , Optional
7
+ from typing import Any , Callable , Optional , Union
8
8
9
9
from fastapi import FastAPI , status
10
10
from fastapi .responses import StreamingResponse
11
11
from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
12
- from pydantic import BaseModel
12
+ from pydantic import BaseModel , Field , create_model
13
13
from starlette .responses import RedirectResponse
14
14
from uvicorn .config import LOG_LEVELS
15
15
from uvicorn .importer import import_from_string
23
23
get_schema_dict ,
24
24
map_inputs ,
25
25
)
26
- from unstructured_platform_plugins .schema import FileDataMeta , UsageData
26
+ from unstructured_platform_plugins .schema import FileDataMeta , NewRecord , UsageData
27
27
from unstructured_platform_plugins .schema .json_schema import (
28
28
schema_to_base_model ,
29
29
)
@@ -67,6 +67,30 @@ def check_precheck_func(precheck_func: Callable):
67
67
raise ValueError (f"no output should exist for precheck function, found: { outputs } " )
68
68
69
69
70
+ def is_optional (t : Any ) -> bool :
71
+ return (
72
+ hasattr (t , "__origin__" )
73
+ and t .__origin__ is Union
74
+ and hasattr (t , "__args__" )
75
+ and type (None ) in t .__args__
76
+ )
77
+
78
+
79
+ def update_filedata_model (new_type ) -> BaseModel :
80
+ field_info = NewRecord .model_fields ["contents" ]
81
+ if is_optional (new_type ):
82
+ field_info .default = None
83
+ new_record_model = create_model (
84
+ NewRecord .__name__ , contents = (new_type , field_info ), __base__ = NewRecord
85
+ )
86
+ new_filedata_model = create_model (
87
+ FileDataMeta .__name__ ,
88
+ new_records = (list [new_record_model ], Field (default_factory = list )),
89
+ __base__ = FileDataMeta ,
90
+ )
91
+ return new_filedata_model
92
+
93
+
70
94
def wrap_in_fastapi (
71
95
func : Callable ,
72
96
plugin_id : str ,
@@ -80,11 +104,12 @@ def wrap_in_fastapi(
80
104
fastapi_app = FastAPI ()
81
105
82
106
response_type = get_output_sig (func )
107
+ filedata_meta_model = update_filedata_model (response_type )
83
108
84
109
class InvokeResponse (BaseModel ):
85
110
usage : list [UsageData ]
86
111
status_code : int
87
- filedata_meta : FileDataMeta
112
+ filedata_meta : filedata_meta_model
88
113
status_code_text : Optional [str ] = None
89
114
output : Optional [response_type ] = None
90
115
@@ -113,7 +138,9 @@ async def _stream_response():
113
138
async for output in func (** (request_dict or {})):
114
139
yield InvokeResponse (
115
140
usage = usage ,
116
- filedata_meta = filedata_meta ,
141
+ filedata_meta = filedata_meta_model .model_validate (
142
+ filedata_meta .model_dump ()
143
+ ),
117
144
status_code = status .HTTP_200_OK ,
118
145
output = output ,
119
146
).model_dump_json () + "\n "
@@ -123,15 +150,15 @@ async def _stream_response():
123
150
output = await invoke_func (func = func , kwargs = request_dict )
124
151
return InvokeResponse (
125
152
usage = usage ,
126
- filedata_meta = filedata_meta ,
153
+ filedata_meta = filedata_meta_model . model_validate ( filedata_meta . model_dump ()) ,
127
154
status_code = status .HTTP_200_OK ,
128
155
output = output ,
129
156
)
130
157
except Exception as invoke_error :
131
158
logger .error (f"failed to invoke plugin: { invoke_error } " , exc_info = True )
132
159
return InvokeResponse (
133
160
usage = usage ,
134
- filedata_meta = filedata_meta ,
161
+ filedata_meta = filedata_meta_model . model_validate ( filedata_meta . model_dump ()) ,
135
162
status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
136
163
status_code_text = f"failed to invoke plugin: "
137
164
f"[{ invoke_error .__class__ .__name__ } ] { invoke_error } " ,
0 commit comments