diff --git a/CHANGELOG.md b/CHANGELOG.md index 98cc5d7..6c6332c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.0.7 + +* **Improve code separation to help with unit tests** + ## 0.0.6 * **Support streaming response types for /invoke if callable is async generator** diff --git a/unstructured_platform_plugins/__version__.py b/unstructured_platform_plugins/__version__.py index 62c73e6..5382d2c 100644 --- a/unstructured_platform_plugins/__version__.py +++ b/unstructured_platform_plugins/__version__.py @@ -1 +1 @@ -__version__ = "0.0.6" # pragma: no cover +__version__ = "0.0.7" # pragma: no cover diff --git a/unstructured_platform_plugins/etl_uvicorn/api_generator.py b/unstructured_platform_plugins/etl_uvicorn/api_generator.py index c19cbb2..2536e66 100644 --- a/unstructured_platform_plugins/etl_uvicorn/api_generator.py +++ b/unstructured_platform_plugins/etl_uvicorn/api_generator.py @@ -67,30 +67,11 @@ def check_precheck_func(precheck_func: Callable): raise ValueError(f"no output should exist for precheck function, found: {outputs}") -def generate_fast_api( - app: str, - method_name: Optional[str] = None, - id_str: Optional[str] = None, - id_method: Optional[str] = None, - precheck_str: Optional[str] = None, - precheck_method: Optional[str] = None, -) -> FastAPI: - instance = import_from_string(app) - func = get_func(instance, method_name) - if id_str: - id_ref = import_from_string(id_str) - plugin_id = get_plugin_id(instance=id_ref, method_name=id_method) - else: - plugin_id = hashlib.sha256( - json.dumps(get_schema_dict(func), sort_keys=True).encode() - ).hexdigest()[:32] - - precheck_func = None - if precheck_str: - precheck_instance = import_from_string(precheck_str) - precheck_func = get_func(precheck_instance, precheck_method) - elif precheck_method: - precheck_func = get_func(instance, precheck_method) +def wrap_in_fastapi( + func: Callable, + plugin_id: str, + precheck_func: Optional[Callable] = None, +): if precheck_func is not None: check_precheck_func(precheck_func=precheck_func) @@ -210,3 +191,31 @@ async def get_id() -> str: ) return fastapi_app + + +def generate_fast_api( + app: str, + method_name: Optional[str] = None, + id_str: Optional[str] = None, + id_method: Optional[str] = None, + precheck_str: Optional[str] = None, + precheck_method: Optional[str] = None, +) -> FastAPI: + instance = import_from_string(app) + func = get_func(instance, method_name) + if id_str: + id_ref = import_from_string(id_str) + plugin_id = get_plugin_id(instance=id_ref, method_name=id_method) + else: + plugin_id = hashlib.sha256( + json.dumps(get_schema_dict(func), sort_keys=True).encode() + ).hexdigest()[:32] + + precheck_func = None + if precheck_str: + precheck_instance = import_from_string(precheck_str) + precheck_func = get_func(precheck_instance, precheck_method) + elif precheck_method: + precheck_func = get_func(instance, precheck_method) + + return wrap_in_fastapi(func=func, plugin_id=plugin_id, precheck_func=precheck_func)