Skip to content

Commit afb84fa

Browse files
committed
enforce the NewRecord model typing to be more prescriptive rather than use Any for content
1 parent a9c4fa6 commit afb84fa

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

unstructured_platform_plugins/etl_uvicorn/api_generator.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import json
55
import logging
66
from functools import partial
7-
from typing import Any, Callable, Optional
7+
from typing import Any, Callable, Optional, Union
88

99
from fastapi import FastAPI, status
1010
from fastapi.responses import StreamingResponse
1111
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
12-
from pydantic import BaseModel
12+
from pydantic import BaseModel, Field, create_model
1313
from starlette.responses import RedirectResponse
1414
from uvicorn.config import LOG_LEVELS
1515
from uvicorn.importer import import_from_string
@@ -23,7 +23,7 @@
2323
get_schema_dict,
2424
map_inputs,
2525
)
26-
from unstructured_platform_plugins.schema import FileDataMeta, UsageData
26+
from unstructured_platform_plugins.schema import FileDataMeta, NewRecord, UsageData
2727
from unstructured_platform_plugins.schema.json_schema import (
2828
schema_to_base_model,
2929
)
@@ -67,6 +67,30 @@ def check_precheck_func(precheck_func: Callable):
6767
raise ValueError(f"no output should exist for precheck function, found: {outputs}")
6868

6969

70+
def is_optional(t: Any) -> bool:
71+
return (
72+
hasattr(t, "__origin__")
73+
and t.__origin__ is Union
74+
and hasattr(t, "__args__")
75+
and type(None) in t.__args__
76+
)
77+
78+
79+
def update_filedata_model(new_type) -> BaseModel:
80+
field_info = NewRecord.model_fields["contents"]
81+
if is_optional(new_type):
82+
field_info.default = None
83+
new_record_model = create_model(
84+
NewRecord.__name__, contents=(new_type, field_info), __base__=NewRecord
85+
)
86+
new_filedata_model = create_model(
87+
FileDataMeta.__name__,
88+
new_records=(list[new_record_model], Field(default_factory=list)),
89+
__base__=FileDataMeta,
90+
)
91+
return new_filedata_model
92+
93+
7094
def wrap_in_fastapi(
7195
func: Callable,
7296
plugin_id: str,
@@ -80,11 +104,12 @@ def wrap_in_fastapi(
80104
fastapi_app = FastAPI()
81105

82106
response_type = get_output_sig(func)
107+
filedata_meta_model = update_filedata_model(response_type)
83108

84109
class InvokeResponse(BaseModel):
85110
usage: list[UsageData]
86111
status_code: int
87-
filedata_meta: FileDataMeta
112+
filedata_meta: filedata_meta_model
88113
status_code_text: Optional[str] = None
89114
output: Optional[response_type] = None
90115

@@ -113,7 +138,9 @@ async def _stream_response():
113138
async for output in func(**(request_dict or {})):
114139
yield InvokeResponse(
115140
usage=usage,
116-
filedata_meta=filedata_meta,
141+
filedata_meta=filedata_meta_model.model_validate(
142+
filedata_meta.model_dump()
143+
),
117144
status_code=status.HTTP_200_OK,
118145
output=output,
119146
).model_dump_json() + "\n"
@@ -123,15 +150,15 @@ async def _stream_response():
123150
output = await invoke_func(func=func, kwargs=request_dict)
124151
return InvokeResponse(
125152
usage=usage,
126-
filedata_meta=filedata_meta,
153+
filedata_meta=filedata_meta_model.model_validate(filedata_meta.model_dump()),
127154
status_code=status.HTTP_200_OK,
128155
output=output,
129156
)
130157
except Exception as invoke_error:
131158
logger.error(f"failed to invoke plugin: {invoke_error}", exc_info=True)
132159
return InvokeResponse(
133160
usage=usage,
134-
filedata_meta=filedata_meta,
161+
filedata_meta=filedata_meta_model.model_validate(filedata_meta.model_dump()),
135162
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
136163
status_code_text=f"failed to invoke plugin: "
137164
f"[{invoke_error.__class__.__name__}] {invoke_error}",

unstructured_platform_plugins/etl_uvicorn/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
from unstructured_platform_plugins.type_hints import get_type_hints
1717

1818

19+
def is_optional(t: Any) -> bool:
20+
return hasattr(t, "__origin__") and t.__origin__ is not None
21+
22+
1923
def get_func(instance: Any, method_name: Optional[str] = None) -> Callable:
2024
method_name = method_name or "__call__"
2125
if inspect.isfunction(instance):

0 commit comments

Comments
 (0)