diff --git a/.env.example b/.env.example index aac6c7d..074f130 100644 --- a/.env.example +++ b/.env.example @@ -15,4 +15,3 @@ LOAD_FILE_URL=http://localhost:8000/storage/load DELETE_FILE_URL=http://localhost:8000/storage/delete OCR_REQUEST_URL=http://localhost:8000/ocr/request OCR_UPLOAD_URL=http://localhost:8000/ocr/upload - diff --git a/.gitignore b/.gitignore index f9534f3..f41096f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ text_extract_api/__pycache__/* .dvenv .DS_Store storage/* -client_secret*.json +client_secret*.json .env.localhost .idea # Python good practice ignore diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..20a34e5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: 'v0.11.7' + hooks: + - id: ruff + name: linting code with Ruff + args: [ "--fix" ] + - id: ruff-format + name: format code using Ruff formatter diff --git a/.vscode/settings.json b/.vscode/settings.json index a77d3ce..3a925c8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,4 +3,4 @@ "**/__pycache__": true, "**/*.egg-info": true } -} \ No newline at end of file +} diff --git a/Makefile b/Makefile index 99d6431..280410d 100644 --- a/Makefile +++ b/Makefile @@ -136,4 +136,3 @@ clean: .PHONY: clean-python-cache clear-cache: find . -type d -name '__pycache__' -exec rm -rf {} + && find . -type f -name '*.pyc' -delete - diff --git a/README.md b/README.md index 86eb034..93a9699 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,13 @@ The API is built with FastAPI and uses Celery for asynchronous task processing. - **Distributed queue processing** using [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) - **Caching** using Redis - the OCR results can be easily cached prior to LLM processing, - **Storage Strategies** switchable storage strategies (Google Drive, Local File System ...) -- **CLI tool** for sending tasks and processing results +- **CLI tool** for sending tasks and processing results ## Screenshots Converting MRI report to Markdown + JSON. -```bash +```bash python client/cli.py ocr_upload --file examples/example-mri.pdf --prompt_file examples/example-mri-2-json-prompt.txt ``` @@ -31,8 +31,8 @@ Before running the example see [getting started](#getting-started) Converting Invoice to JSON and remove PII -```bash -python client/cli.py ocr_upload --file examples/example-invoice.pdf --prompt_file examples/example-invoice-remove-pii.txt +```bash +python client/cli.py ocr_upload --file examples/example-invoice.pdf --prompt_file examples/example-invoice-remove-pii.txt ``` Before running the example see [getting started](#getting-started) @@ -52,19 +52,19 @@ To have it up and running please execute the following steps: > ### Setting Up Ollama on a Remote Host -> +> > To connect to an external Ollama instance, set the environment variable: `OLLAMA_HOST=http://address:port`, e.g.: > ```bash > OLLAMA_HOST=http(s)://127.0.0.1:5000 > ``` -> +> > If you want to disable the local Ollama model, use env `DISABLE_LOCAL_OLLAMA=1`, e.g. > ```bash > DISABLE_LOCAL_OLLAMA=1 make install > ``` -> **Note**: When local Ollama is disabled, ensure the required model is downloaded on the external instance. -> -> Currently, the `DISABLE_LOCAL_OLLAMA` variable cannot be used to disable Ollama in Docker. As a workaround, remove the `ollama` service from `docker-compose.yml` or `docker-compose.gpu.yml`. +> **Note**: When local Ollama is disabled, ensure the required model is downloaded on the external instance. +> +> Currently, the `DISABLE_LOCAL_OLLAMA` variable cannot be used to disable Ollama in Docker. As a workaround, remove the `ollama` service from `docker-compose.yml` or `docker-compose.gpu.yml`. > > Support for using the variable in Docker environments will be added in a future release. @@ -83,11 +83,11 @@ cd text-extract-api Be default application create [virtual python env](https://docs.python.org/3/library/venv.html): `.venv`. You can disable this functionality on local setup by adding `DISABLE_VENV=1` before running script: ```bash -DISABLE_VENV=1 make install +DISABLE_VENV=1 make install ``` ```bash -DISABLE_VENV=1 make run +DISABLE_VENV=1 make run ``` ### Manual setup @@ -110,6 +110,20 @@ run.sh This command will install all the dependencies - including Redis (via Docker, so it is not entirely docker free method of running `text-extract-api` anyways :) + +Run pre-commit checks + +This command will run a pre-commit checks and, if possible, automatically format the code. +The hooks used are available in the `pre-commit-config.yaml` file. + +Be sure to run this command after installing dependencies +Including also those from the dev section + +```bash +pre-commit run --all-files +``` + + (MAC) - Dependencies ``` brew update && brew install libmagic poppler pkg-config ghostscript ffmpeg automake autoconf @@ -172,7 +186,7 @@ Easy OCR is available on Apache based license. It's general purpose OCR with sup Enabled by default. Please do use the `strategy=easyocr` CLI and URL parameters to use it. -### `minicpm-v` +### `minicpm-v` MiniCPM-V is an Apache based licensed OCR strategy. @@ -190,7 +204,7 @@ Enabled by default. Please do use the `strategy=minicpm_v` CLI and URL parameter -### `llama_vision` +### `llama_vision` LLama 3.2 Vision Strategy is licensed on [Meta Community License Agreement](https://ollama.com/library/llama3.2-vision/blobs/0b4284c1f870). Works great for many languages, although due to the number of parameters (90b) this model is probably **the slowest** one. @@ -201,7 +215,7 @@ Enabled by default. Please do use the `strategy=llama_vision` CLI and URL parame Some OCR's - like [Marker, state of the art PDF OCR](https://github.com/VikParuchuri/marker) - works really great for more than 50 languages, including great accuracy for Polish and other languages - let's say that are "diffult" to read for standard OCR. -The `marker-pdf` is however licensed on GPL3 license and **therefore it's not included** by default in this application (as we're bound to MIT). +The `marker-pdf` is however licensed on GPL3 license and **therefore it's not included** by default in this application (as we're bound to MIT). The weights for the models are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. You also must not be competitive with the Datalab API. If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options here. @@ -234,7 +248,7 @@ make run Please do use the `strategy=remote` CLI and URL parameters to use it. For example: ```bash -curl -X POST -H "Content-Type: multipart/form-data" -F "file=@examples/example-mri.pdf" -F "strategy=remote" -F "ocr_cache=true" -F "prompt=" -F "model=" "http://localhost:8000/ocr/upload" +curl -X POST -H "Content-Type: multipart/form-data" -F "file=@examples/example-mri.pdf" -F "strategy=remote" -F "ocr_cache=true" -F "prompt=" -F "model=" "http://localhost:8000/ocr/upload" ``` We are connecting to remote OCR via it's API to not share the same license (GPL3) by having it all linked on the source code level. @@ -266,7 +280,7 @@ Create `.env` file in the root directory and set the necessary environment varia cp .env.example .env ``` -or +or ```bash # defaults for local run @@ -412,7 +426,7 @@ python client/cli.py result --task_id {your_task_id_from_upload_step} ### List file results archived by `storage_profile` ```bash -python client/cli.py list_files +python client/cli.py list_files ``` to use specific (in this case `google drive`) storage profile run: @@ -436,7 +450,7 @@ python client/cli.py delete_file --file_name "invoices/2024/example-invoice-2024 or for default profile (local file system): ```bash -python client/cli.py delete_file --file_name "invoices/2024/example-invoice-2024-10-31-16-33.md" +python client/cli.py delete_file --file_name "invoices/2024/example-invoice-2024-10-31-16-33.md" ``` ### Clear OCR Cache @@ -498,7 +512,7 @@ apiClient.uploadFile(formData).then(response => { Example: ```bash -curl -X POST -H "Content-Type: multipart/form-data" -F "file=@examples/example-mri.pdf" -F "strategy=easyocr" -F "ocr_cache=true" -F "prompt=" -F "model=" "http://localhost:8000/ocr/upload" +curl -X POST -H "Content-Type: multipart/form-data" -F "file=@examples/example-mri.pdf" -F "strategy=easyocr" -F "ocr_cache=true" -F "prompt=" -F "model=" "http://localhost:8000/ocr/upload" ``` ### OCR Endpoint via JSON request @@ -576,14 +590,14 @@ curl -X POST "http://localhost:8000/llm/generate" -H "Content-Type: application/ ``` ### List storage files: - + - **URL:** /storage/list - **Method:** GET - **Parameters**: - **storage_profile**: Name of the storage profile to use for listing files (default: `default`). ### Download storage file: - + - **URL:** /storage/load - **Method:** GET - **Parameters**: @@ -591,7 +605,7 @@ curl -X POST "http://localhost:8000/llm/generate" -H "Content-Type: application/ - **storage_profile**: Name of the storage profile to use for listing files (default: `default`). ### Delete storage file: - + - **URL:** /storage/delete - **Method:** DELETE - **Parameters**: @@ -641,10 +655,10 @@ settings: #### Requirements for AWS S3 Access Key -1. **Access Key Ownership** +1. **Access Key Ownership** The access key must belong to an IAM user or role with permissions for S3 operations. -2. **IAM Policy Example** +2. **IAM Policy Example** The IAM policy attached to the user or role must allow the necessary actions. Below is an example of a policy granting access to an S3 bucket: ```json { diff --git a/client/Dockerfile b/client/Dockerfile index f46bebc..94b8845 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -18,4 +18,4 @@ RUN pip install --no-cache-dir -r requirements.txt COPY . . # Define the command to run the CLI script -CMD ["python", "cli.py"] \ No newline at end of file +CMD ["python", "cli.py"] diff --git a/client/cli.py b/client/cli.py index 4cc478c..818883e 100644 --- a/client/cli.py +++ b/client/cli.py @@ -6,286 +6,553 @@ import math from ollama import pull -def ocr_upload(file_path, ocr_cache, prompt, prompt_file=None, model='llama3.1', strategy='llama_vision', storage_profile='default', storage_filename=None, language='en'): - ocr_url = os.getenv('OCR_UPLOAD_URL', 'http://localhost:8000/ocr/upload') - files = {'file': open(file_path, 'rb')} + +def ocr_upload( + file_path, + ocr_cache, + prompt, + prompt_file=None, + model="llama3.1", + strategy="llama_vision", + storage_profile="default", + storage_filename=None, + language="en", +): + ocr_url = os.getenv("OCR_UPLOAD_URL", "http://localhost:8000/ocr/upload") + files = {"file": open(file_path, "rb")} if not ocr_cache: print("OCR cache disabled.") - data = {'ocr_cache': ocr_cache, 'model': model, 'strategy': strategy, 'storage_profile': storage_profile, 'language': language} + data = { + "ocr_cache": ocr_cache, + "model": model, + "strategy": strategy, + "storage_profile": storage_profile, + "language": language, + } if storage_filename: - data['storage_filename'] = storage_filename - - print(data) # @todo change to log debug in the future + data["storage_filename"] = storage_filename + + print(data) # @todo change to log debug in the future try: if prompt_file: - prompt = open(prompt_file, 'r').read() + prompt = open(prompt_file, "r").read() except FileNotFoundError: print(f"Prompt file not found: {prompt_file}") return None - + if prompt: - data['prompt'] = prompt + data["prompt"] = prompt response = requests.post(ocr_url, files=files, data=data) if response.status_code == 200: respObject = response.json() - if respObject.get('task_id'): - return { - "task_id": respObject.get('task_id') - } + if respObject.get("task_id"): + return {"task_id": respObject.get("task_id")} else: return { - "text": respObject.get('text') # sync mode support + "text": respObject.get("text") # sync mode support } else: print(f"Failed to upload file: {response.text}") return None -def ocr_request(file_path, ocr_cache, prompt, prompt_file=None, model='llama3.1', strategy='llama_vision', storage_profile='default', storage_filename=None, language='en'): - ocr_url = os.getenv('OCR_REQUEST_URL', 'http://localhost:8000/ocr/request') - with open(file_path, 'rb') as f: - file_content = base64.b64encode(f.read()).decode('utf-8') - + +def ocr_request( + file_path, + ocr_cache, + prompt, + prompt_file=None, + model="llama3.1", + strategy="llama_vision", + storage_profile="default", + storage_filename=None, + language="en", +): + ocr_url = os.getenv("OCR_REQUEST_URL", "http://localhost:8000/ocr/request") + with open(file_path, "rb") as f: + file_content = base64.b64encode(f.read()).decode("utf-8") + data = { - 'ocr_cache': ocr_cache, - 'model': model, - 'strategy': strategy, - 'storage_profile': storage_profile, - 'file': file_content, - 'language': language + "ocr_cache": ocr_cache, + "model": model, + "strategy": strategy, + "storage_profile": storage_profile, + "file": file_content, + "language": language, } if storage_filename: - data['storage_filename'] = storage_filename - + data["storage_filename"] = storage_filename + if prompt_file: try: - prompt = open(prompt_file, 'r').read() + prompt = open(prompt_file, "r").read() except FileNotFoundError: print(f"Prompt file not found: {prompt_file}") return None - + if prompt: - data['prompt'] = prompt - + data["prompt"] = prompt + response = requests.post(ocr_url, json=data) if response.status_code == 200: respObject = response.json() - if respObject.get('task_id'): - return { - "task_id": respObject.get('task_id') - } + if respObject.get("task_id"): + return {"task_id": respObject.get("task_id")} else: return { - "text": respObject.get('text') # sync mode support + "text": respObject.get("text") # sync mode support } else: print(f"Error: {response.status_code} - {response.text}") return None -def get_result(task_id, print_progress = False): + +def get_result(task_id, print_progress=False): extracted_text_printed_once = False - result_url = os.getenv('RESULT_URL', f'http://localhost:8000/ocr/result/') + result_url = os.getenv("RESULT_URL", "http://localhost:8000/ocr/result/") while True: response = requests.get(result_url + task_id) result = response.json() - if result['state'] != 'SUCCESS' and print_progress: - task_info = result.get('info') + if result["state"] != "SUCCESS" and print_progress: + task_info = result.get("info") if task_info is not None: - if task_info.get('extracted_text'): + if task_info.get("extracted_text"): if not extracted_text_printed_once: extracted_text_printed_once = True - print("Extracted text: " + task_info.get('extracted_text')) + print("Extracted text: " + task_info.get("extracted_text")) else: - del task_info['extracted_text'] - del task_info['start_time'] + del task_info["extracted_text"] + del task_info["start_time"] print(result) if response.status_code == 200: - if result['state'] == 'SUCCESS': - return result['result'] - elif result['state'] == 'FAILURE': + if result["state"] == "SUCCESS": + return result["result"] + elif result["state"] == "FAILURE": print("OCR task failed.") return None time.sleep(2) # Wait for 2 seconds before checking again + def clear_cache(): - clear_cache_url = os.getenv('CLEAR_CACHE_URL', 'http://localhost:8000/ocr/clear_cache') + clear_cache_url = os.getenv( + "CLEAR_CACHE_URL", "http://localhost:8000/ocr/clear_cache" + ) response = requests.post(clear_cache_url) if response.status_code == 200: print("OCR cache cleared successfully.") else: print(f"Failed to clear OCR cache: {response.text}") -def llm_pull(model = 'llama3.1'): + +def llm_pull(model="llama3.1"): response = pull(model, stream=True) for chunk in response: - if chunk.completed and chunk.total: - print(f'Please wait .... {model} - {chunk.status} - {math.floor((chunk.completed / chunk.total) * 100)}% completed') - else: - print(f'Pulling {model} - {chunk.status}') + if chunk.completed and chunk.total: + print( + f"Please wait .... {model} - {chunk.status} - {math.floor((chunk.completed / chunk.total) * 100)}% completed" + ) + else: + print(f"Pulling {model} - {chunk.status}") + -def llm_generate(prompt, model = 'llama3.1'): - ollama_url = os.getenv('LLM_GENERATE_API_URL', 'http://localhost:8000/llm/generate') +def llm_generate(prompt, model="llama3.1"): + ollama_url = os.getenv("LLM_GENERATE_API_URL", "http://localhost:8000/llm/generate") response = requests.post(ollama_url, json={"model": model, "prompt": prompt}) if response.status_code == 200: - print(response.json().get('generated_text')) + print(response.json().get("generated_text")) else: print(f"Failed to generate text: {response.text}") + def list_files(storage_profile): - list_files_url = os.getenv('LIST_FILES_URL', 'http://localhost:8000/storage/list') - response = requests.get(list_files_url, params={'storage_profile': storage_profile}) + list_files_url = os.getenv("LIST_FILES_URL", "http://localhost:8000/storage/list") + response = requests.get(list_files_url, params={"storage_profile": storage_profile}) if response.status_code == 200: - files = response.json().get('files', []) + files = response.json().get("files", []) for file in files: print(file) else: - print(f"Failed to list files: {response.text}") + print(f"Failed to list files: {response.text}") + def load_file(file_name, storage_profile): - load_file_url = os.getenv('LOAD_FILE_URL', 'http://localhost:8000/storage/load') - response = requests.get(load_file_url, params={'file_name': file_name, 'storage_profile': storage_profile}) + load_file_url = os.getenv("LOAD_FILE_URL", "http://localhost:8000/storage/load") + response = requests.get( + load_file_url, + params={"file_name": file_name, "storage_profile": storage_profile}, + ) if response.status_code == 200: - content = response.json().get('content', '') + content = response.json().get("content", "") print(content) else: print(f"Failed to load file: {response.text}") + def delete_file(file_name, storage_profile): - delete_file_url = os.getenv('DELETE_FILE_URL', 'http://localhost:8000/storage/delete') - response = requests.delete(delete_file_url, params={'file_name': file_name, 'storage_profile': storage_profile}) + delete_file_url = os.getenv( + "DELETE_FILE_URL", "http://localhost:8000/storage/delete" + ) + response = requests.delete( + delete_file_url, + params={"file_name": file_name, "storage_profile": storage_profile}, + ) if response.status_code == 200: print(f"File {file_name} deleted successfully.") else: print(f"Failed to delete file: {response.text}") + def main(): parser = argparse.ArgumentParser(description="CLI for OCR and Ollama operations.") - subparsers = parser.add_subparsers(dest='command', help='Sub-command help') + subparsers = parser.add_subparsers(dest="command", help="Sub-command help") # Sub-command for uploading a file via file upload - ocr_parser = subparsers.add_parser('ocr_upload', help='Upload a file to the OCR endpoint and get the result.') - ocr_parser.add_argument('--file', type=str, default='examples/rmi-example.pdf', help='Path to the file to upload') - ocr_parser.add_argument('--ocr_cache', default=True, action='store_true', help='Enable OCR result caching') - ocr_parser.add_argument('--disable_ocr_cache', default=False, action='store_true', help='Disable OCR result caching') - ocr_parser.add_argument('--prompt', type=str, default=None, help='Prompt used for the Ollama model to fix or transform the file') - ocr_parser.add_argument('--prompt_file', default=None, type=str, help='Prompt file name used for the Ollama model to fix or transform the file') - ocr_parser.add_argument('--model', type=str, default='llama3.1', help='Model to use for the Ollama endpoint') - ocr_parser.add_argument('--strategy', type=str, default='llama_vision', help='OCR strategy to use for the file') - ocr_parser.add_argument('--print_progress', default=True, action='store_true', help='Print the progress of the OCR task') - ocr_parser.add_argument('--storage_profile', type=str, default='default', help='Storage profile to use for the file') - ocr_parser.add_argument('--storage_filename', type=str, default=None, help='Storage filename to use for the file. You may use some formatting - see the docs') - ocr_parser.add_argument('--language', type=str, default='en', help='Language to use for the OCR task') - #ocr_parser.add_argument('--async_mode', action='store_true', help='Enable async mode for the OCR task') + ocr_parser = subparsers.add_parser( + "ocr_upload", help="Upload a file to the OCR endpoint and get the result." + ) + ocr_parser.add_argument( + "--file", + type=str, + default="examples/rmi-example.pdf", + help="Path to the file to upload", + ) + ocr_parser.add_argument( + "--ocr_cache", + default=True, + action="store_true", + help="Enable OCR result caching", + ) + ocr_parser.add_argument( + "--disable_ocr_cache", + default=False, + action="store_true", + help="Disable OCR result caching", + ) + ocr_parser.add_argument( + "--prompt", + type=str, + default=None, + help="Prompt used for the Ollama model to fix or transform the file", + ) + ocr_parser.add_argument( + "--prompt_file", + default=None, + type=str, + help="Prompt file name used for the Ollama model to fix or transform the file", + ) + ocr_parser.add_argument( + "--model", + type=str, + default="llama3.1", + help="Model to use for the Ollama endpoint", + ) + ocr_parser.add_argument( + "--strategy", + type=str, + default="llama_vision", + help="OCR strategy to use for the file", + ) + ocr_parser.add_argument( + "--print_progress", + default=True, + action="store_true", + help="Print the progress of the OCR task", + ) + ocr_parser.add_argument( + "--storage_profile", + type=str, + default="default", + help="Storage profile to use for the file", + ) + ocr_parser.add_argument( + "--storage_filename", + type=str, + default=None, + help="Storage filename to use for the file. You may use some formatting - see the docs", + ) + ocr_parser.add_argument( + "--language", type=str, default="en", help="Language to use for the OCR task" + ) + # ocr_parser.add_argument('--async_mode', action='store_true', help='Enable async mode for the OCR task') # Sub-command for uploading a file via file upload - @deprecated - it's a backward compatibility gimmick - ocr_parser = subparsers.add_parser('ocr', help='Upload a file to the OCR endpoint and get the result.') - ocr_parser.add_argument('--file', type=str, default='examples/rmi-example.pdf', help='Path to the file to upload') - ocr_parser.add_argument('--ocr_cache', default=True, action='store_true', help='Enable OCR result caching') - ocr_parser.add_argument('--disable_ocr_cache', default=False, action='store_true', help='Disable OCR result caching') - ocr_parser.add_argument('--prompt', type=str, default=None, help='Prompt used for the Ollama model to fix or transform the file') - ocr_parser.add_argument('--prompt_file', default=None, type=str, help='Prompt file name used for the Ollama model to fix or transform the file') - ocr_parser.add_argument('--model', type=str, default='llama3.1', help='Model to use for the Ollama endpoint') - ocr_parser.add_argument('--strategy', type=str, default='llama_vision', help='OCR strategy to use for the file') - ocr_parser.add_argument('--print_progress', default=True, action='store_true', help='Print the progress of the OCR task') - ocr_parser.add_argument('--storage_profile', type=str, default='default', help='Storage profile to use for the file') - ocr_parser.add_argument('--storage_filename', type=str, default=None, help='Storage filename to use for the file. You may use some formatting - see the docs') - ocr_parser.add_argument('--language', type=str, default='en', help='Language to use for the OCR task') - #ocr_parser.add_argument('--async_mode', action='store_true', help='Enable async mode for the OCR task') - + ocr_parser = subparsers.add_parser( + "ocr", help="Upload a file to the OCR endpoint and get the result." + ) + ocr_parser.add_argument( + "--file", + type=str, + default="examples/rmi-example.pdf", + help="Path to the file to upload", + ) + ocr_parser.add_argument( + "--ocr_cache", + default=True, + action="store_true", + help="Enable OCR result caching", + ) + ocr_parser.add_argument( + "--disable_ocr_cache", + default=False, + action="store_true", + help="Disable OCR result caching", + ) + ocr_parser.add_argument( + "--prompt", + type=str, + default=None, + help="Prompt used for the Ollama model to fix or transform the file", + ) + ocr_parser.add_argument( + "--prompt_file", + default=None, + type=str, + help="Prompt file name used for the Ollama model to fix or transform the file", + ) + ocr_parser.add_argument( + "--model", + type=str, + default="llama3.1", + help="Model to use for the Ollama endpoint", + ) + ocr_parser.add_argument( + "--strategy", + type=str, + default="llama_vision", + help="OCR strategy to use for the file", + ) + ocr_parser.add_argument( + "--print_progress", + default=True, + action="store_true", + help="Print the progress of the OCR task", + ) + ocr_parser.add_argument( + "--storage_profile", + type=str, + default="default", + help="Storage profile to use for the file", + ) + ocr_parser.add_argument( + "--storage_filename", + type=str, + default=None, + help="Storage filename to use for the file. You may use some formatting - see the docs", + ) + ocr_parser.add_argument( + "--language", type=str, default="en", help="Language to use for the OCR task" + ) + # ocr_parser.add_argument('--async_mode', action='store_true', help='Enable async mode for the OCR task') # Sub-command for uploading a file via request JSON - ocr_request_parser = subparsers.add_parser('ocr_request', help='Upload a file to the OCR endpoint via JSON and get the result.') - ocr_request_parser.add_argument('--file', type=str, default='examples/rmi-example.pdf', help='Path to the file to upload') - ocr_request_parser.add_argument('--ocr_cache', default=True, action='store_true', help='Enable OCR result caching') - ocr_request_parser.add_argument('--disable_ocr_cache', default=False, action='store_true', help='Disable OCR result caching') - ocr_request_parser.add_argument('--prompt', type=str, default=None, help='Prompt used for the Ollama model to fix or transform the file') - ocr_request_parser.add_argument('--prompt_file', default=None, type=str, help='Prompt file name used for the Ollama model to fix or transform the file') - ocr_request_parser.add_argument('--model', type=str, default='llama3.1', help='Model to use for the Ollama endpoint') - ocr_request_parser.add_argument('--strategy', type=str, default='llama_vision', help='OCR strategy to use') - ocr_request_parser.add_argument('--print_progress', default=True, action='store_true', help='Print the progress of the OCR task') - ocr_request_parser.add_argument('--storage_profile', type=str, default='default', help='Storage profile to use. You may use some formatting - see the docs') - ocr_request_parser.add_argument('--storage_filename', type=str, default=None, help='Storage filename to use') - ocr_request_parser.add_argument('--language', type=str, default='en', help='Language to use for the OCR task') + ocr_request_parser = subparsers.add_parser( + "ocr_request", + help="Upload a file to the OCR endpoint via JSON and get the result.", + ) + ocr_request_parser.add_argument( + "--file", + type=str, + default="examples/rmi-example.pdf", + help="Path to the file to upload", + ) + ocr_request_parser.add_argument( + "--ocr_cache", + default=True, + action="store_true", + help="Enable OCR result caching", + ) + ocr_request_parser.add_argument( + "--disable_ocr_cache", + default=False, + action="store_true", + help="Disable OCR result caching", + ) + ocr_request_parser.add_argument( + "--prompt", + type=str, + default=None, + help="Prompt used for the Ollama model to fix or transform the file", + ) + ocr_request_parser.add_argument( + "--prompt_file", + default=None, + type=str, + help="Prompt file name used for the Ollama model to fix or transform the file", + ) + ocr_request_parser.add_argument( + "--model", + type=str, + default="llama3.1", + help="Model to use for the Ollama endpoint", + ) + ocr_request_parser.add_argument( + "--strategy", type=str, default="llama_vision", help="OCR strategy to use" + ) + ocr_request_parser.add_argument( + "--print_progress", + default=True, + action="store_true", + help="Print the progress of the OCR task", + ) + ocr_request_parser.add_argument( + "--storage_profile", + type=str, + default="default", + help="Storage profile to use. You may use some formatting - see the docs", + ) + ocr_request_parser.add_argument( + "--storage_filename", type=str, default=None, help="Storage filename to use" + ) + ocr_request_parser.add_argument( + "--language", type=str, default="en", help="Language to use for the OCR task" + ) # Sub-command for getting the result - result_parser = subparsers.add_parser('result', help='Get the OCR result by specified task id.') - result_parser.add_argument('--task_id', type=str, help='Task Id returned by the upload command') - result_parser.add_argument('--print_progress', default=True, action='store_true', help='Print the progress of the OCR task') - - # Sub-command for clearing the cache - clear_cache_parser = subparsers.add_parser('clear_cache', help='Clear the OCR result cache') + result_parser = subparsers.add_parser( + "result", help="Get the OCR result by specified task id." + ) + result_parser.add_argument( + "--task_id", type=str, help="Task Id returned by the upload command" + ) + result_parser.add_argument( + "--print_progress", + default=True, + action="store_true", + help="Print the progress of the OCR task", + ) # Sub-command for running Ollama - ollama_parser = subparsers.add_parser('llm_generate', help='Run the Ollama endpoint') - ollama_parser.add_argument('--prompt', type=str, required=True, help='Prompt for the Ollama model') - ollama_parser.add_argument('--model', type=str, default='llama3.1', help='Model to use for the Ollama endpoint') + ollama_parser = subparsers.add_parser( + "llm_generate", help="Run the Ollama endpoint" + ) + ollama_parser.add_argument( + "--prompt", type=str, required=True, help="Prompt for the Ollama model" + ) + ollama_parser.add_argument( + "--model", + type=str, + default="llama3.1", + help="Model to use for the Ollama endpoint", + ) - ollama_pull_parser = subparsers.add_parser('llm_pull', help='Pull the latest Llama model from the Ollama API') - ollama_pull_parser.add_argument('--model', type=str, default='llama3.1', help='Model to pull from the Ollama API') + ollama_pull_parser = subparsers.add_parser( + "llm_pull", help="Pull the latest Llama model from the Ollama API" + ) + ollama_pull_parser.add_argument( + "--model", + type=str, + default="llama3.1", + help="Model to pull from the Ollama API", + ) # Sub-command for listing files - list_files_parser = subparsers.add_parser('list_files', help='List files using the selected storage profile') - list_files_parser.add_argument('--storage_profile', type=str, default='default', help='Storage profile to use') + list_files_parser = subparsers.add_parser( + "list_files", help="List files using the selected storage profile" + ) + list_files_parser.add_argument( + "--storage_profile", type=str, default="default", help="Storage profile to use" + ) # Sub-command for loading a file - load_file_parser = subparsers.add_parser('load_file', help='Load a file using the selected storage profile') - load_file_parser.add_argument('--file_name', type=str, required=True, help='Name of the file to load') - load_file_parser.add_argument('--storage_profile', type=str, default='default', help='Storage profile to use') + load_file_parser = subparsers.add_parser( + "load_file", help="Load a file using the selected storage profile" + ) + load_file_parser.add_argument( + "--file_name", type=str, required=True, help="Name of the file to load" + ) + load_file_parser.add_argument( + "--storage_profile", type=str, default="default", help="Storage profile to use" + ) # Sub-command for deleting a file - delete_file_parser = subparsers.add_parser('delete_file', help='Delete a file using the selected storage profile') - delete_file_parser.add_argument('--file_name', type=str, required=True, help='Name of the file to delete') - delete_file_parser.add_argument('--storage_profile', type=str, default='default', help='Storage profile to use') + delete_file_parser = subparsers.add_parser( + "delete_file", help="Delete a file using the selected storage profile" + ) + delete_file_parser.add_argument( + "--file_name", type=str, required=True, help="Name of the file to delete" + ) + delete_file_parser.add_argument( + "--storage_profile", type=str, default="default", help="Storage profile to use" + ) args = parser.parse_args() - if args.command == 'ocr' or args.command == 'ocr_upload': + if args.command == "ocr" or args.command == "ocr_upload": print(args) - result = ocr_upload(args.file, False if args.disable_ocr_cache else args.ocr_cache, args.prompt, args.prompt_file, args.model, args.strategy, args.storage_profile, args.storage_filename, args.language) + result = ocr_upload( + args.file, + False if args.disable_ocr_cache else args.ocr_cache, + args.prompt, + args.prompt_file, + args.model, + args.strategy, + args.storage_profile, + args.storage_filename, + args.language, + ) if result is None: print("Error uploading file.") return - if result.get('text'): - print(result.get('text')) + if result.get("text"): + print(result.get("text")) elif result: - print("File uploaded successfully. Task Id: " + result.get('task_id') + " Waiting for the result...") - text_result = get_result(result.get('task_id'), args.print_progress) + print( + "File uploaded successfully. Task Id: " + + result.get("task_id") + + " Waiting for the result..." + ) + text_result = get_result(result.get("task_id"), args.print_progress) if text_result: print(text_result) - elif args.command == 'ocr_request': - result = ocr_request(args.file, False if args.disable_ocr_cache else args.ocr_cache, args.prompt, args.prompt_file, args.model, args.strategy, args.storage_profile, args.storage_filename, args.language) + elif args.command == "ocr_request": + result = ocr_request( + args.file, + False if args.disable_ocr_cache else args.ocr_cache, + args.prompt, + args.prompt_file, + args.model, + args.strategy, + args.storage_profile, + args.storage_filename, + args.language, + ) if result is None: print("Error uploading file.") return - if result.get('text'): - print(result.get('text')) + if result.get("text"): + print(result.get("text")) elif result: - print("File uploaded successfully. Task Id: " + result.get('task_id') + " Waiting for the result...") - text_result = get_result(result.get('task_id'), args.print_progress) + print( + "File uploaded successfully. Task Id: " + + result.get("task_id") + + " Waiting for the result..." + ) + text_result = get_result(result.get("task_id"), args.print_progress) if text_result: print(text_result) - elif args.command == 'result': + elif args.command == "result": text_result = get_result(args.task_id, args.print_progress) if text_result: print(text_result) - elif args.command == 'clear_cache': + elif args.command == "clear_cache": clear_cache() - elif args.command == 'llm_generate': + elif args.command == "llm_generate": llm_generate(args.prompt, args.model) - elif args.command == 'llm_pull': + elif args.command == "llm_pull": llm_pull(args.model) - elif args.command == 'list_files': - list_files(args.storage_profile) - elif args.command == 'load_file': + elif args.command == "list_files": + list_files(args.storage_profile) + elif args.command == "load_file": load_file(args.file_name, args.storage_profile) - elif args.command == 'delete_file': - delete_file(args.file_name, args.storage_profile) + elif args.command == "delete_file": + delete_file(args.file_name, args.storage_profile) else: parser.print_help() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/client/requirements.txt b/client/requirements.txt index cd6b55d..c4c3553 100644 --- a/client/requirements.txt +++ b/client/requirements.txt @@ -1,2 +1,2 @@ requests -argparse \ No newline at end of file +argparse diff --git a/docker-compose.yml b/docker-compose.yml index ee0fd57..3d8bc19 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -36,7 +36,7 @@ services: - CELERY_BROKER_URL=${CELERY_BROKER_URL-redis://redis:6379/0} - CELERY_RESULT_BACKEND=${CELERY_RESULT_BACKEND-redis://redis:6379/0} - STORAGE_PROFILE_PATH=${STORAGE_PROFILE_PATH-/app/storage_profiles} # Add the storage profile path - - LIST_FILES_URL=${LIST_FILES_URL-http://localhost:8000/storage/list} + - LIST_FILES_URL=${LIST_FILES_URL-http://localhost:8000/storage/list} - LOAD_FILE_URL=${LOAD_FILE_URL-http://localhost:8000/storage/load} - DELETE_FILE_URL=${DELETE_FILE_URL-http://localhost:8000/storage/delete} depends_on: @@ -61,4 +61,3 @@ services: interval: 30s timeout: 10s retries: 3 - diff --git a/examples/example-invoice-remove-pii.md b/examples/example-invoice-remove-pii.md index 17f1724..4ca966d 100644 --- a/examples/example-invoice-remove-pii.md +++ b/examples/example-invoice-remove-pii.md @@ -4,8 +4,8 @@ Invoice For John Doe 2048 Michigan Str Adress Line 2 601 Chicago, US ## Subject -From Acme Invoice Ltd Darrow Street 2 E1 7AW Portsoken London Invoice ID -INV/S/24/2024 17/09/2024 Issue Date PO Number 11/10/2024 Due Date +From Acme Invoice Ltd Darrow Street 2 E1 7AW Portsoken London Invoice ID +INV/S/24/2024 17/09/2024 Issue Date PO Number 11/10/2024 Due Date | Amount | | | | |-----------------------|------------|--------|-------| @@ -35,7 +35,7 @@ Invoice For John Doe ANONYMIZED ## Subject From Acme Invoice Ltd Darrow Street 2 E1 7AW Portsoken London -Invoice ID INV/S/24/2024 17/09/2024 Issue Date PO Number 11/10/2024 Due Date +Invoice ID INV/S/24/2024 17/09/2024 Issue Date PO Number 11/10/2024 Due Date | Amount | | | | |-----------------------|------------|--------|-------| diff --git a/examples/example-invoice-remove-pii.txt b/examples/example-invoice-remove-pii.txt index 3a1412e..6d29c44 100644 --- a/examples/example-invoice-remove-pii.txt +++ b/examples/example-invoice-remove-pii.txt @@ -1,5 +1,5 @@ Below is th text from PDF document after OCR. Fix the text and spelling issues. -Remove Any personal information (like name, first name, last name, address, street, phone number email) replacing it with "ANONYMIZED". +Remove Any personal information (like name, first name, last name, address, street, phone number email) replacing it with "ANONYMIZED". Convert to JSON using the following schema: @@ -86,4 +86,3 @@ Convert to JSON using the following schema: Input text: - diff --git a/examples/example-mri-2-json-prompt.txt b/examples/example-mri-2-json-prompt.txt index 591d049..8501815 100644 --- a/examples/example-mri-2-json-prompt.txt +++ b/examples/example-mri-2-json-prompt.txt @@ -44,4 +44,4 @@ Convert the text to JSON format according to this schema: } ``` -Return only JSON object. \ No newline at end of file +Return only JSON object. diff --git a/examples/example-mri-remove-pii.txt b/examples/example-mri-remove-pii.txt index b3648e9..1f82e04 100644 --- a/examples/example-mri-remove-pii.txt +++ b/examples/example-mri-remove-pii.txt @@ -1,3 +1,2 @@ Below is th text from PDF document after OCR. Fix the text and spelling issues. -Remove Any personal information (like name, first name, last name, address, street, phone number email) replacing it with "ANONYMIZED". - +Remove Any personal information (like name, first name, last name, address, street, phone number email) replacing it with "ANONYMIZED". diff --git a/logs/.gitignore b/logs/.gitignore index 8cee1e4..ddad3d9 100644 --- a/logs/.gitignore +++ b/logs/.gitignore @@ -1,2 +1,2 @@ .* -* \ No newline at end of file +* diff --git a/pyproject.toml b/pyproject.toml index 4883325..e9edf79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,26 +47,19 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest", - "black", - "isort", - "flake8", + "ruff", + "pre-commit", ] -[tool.black] +[tool.ruff] line-length = 88 - -[tool.isort] -profile = "black" -multi_line_output = 3 +fix = true +exclude = ["venv", "__pycache__", ".git", ".mypy_cache"] [tool.pytest] testpaths = ["tests"] python_files = ["test_*.py"] -[tool.flake8] -max-line-length = 88 -exclude = ["venv", "__pycache__", ".git", ".mypy_cache"] - [tool.setuptools.packages.find] where = ["text_extract_api"] -include = ["text_extract_api.*"] \ No newline at end of file +include = ["text_extract_api.*"] diff --git a/run.sh b/run.sh index 2f4af50..4fb60fb 100755 --- a/run.sh +++ b/run.sh @@ -86,4 +86,4 @@ else "$CELERY_BIN" -A text_extract_api.celery_app worker --loglevel=debug --pool=solo & uvicorn text_extract_api.main:app --host 0.0.0.0 --port 8000 --reload ) -fi \ No newline at end of file +fi diff --git a/storage_profiles/default.yaml b/storage_profiles/default.yaml index b3935ad..a87e530 100644 --- a/storage_profiles/default.yaml +++ b/storage_profiles/default.yaml @@ -2,4 +2,4 @@ strategy: local_filesystem settings: root_path: ./storage # The root path where the files will be stored - mount a proper folder in the docker file to match it subfolder_names_format: "" # eg: by_months/{Y}-{mm}/ - create_subfolders: true \ No newline at end of file + create_subfolders: true diff --git a/storage_profiles/gdrive.yaml b/storage_profiles/gdrive.yaml index 4cb761d..c0fcf2e 100644 --- a/storage_profiles/gdrive.yaml +++ b/storage_profiles/gdrive.yaml @@ -2,4 +2,4 @@ strategy: google_drive settings: ## how to enable GDrive API: https://developers.google.com/drive/api/quickstart/python?hl=pl service_account_file: /storage/gdrive_service_account.json - folder_id: + folder_id: diff --git a/tests/text_extract_api/files/converters/test_converter.py b/tests/text_extract_api/files/converters/test_converter.py index 8a1a7a6..430aae1 100644 --- a/tests/text_extract_api/files/converters/test_converter.py +++ b/tests/text_extract_api/files/converters/test_converter.py @@ -6,24 +6,30 @@ class TestConverter(unittest.TestCase): - - @patch("text_extract_api.files.converters.converter.Converter.convert", - return_value=iter(["page1", "page2", "page3"])) + @patch( + "text_extract_api.files.converters.converter.Converter.convert", + return_value=iter(["page1", "page2", "page3"]), + ) def test_convert_to_list(self, mock_convert): file_format = MagicMock(spec=FileFormat) result = Converter.convert_to_list(file_format) self.assertEqual(result, ["page1", "page2", "page3"]) mock_convert.assert_called_once_with(file_format) - @patch("text_extract_api.files.converters.converter.Converter.convert", - return_value=iter(["page1", "page2", "page3"])) + @patch( + "text_extract_api.files.converters.converter.Converter.convert", + return_value=iter(["page1", "page2", "page3"]), + ) def test_convert_force_single(self, mock_convert): file_format = MagicMock(spec=FileFormat) result = Converter.convert_force_single(file_format) self.assertEqual(result, "page1") mock_convert.assert_called_once_with(file_format) - @patch("text_extract_api.files.converters.converter.Converter.convert", side_effect=NotImplementedError) + @patch( + "text_extract_api.files.converters.converter.Converter.convert", + side_effect=NotImplementedError, + ) def test_convert_not_implemented(self): file_format = MagicMock(spec=FileFormat) with self.assertRaises(NotImplementedError): diff --git a/text_extract_api/celery_app.py b/text_extract_api/celery_app.py index 0f8be30..873925c 100644 --- a/text_extract_api/celery_app.py +++ b/text_extract_api/celery_app.py @@ -1,5 +1,6 @@ import pathlib import sys +import multiprocessing from celery import Celery from dotenv import load_dotenv @@ -8,17 +9,12 @@ load_dotenv(".env") -import multiprocessing multiprocessing.set_start_method("spawn", force=True) app = Celery( - "text_extract_api", - broker="redis://redis:6379/0", - backend="redis://redis:6379/0" + "text_extract_api", broker="redis://redis:6379/0", backend="redis://redis:6379/0" ) -app.config_from_object({ - "worker_max_memory_per_child": 8200000 -}) +app.config_from_object({"worker_max_memory_per_child": 8200000}) -app.autodiscover_tasks(["text_extract_api.extract"], 'tasks', True) +app.autodiscover_tasks(["text_extract_api.extract"], "tasks", True) diff --git a/text_extract_api/extract/extract_result.py b/text_extract_api/extract/extract_result.py index 8d73d2d..ad0cb11 100644 --- a/text_extract_api/extract/extract_result.py +++ b/text_extract_api/extract/extract_result.py @@ -3,24 +3,22 @@ """ IMPORTANT INFORMATION ABOUT THIS CLASS: -This is not the final version of the object, namespace, or intended use. +This is not the final version of the object, namespace, or intended use. -For this reason, I am not creating an interface, etc. Add code here as soon as possible -along with further integrations, and once we have gained sufficient experience, we will +For this reason, I am not creating an interface, etc. Add code here as soon as possible +along with further integrations, and once we have gained sufficient experience, we will undertake a refactor. -Currently, the object's purpose is to replace the use of a primitive type, a string, for -extract returns. The limitation of this approach became evident when returning only the -resulting string caused us to lose valuable metadata about the document. Thanks to this -class, we retain DoclingDocument and foresee that other converters/OCRs may have similar +Currently, the object's purpose is to replace the use of a primitive type, a string, for +extract returns. The limitation of this approach became evident when returning only the +resulting string caused us to lose valuable metadata about the document. Thanks to this +class, we retain DoclingDocument and foresee that other converters/OCRs may have similar metadata. """ + + class ExtractResult: - def __init__( - self, - value: Any, - text_gatherer: Callable[[Any], str] = None - ): + def __init__(self, value: Any, text_gatherer: Callable[[Any], str] = None): """ Initializes a UnifiedText instance. @@ -48,16 +46,20 @@ def __init__( """ if text_gatherer is not None and not callable(text_gatherer): - raise ValueError("The `text_gatherer` provided to UnifiedText must be a callable.") + raise ValueError( + "The `text_gatherer` provided to UnifiedText must be a callable." + ) if not isinstance(value, str) and not callable(text_gatherer): - raise ValueError("If `value` is not a string, `text_gatherer` must be provided.") + raise ValueError( + "If `value` is not a string, `text_gatherer` must be provided." + ) self.value = value self.text_gatherer = text_gatherer or self._default_text_gatherer @staticmethod - def from_text(value: str) -> 'ExtractResult': + def from_text(value: str) -> "ExtractResult": return ExtractResult(value) @property @@ -87,4 +89,4 @@ def _default_text_gatherer(value: Any) -> str: """ if isinstance(value, str): return value - raise TypeError("Default text gatherer only supports strings.") \ No newline at end of file + raise TypeError("Default text gatherer only supports strings.") diff --git a/text_extract_api/extract/strategies/docling.py b/text_extract_api/extract/strategies/docling.py index 03faf99..2ab5e67 100644 --- a/text_extract_api/extract/strategies/docling.py +++ b/text_extract_api/extract/strategies/docling.py @@ -7,7 +7,8 @@ from text_extract_api.extract.extract_result import ExtractResult from text_extract_api.extract.strategies.strategy import Strategy -from text_extract_api.files.file_formats import FileFormat, PdfFileFormat +from text_extract_api.files.file_formats import FileFormat + class DoclingStrategy(Strategy): """ diff --git a/text_extract_api/extract/strategies/easyocr.py b/text_extract_api/extract/strategies/easyocr.py index 90dff36..35cd296 100644 --- a/text_extract_api/extract/strategies/easyocr.py +++ b/text_extract_api/extract/strategies/easyocr.py @@ -14,17 +14,18 @@ class EasyOCRStrategy(Strategy): def name(cls) -> str: return "easyOCR" - def extract_text(self, file_format: FileFormat, language: str = 'en') -> ExtractResult: + def extract_text( + self, file_format: FileFormat, language: str = "en" + ) -> ExtractResult: """ Extract text using EasyOCR after converting the input file to images - (if not already an ImageFileFormat). + (if not already an ImageFileFormat). """ # Ensure we can actually convert the input file to ImageFileFormat - if ( - not isinstance(file_format, ImageFileFormat) - and not file_format.can_convert_to(ImageFileFormat) - ): + if not isinstance( + file_format, ImageFileFormat + ) and not file_format.can_convert_to(ImageFileFormat): raise TypeError( f"EasyOCR - format {file_format.mime_type} is not supported (yet?)" ) @@ -34,19 +35,21 @@ def extract_text(self, file_format: FileFormat, language: str = 'en') -> Extract # Initialize the EasyOCR Reader # Add or change languages to your needs, e.g., ['en', 'fr'] - reader = easyocr.Reader(language.split(',')) + reader = easyocr.Reader(language.split(",")) # Process each image, extracting text all_extracted_text = [] for image_format in images: # Convert the in-memory bytes to a PIL Image pil_image = Image.open(io.BytesIO(image_format.binary)) - + # Convert PIL image to numpy array for EasyOCR np_image = np.array(pil_image) # Perform OCR; with `detail=0`, we get just text, no bounding boxes - ocr_result = reader.readtext(np_image, detail=0) # TODO: addd bounding boxes support as described in #37 + ocr_result = reader.readtext( + np_image, detail=0 + ) # TODO: addd bounding boxes support as described in #37 # Combine all lines into a single string for that image/page extracted_text = "\n".join(ocr_result) @@ -55,5 +58,4 @@ def extract_text(self, file_format: FileFormat, language: str = 'en') -> Extract # Join text from all images/pages full_text = "\n\n".join(all_extracted_text) - return ExtractResult.from_text(full_text) diff --git a/text_extract_api/extract/strategies/ollama.py b/text_extract_api/extract/strategies/ollama.py index 8bd68d6..734a2ce 100644 --- a/text_extract_api/extract/strategies/ollama.py +++ b/text_extract_api/extract/strategies/ollama.py @@ -18,12 +18,12 @@ class OllamaStrategy(Strategy): def name(cls) -> str: return "llama_vision" - def extract_text(self, file_format: FileFormat, language: str = 'en') -> ExtractResult: - - if ( - not isinstance(file_format, ImageFileFormat) - and not file_format.can_convert_to(ImageFileFormat) - ): + def extract_text( + self, file_format: FileFormat, language: str = "en" + ) -> ExtractResult: + if not isinstance( + file_format, ImageFileFormat + ) and not file_format.can_convert_to(ImageFileFormat): raise TypeError( f"Ollama OCR - format {file_format.mime_type} is not supported (yet?)" ) @@ -33,39 +33,56 @@ def extract_text(self, file_format: FileFormat, language: str = 'en') -> Extract ocr_percent_done = 0 num_pages = len(images) for i, image in enumerate(images): - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file: temp_file.write(image.binary) temp_filename = temp_file.name # Generate text using the specified model try: - timeout = httpx.Timeout(connect=180.0, read=180.0, write=180.0, pool=180.0) # @todo move those values to .env + timeout = httpx.Timeout( + connect=180.0, read=180.0, write=180.0, pool=180.0 + ) # @todo move those values to .env ollama = Client(timeout=timeout) - response = ollama.chat(self._strategy_config.get('model'), [{ - 'role': 'user', - 'content': self._strategy_config.get('prompt'), - 'images': [temp_filename] - }], stream=True) + response = ollama.chat( + self._strategy_config.get("model"), + [ + { + "role": "user", + "content": self._strategy_config.get("prompt"), + "images": [temp_filename], + } + ], + stream=True, + ) os.remove(temp_filename) num_chunk = 1 for chunk in response: meta = { - 'progress': str(30 + ocr_percent_done), - 'status': 'OCR Processing' - + '(page ' + str(i + 1) + ' of ' + str(num_pages) + ')' - + ' chunk no: ' + str(num_chunk), - 'start_time': start_time, - 'elapsed_time': time.time() - start_time} - self.update_state_callback(state='PROGRESS', meta=meta) + "progress": str(30 + ocr_percent_done), + "status": "OCR Processing" + + "(page " + + str(i + 1) + + " of " + + str(num_pages) + + ")" + + " chunk no: " + + str(num_chunk), + "start_time": start_time, + "elapsed_time": time.time() - start_time, + } + self.update_state_callback(state="PROGRESS", meta=meta) num_chunk += 1 - extracted_text += chunk['message']['content'] + extracted_text += chunk["message"]["content"] ocr_percent_done += int( - 20 / num_pages) # 20% of work is for OCR - just a stupid assumption from tasks.py + 20 / num_pages + ) # 20% of work is for OCR - just a stupid assumption from tasks.py except ollama.ResponseError as e: - print('Error:', e.error) - raise Exception("Failed to generate text with Ollama model " + self._strategy_config.get('model')) + print("Error:", e.error) + raise Exception( + "Failed to generate text with Ollama model " + + self._strategy_config.get("model") + ) print(response) diff --git a/text_extract_api/extract/strategies/remote.py b/text_extract_api/extract/strategies/remote.py index f692554..bb3b06d 100644 --- a/text_extract_api/extract/strategies/remote.py +++ b/text_extract_api/extract/strategies/remote.py @@ -1,12 +1,10 @@ import os -import tempfile import time from extract.extract_result import ExtractResult from text_extract_api.extract.strategies.strategy import Strategy from text_extract_api.files.file_formats.file_format import FileFormat -from text_extract_api.files.file_formats.image import ImageFileFormat from text_extract_api.files.file_formats.pdf import PdfFileFormat import requests @@ -18,12 +16,12 @@ class RemoteStrategy(Strategy): def name(cls) -> str: return "remote" - def extract_text(self, file_format: FileFormat, language: str = 'en') -> ExtractResult: - - if ( - not isinstance(file_format, PdfFileFormat) - and not file_format.can_convert_to(PdfFileFormat) - ): + def extract_text( + self, file_format: FileFormat, language: str = "en" + ) -> ExtractResult: + if not isinstance( + file_format, PdfFileFormat + ) and not file_format.can_convert_to(PdfFileFormat): raise TypeError( f"Marker PDF - format {file_format.mime_type} is not supported (yet?)" ) @@ -32,40 +30,45 @@ def extract_text(self, file_format: FileFormat, language: str = 'en') -> Extract extracted_text = "" start_time = time.time() ocr_percent_done = 0 - + if len(pdf_files) > 1: raise ValueError("Only one PDF file is supported.") - + if len(pdf_files) == 0: raise ValueError("No PDF file found - conversion error.") - try: + try: url = os.getenv("REMOTE_API_URL", self._strategy_config.get("url")) if not url: - raise Exception('Please do set the REMOTE_API_URL environment variable: export REMOTE_API_URL=http://...') - files = {'file': ('document.pdf', pdf_files[0].binary, 'application/pdf')} + raise Exception( + "Please do set the REMOTE_API_URL environment variable: export REMOTE_API_URL=http://..." + ) + files = {"file": ("document.pdf", pdf_files[0].binary, "application/pdf")} data = { - 'page_range': None, - 'languages': language, - 'force_ocr': False, - 'paginate_output': False, - 'output_format': 'markdown' # TODO: support JSON output format + "page_range": None, + "languages": language, + "force_ocr": False, + "paginate_output": False, + "output_format": "markdown", # TODO: support JSON output format } meta = { - 'progress': str(30 + ocr_percent_done), - 'status': 'OCR Processing', - 'start_time': start_time, - 'elapsed_time': time.time() - start_time} - self.update_state_callback(state='PROGRESS', meta=meta) + "progress": str(30 + ocr_percent_done), + "status": "OCR Processing", + "start_time": start_time, + "elapsed_time": time.time() - start_time, + } + self.update_state_callback(state="PROGRESS", meta=meta) response = requests.post(url, files=files, data=data) if response.status_code != 200: raise Exception(f"Failed to upload PDF file: {response.content}") - extracted_text = response.json().get('output', '') + extracted_text = response.json().get("output", "") except Exception as e: - print('Error:', e) - raise Exception("Failed to generate text with Remote API. Make sure the remote server is up and running") - + print("Error:", e) + raise Exception( + "Failed to generate text with Remote API. Make sure the remote server is up and running" + ) + return ExtractResult.from_text(extracted_text) diff --git a/text_extract_api/extract/strategies/strategy.py b/text_extract_api/extract/strategies/strategy.py index 2a613a2..0c54e81 100644 --- a/text_extract_api/extract/strategies/strategy.py +++ b/text_extract_api/extract/strategies/strategy.py @@ -5,11 +5,11 @@ import pkgutil from typing import Type, Dict -from pydantic.v1.typing import get_class from extract.extract_result import ExtractResult from text_extract_api.files.file_formats.file_format import FileFormat + class Strategy: _strategies: Dict[str, Strategy] = {} _strategy_config: Dict[str, Dict] = {} @@ -33,8 +33,12 @@ def name(cls) -> str: raise NotImplementedError("Strategy subclasses must implement name") @classmethod - def extract_text(cls, file_format: Type["FileFormat"], language: str = 'en') -> ExtractResult: - raise NotImplementedError("Strategy subclasses must implement extract_text method") + def extract_text( + cls, file_format: Type["FileFormat"], language: str = "en" + ) -> ExtractResult: + raise NotImplementedError( + "Strategy subclasses must implement extract_text method" + ) @classmethod def get_strategy(cls, name: str) -> Type["Strategy"]: @@ -58,46 +62,58 @@ def get_strategy(cls, name: str) -> Type["Strategy"]: cls.autodiscover_strategies() if name not in cls._strategies: - available = ', '.join(cls._strategies.keys()) + available = ", ".join(cls._strategies.keys()) raise ValueError(f"Unknown strategy '{name}'. Available: {available}") return cls._strategies[name] @classmethod - def register_strategy(cls, strategy: Type["Strategy"], name: str = None, override: bool = False): + def register_strategy( + cls, strategy: Type["Strategy"], name: str = None, override: bool = False + ): name = name or strategy.name() if override or name not in cls._strategies: cls._strategies[name] = strategy @classmethod - def load_strategies_from_config(cls, path: str = os.getenv('OCR_CONFIG_PATH', 'config/strategies.yaml')): + def load_strategies_from_config( + cls, path: str = os.getenv("OCR_CONFIG_PATH", "config/strategies.yaml") + ): strategies = cls._strategies project_root = os.path.dirname(os.path.dirname(os.path.abspath(path))) config_file_path = os.path.join(project_root, path) if not os.path.isfile(config_file_path): - raise FileNotFoundError(f"Config file not found at path: {config_file_path}") + raise FileNotFoundError( + f"Config file not found at path: {config_file_path}" + ) - with open(config_file_path, 'r') as f: + with open(config_file_path, "r") as f: config = yaml.safe_load(f) - if 'strategies' not in config or not isinstance(config['strategies'], dict): - raise ValueError(f"Missing or invalid 'strategies' section in the {config_file_path} file") + if "strategies" not in config or not isinstance(config["strategies"], dict): + raise ValueError( + f"Missing or invalid 'strategies' section in the {config_file_path} file" + ) - for strategy_name, strategy_config in config['strategies'].items(): - if 'class' not in strategy_config: - raise ValueError(f"Missing 'class' attribute for OCR strategy: {strategy_name}") + for strategy_name, strategy_config in config["strategies"].items(): + if "class" not in strategy_config: + raise ValueError( + f"Missing 'class' attribute for OCR strategy: {strategy_name}" + ) - strategy_class_path = strategy_config['class'] - module_path, class_name = strategy_class_path.rsplit('.', 1) + strategy_class_path = strategy_config["class"] + module_path, class_name = strategy_class_path.rsplit(".", 1) module = importlib.import_module(module_path) strategy = getattr(module, class_name) strategy_instance = strategy() strategy_instance.set_strategy_config(strategy_config) - + cls.register_strategy(strategy_instance, strategy_name) - print(f"Loaded strategy from {config_file_path} {strategy_name} [{strategy_class_path}]") + print( + f"Loaded strategy from {config_file_path} {strategy_name} [{strategy_class_path}]" + ) return strategies @@ -116,27 +132,30 @@ def autodiscover_strategies(cls) -> Dict[str, Type]: if not hasattr(module, "__path__"): continue - for submodule_info in pkgutil.walk_packages(module.__path__, module_info.name + "."): + for submodule_info in pkgutil.walk_packages( + module.__path__, module_info.name + "." + ): if ".strategies." not in submodule_info.name: continue try: ocr_module = importlib.import_module(submodule_info.name) except ImportError as e: - print('Error loading strategy ' + submodule_info.name + ': ' + str(e)) + print( + "Error loading strategy " + submodule_info.name + ": " + str(e) + ) continue for attr_name in dir(ocr_module): attr = getattr(ocr_module, attr_name) - if (isinstance(attr, type) - and issubclass(attr, Strategy) - and attr is not Strategy - and attr.name() not in strategies + if ( + isinstance(attr, type) + and issubclass(attr, Strategy) + and attr is not Strategy + and attr.name() not in strategies ): strategies[attr.name()] = attr() - print(f"Discovered strategy {attr.name()} from {submodule_info.name} [{module_info.name}]") - + print( + f"Discovered strategy {attr.name()} from {submodule_info.name} [{module_info.name}]" + ) cls._strategies = strategies - - - diff --git a/text_extract_api/extract/tasks.py b/text_extract_api/extract/tasks.py index 1d1a77a..ba18e3f 100644 --- a/text_extract_api/extract/tasks.py +++ b/text_extract_api/extract/tasks.py @@ -11,23 +11,23 @@ from text_extract_api.files.storage_manager import StorageManager # Connect to Redis -redis_url = os.getenv('REDIS_CACHE_URL', 'redis://redis:6379/1') +redis_url = os.getenv("REDIS_CACHE_URL", "redis://redis:6379/1") redis_client = redis.StrictRedis.from_url(redis_url) @celery_app.task(bind=True) def ocr_task( - self, - binary_content: bytes, - strategy_name: str, - filename: str, - file_hash: str, - ocr_cache: bool, - prompt: Optional[str] = None, - model: Optional[str] = None, - language: Optional[str] = None, - storage_profile: Optional[str] = None, - storage_filename: Optional[str] = None, + self, + binary_content: bytes, + strategy_name: str, + filename: str, + file_hash: str, + ocr_cache: bool, + prompt: Optional[str] = None, + model: Optional[str] = None, + language: Optional[str] = None, + storage_profile: Optional[str] = None, + storage_filename: Optional[str] = None, ): """ Celery task to perform OCR processing on a PDF/Office/image file. @@ -37,32 +37,47 @@ def ocr_task( strategy = Strategy.get_strategy(strategy_name) strategy.set_update_state_callback(self.update_state) - self.update_state(state='PROGRESS', status="File uploaded successfully", - meta={'progress': 10}) # Example progress update + self.update_state( + state="PROGRESS", status="File uploaded successfully", meta={"progress": 10} + ) # Example progress update extracted_text = None if ocr_cache: cached_result = redis_client.get(file_hash) if cached_result: # Return cached result if available - extracted_text = cached_result.decode('utf-8') + extracted_text = cached_result.decode("utf-8") if extracted_text is None: print(f"Extracting text from file using strategy: {strategy.name()}") - self.update_state(state='PROGRESS', - meta={'progress': 30, 'status': 'Extracting text from file', 'start_time': start_time, - 'elapsed_time': time.time() - start_time}) # Example progress update - extract_result = strategy.extract_text(FileFormat.from_binary(binary_content), language) + self.update_state( + state="PROGRESS", + meta={ + "progress": 30, + "status": "Extracting text from file", + "start_time": start_time, + "elapsed_time": time.time() - start_time, + }, + ) # Example progress update + extract_result = strategy.extract_text( + FileFormat.from_binary(binary_content), language + ) extracted_text = extract_result.text else: print("Using cached result...") print("After extracted text") - self.update_state(state='PROGRESS', - meta={'progress': 50, 'status': 'Text extracted', 'extracted_text': extracted_text, - 'start_time': start_time, - 'elapsed_time': time.time() - start_time}) # Example progress update + self.update_state( + state="PROGRESS", + meta={ + "progress": 50, + "status": "Text extracted", + "extracted_text": extracted_text, + "start_time": start_time, + "elapsed_time": time.time() - start_time, + }, + ) # Example progress update # @todo Universal Text Object - is cache available if ocr_cache: @@ -70,27 +85,46 @@ def ocr_task( if prompt: print(f"Transforming text using LLM (prompt={prompt}, model={model}) ...") - self.update_state(state='PROGRESS', meta={'progress': 75, 'status': 'Processing LLM', 'start_time': start_time, - 'elapsed_time': time.time() - start_time}) # Example progress update + self.update_state( + state="PROGRESS", + meta={ + "progress": 75, + "status": "Processing LLM", + "start_time": start_time, + "elapsed_time": time.time() - start_time, + }, + ) # Example progress update llm_resp = ollama.generate(model, prompt + extracted_text, stream=True) num_chunk = 1 - extracted_text = '' # will be filled with chunks from llm + extracted_text = "" # will be filled with chunks from llm for chunk in llm_resp: - self.update_state(state='PROGRESS', - meta={'progress': num_chunk, 'status': 'LLM Processing chunk no: ' + str(num_chunk), - 'start_time': start_time, - 'elapsed_time': time.time() - start_time}) # Example progress update + self.update_state( + state="PROGRESS", + meta={ + "progress": num_chunk, + "status": "LLM Processing chunk no: " + str(num_chunk), + "start_time": start_time, + "elapsed_time": time.time() - start_time, + }, + ) # Example progress update num_chunk += 1 - extracted_text += chunk['response'] + extracted_text += chunk["response"] if storage_profile: if not storage_filename: - storage_filename = filename.replace('.', '_') + '.pdf' + storage_filename = filename.replace(".", "_") + ".pdf" storage_manager = StorageManager(storage_profile) storage_manager.save(filename, storage_filename, extracted_text) - self.update_state(state='DONE', meta={'progress': 100, 'status': 'Processing done!', 'start_time': start_time, - 'elapsed_time': time.time() - start_time}) + self.update_state( + state="DONE", + meta={ + "progress": 100, + "status": "Processing done!", + "start_time": start_time, + "elapsed_time": time.time() - start_time, + }, + ) return extracted_text diff --git a/text_extract_api/files/converters/converter.py b/text_extract_api/files/converters/converter.py index 7fd54a2..5c90552 100644 --- a/text_extract_api/files/converters/converter.py +++ b/text_extract_api/files/converters/converter.py @@ -2,6 +2,7 @@ from text_extract_api.files.file_formats.file_format import FileFormat + class Converter: @staticmethod def convert(file_format: Type["FileFormat"]) -> Iterator["FileFormat"]: @@ -12,6 +13,8 @@ def convert_to_list(cls, file_format: Type["FileFormat"]) -> List["FileFormat"]: return list(cls.convert(file_format)) @classmethod - def convert_force_single(cls, file_format: Type["FileFormat"]) -> Type["FileFormat"]: - """ Warning - this will return only first page """ + def convert_force_single( + cls, file_format: Type["FileFormat"] + ) -> Type["FileFormat"]: + """Warning - this will return only first page""" return next(cls.convert(file_format), None) diff --git a/text_extract_api/files/converters/image_to_pdf.py b/text_extract_api/files/converters/image_to_pdf.py index e555750..ea64d95 100644 --- a/text_extract_api/files/converters/image_to_pdf.py +++ b/text_extract_api/files/converters/image_to_pdf.py @@ -7,21 +7,18 @@ class ImageToPdfConverter(Converter): - @staticmethod def convert(file_format: ImageFileFormat) -> Iterator[Type["PdfFileFormat"]]: - image = Image.open(BytesIO(file_format.binary)) pdf_bytes = ImageToPdfConverter._image_to_pdf_bytes(image) yield PdfFileFormat.from_binary( binary=pdf_bytes, filename=f"{file_format.filename}.pdf", - mime_type="application/pdf" + mime_type="application/pdf", ) @staticmethod def _image_to_pdf_bytes(image: Image) -> bytes: - buffer = BytesIO() image.save(buffer, format="PDF") - return buffer.getvalue() \ No newline at end of file + return buffer.getvalue() diff --git a/text_extract_api/files/converters/pdf_to_jpeg.py b/text_extract_api/files/converters/pdf_to_jpeg.py index e11b47f..5cc9a51 100644 --- a/text_extract_api/files/converters/pdf_to_jpeg.py +++ b/text_extract_api/files/converters/pdf_to_jpeg.py @@ -6,8 +6,8 @@ from text_extract_api.files.file_formats.image import ImageFileFormat from text_extract_api.files.file_formats.pdf import PdfFileFormat -class PdfToJpegConverter(Converter): +class PdfToJpegConverter(Converter): @staticmethod def convert(file_format: PdfFileFormat) -> Iterator[Type["ImageFileFormat"]]: pages = convert_from_bytes(file_format.binary) @@ -17,7 +17,7 @@ def convert(file_format: PdfFileFormat) -> Iterator[Type["ImageFileFormat"]]: yield ImageFileFormat.from_binary( binary=PdfToJpegConverter._image_to_bytes(page), filename=f"{file_format.filename}_page_{i}.jpg", - mime_type="image/jpeg" + mime_type="image/jpeg", ) @staticmethod diff --git a/text_extract_api/files/file_formats/__init__.py b/text_extract_api/files/file_formats/__init__.py index c6a9fcd..0ce05ed 100644 --- a/text_extract_api/files/file_formats/__init__.py +++ b/text_extract_api/files/file_formats/__init__.py @@ -1,4 +1,3 @@ - ### WARNING ### This file is generated dynamically before git commit. ### Run ./scripts/dev/gen-file-format-init.sh from repository root. @@ -7,3 +6,10 @@ from .docling import DoclingFileFormat from .pdf import PdfFileFormat from .image import ImageFileFormat + +__all__ = [ + "FileFormat", + "DoclingFileFormat", + "PdfFileFormat", + "ImageFileFormat", +] diff --git a/text_extract_api/files/file_formats/docling.py b/text_extract_api/files/file_formats/docling.py index a0672da..0233a11 100644 --- a/text_extract_api/files/file_formats/docling.py +++ b/text_extract_api/files/file_formats/docling.py @@ -35,11 +35,13 @@ def default_iterator_file_format(cls) -> Type[FileFormat]: return cls @staticmethod - def convertible_to() -> Dict[Type["FileFormat"], Callable[[], Iterator["FileFormat"]]]: + def convertible_to() -> Dict[ + Type["FileFormat"], Callable[[], Iterator["FileFormat"]] + ]: # No specific converters needed as the strategy will handle conversion return {} @staticmethod def validate(binary_file_content: bytes): if not binary_file_content or len(binary_file_content) == 0: - raise ValueError("Empty file content") \ No newline at end of file + raise ValueError("Empty file content") diff --git a/text_extract_api/files/file_formats/file_format.py b/text_extract_api/files/file_formats/file_format.py index 8382702..da6bc6e 100644 --- a/text_extract_api/files/file_formats/file_format.py +++ b/text_extract_api/files/file_formats/file_format.py @@ -20,8 +20,12 @@ class FileFormat: # Construction - def __init__(self, binary_file_content: bytes, filename: Optional[str] = None, - mime_type: Optional[str] = None) -> None: + def __init__( + self, + binary_file_content: bytes, + filename: Optional[str] = None, + mime_type: Optional[str] = None, + ) -> None: """ Attributes: binary_file_content (bytes): The binary content of the file. @@ -40,19 +44,27 @@ def __init__(self, binary_file_content: bytes, filename: Optional[str] = None, is provided or defaulted to. """ if not binary_file_content: - raise ValueError(f"{self.__class__.__name__} missing content file - corrupted base64 or binary data.") + raise ValueError( + f"{self.__class__.__name__} missing content file - corrupted base64 or binary data." + ) resolved_mime_type = mime_type or self.DEFAULT_MIME_TYPE if not resolved_mime_type: - raise ValueError(f"{self.__class__.__name__} requires a mime type to be provided or defaulted.") + raise ValueError( + f"{self.__class__.__name__} requires a mime type to be provided or defaulted." + ) self.binary_file_content: bytes = binary_file_content self.filename: str = filename or self.DEFAULT_FILENAME self.mime_type: str = resolved_mime_type @classmethod - def from_base64(cls, base64_string: str, filename: Optional[str] = None, mime_type: Optional[str] = None) -> Type[ - "FileFormat"]: + def from_base64( + cls, + base64_string: str, + filename: Optional[str] = None, + mime_type: Optional[str] = None, + ) -> Type["FileFormat"]: binary = base64.b64decode(base64_string) instance = cls.from_binary(binary, filename=filename, mime_type=mime_type) instance._base64_cache = base64_string @@ -60,26 +72,27 @@ def from_base64(cls, base64_string: str, filename: Optional[str] = None, mime_ty @classmethod def from_binary( - cls, - binary: bytes, - filename: Optional[str] = None, - mime_type: Optional[str] = None + cls, + binary: bytes, + filename: Optional[str] = None, + mime_type: Optional[str] = None, ) -> Type["FileFormat"]: if mime_type == "application/octet-stream": mime_type = None - mime_type = mime_type or FileFormat._guess_mime_type(binary_data=binary, filename=filename) - from text_extract_api.files.file_formats.pdf import PdfFileFormat # type: ignore + mime_type = mime_type or FileFormat._guess_mime_type( + binary_data=binary, filename=filename + ) file_format_class = cls._get_file_format_class(mime_type) - return file_format_class(binary_file_content=binary, filename=filename, mime_type=mime_type) + return file_format_class( + binary_file_content=binary, filename=filename, mime_type=mime_type + ) def __repr__(self) -> str: """ Returns a string representation of the FileFormat instance. """ size = len(self.binary_file_content) - return ( - f"" - ) + return f"" def to_dict(self, encode_base64: bool = False) -> FileFormatDict: """ @@ -104,7 +117,9 @@ def to_dict(self, encode_base64: bool = False) -> FileFormatDict: @property def base64_(self) -> str: if self._base64_cache is None: - self._base64_cache = base64.b64encode(self.binary_file_content).decode('utf-8') + self._base64_cache = base64.b64encode(self.binary_file_content).decode( + "utf-8" + ) return self._base64_cache @property @@ -136,7 +151,9 @@ def iterator(self, target_format: Optional["FileFormat"]) -> Iterator["FileForma if self.is_pageable(): if final_format.is_pageable(): - raise ValueError("Target format and current format are both pageable. Cannot iterate.") + raise ValueError( + "Target format and current format are both pageable. Cannot iterate." + ) else: yield self.convert_to(final_format) else: @@ -164,12 +181,16 @@ def convert_to(self, target_format: Type["FileFormat"]) -> List["FileFormat"]: converters = self.convertible_to() if target_format not in converters: - raise ValueError(f"Cannot convert to {target_format}. Conversion not supported.") + raise ValueError( + f"Cannot convert to {target_format}. Conversion not supported." + ) return list(converters[target_format](self)) @staticmethod - def convertible_to() -> Dict[Type["FileFormat"], Callable[[Type["FileFormat"]], Iterator[Type["Converter"]]]]: + def convertible_to() -> Dict[ + Type["FileFormat"], Callable[[Type["FileFormat"]], Iterator[Type["Converter"]]] # NOQA: F821 + ]: """ Defines what formats this file type can be converted to. Returns a dictionary where keys are target formats and values are functions @@ -177,13 +198,15 @@ def convertible_to() -> Dict[Type["FileFormat"], Callable[[Type["FileFormat"]], :return: A dictionary of convertible formats and their converters. """ - return {} + return {} # To check, Converter is not defined in this file @classmethod def default_iterator_file_format(cls) -> Type["FileFormat"]: if not cls.is_pageable(): return cls - raise NotImplementedError("Pageable formats must implement default_iterator_file_format.") + raise NotImplementedError( + "Pageable formats must implement default_iterator_file_format." + ) def unify(self) -> "FileFormat": """ @@ -199,19 +222,24 @@ def _get_file_format_class(mime_type: str) -> Type["FileFormat"]: import text_extract_api.files.file_formats.pdf # noqa - its not unused import @todo autodiscover import text_extract_api.files.file_formats.image # noqa - its not unused import @todo autodiscover import text_extract_api.files.file_formats.docling # noqa - its not unused import @todo autodiscover + for subclass in FileFormat.__subclasses__(): if mime_type in subclass.accepted_mime_types(): return subclass raise ValueError(f"No matching FileFormat class for mime type: {mime_type}") @staticmethod - def _guess_mime_type(binary_data: Optional[bytes] = None, filename: Optional[str] = None) -> str: + def _guess_mime_type( + binary_data: Optional[bytes] = None, filename: Optional[str] = None + ) -> str: mime = magic.Magic(mime=True) if binary_data: return mime.from_buffer(binary_data) if filename: return mime.from_file(filename) - raise ValueError("Either binary_data or filename must be provided to guess the MIME type.") + raise ValueError( + "Either binary_data or filename must be provided to guess the MIME type." + ) class FileField: diff --git a/text_extract_api/files/file_formats/image.py b/text_extract_api/files/file_formats/image.py index baf636b..e4fef49 100644 --- a/text_extract_api/files/file_formats/image.py +++ b/text_extract_api/files/file_formats/image.py @@ -5,27 +5,29 @@ from text_extract_api.files.file_formats.file_format import FileFormat + class ImageSupportedExportFormats(Enum): JPEG = "JPEG" PNG = "PNG" BMP = "BMP" TIFF = "TIFF" + class ImageFileFormat(FileFormat): DEFAULT_FILENAME: str = "image.jpeg" @staticmethod def accepted_mime_types() -> list[str]: return ["image/jpeg", "image/png", "image/bmp", "image/gif", "image/tiff"] - + @staticmethod - def convertible_to() -> Dict[Type["FileFormat"], Callable[[], Iterator["FileFormat"]]]: + def convertible_to() -> Dict[ + Type["FileFormat"], Callable[[], Iterator["FileFormat"]] + ]: from text_extract_api.files.file_formats.pdf import PdfFileFormat from text_extract_api.files.converters.image_to_pdf import ImageToPdfConverter - return { - PdfFileFormat: ImageToPdfConverter.convert - } + return {PdfFileFormat: ImageToPdfConverter.convert} @staticmethod def is_pageable() -> bool: @@ -36,7 +38,9 @@ def default_iterator_file_format(cls) -> Type["ImageFileFormat"]: return cls def unify(self) -> "FileFormat": - unified_image = ImageProcessor.unify_image(self.binary, ImageSupportedExportFormats.JPEG) + unified_image = ImageProcessor.unify_image( + self.binary, ImageSupportedExportFormats.JPEG + ) return ImageFileFormat.from_binary(unified_image, self.filename, self.mime_type) @staticmethod @@ -47,10 +51,14 @@ def validate(binary_file_content: bytes): except OSError as e: raise ValueError("Corrupted image file content") from e + class ImageProcessor: @staticmethod - def unify_image(image_bytes: bytes, target_format: ImageSupportedExportFormats = "JPEG", - convert_to_rgb: bool = True) -> bytes: + def unify_image( + image_bytes: bytes, + target_format: ImageSupportedExportFormats = "JPEG", + convert_to_rgb: bool = True, + ) -> bytes: """ Prepares an image for OCR by unifying its format and color mode. - Converts image to the desired format (e.g., JPEG). diff --git a/text_extract_api/files/file_formats/pdf.py b/text_extract_api/files/file_formats/pdf.py index 50fb8a3..227c303 100644 --- a/text_extract_api/files/file_formats/pdf.py +++ b/text_extract_api/files/file_formats/pdf.py @@ -17,18 +17,19 @@ def is_pageable() -> bool: @classmethod def default_iterator_file_format(cls) -> Type[FileFormat]: from text_extract_api.files.file_formats.image import ImageFileFormat + return ImageFileFormat @staticmethod - def convertible_to() -> Dict[Type["FileFormat"], Callable[[], Iterator["FileFormat"]]]: + def convertible_to() -> Dict[ + Type["FileFormat"], Callable[[], Iterator["FileFormat"]] + ]: from text_extract_api.files.file_formats.image import ImageFileFormat from text_extract_api.files.converters.pdf_to_jpeg import PdfToJpegConverter - return { - ImageFileFormat: PdfToJpegConverter.convert - } + return {ImageFileFormat: PdfToJpegConverter.convert} @staticmethod def validate(binary_file_content: bytes): - if not binary_file_content.startswith(b'%PDF'): - raise ValueError("Corrupted PDF file") \ No newline at end of file + if not binary_file_content.startswith(b"%PDF"): + raise ValueError("Corrupted PDF file") diff --git a/text_extract_api/files/storage_manager.py b/text_extract_api/files/storage_manager.py index cf86513..a716bbd 100644 --- a/text_extract_api/files/storage_manager.py +++ b/text_extract_api/files/storage_manager.py @@ -4,9 +4,12 @@ import yaml from text_extract_api.files.storage_strategies.aws_s3 import AWSS3StorageStrategy -from text_extract_api.files.storage_strategies.google_drive import GoogleDriveStorageStrategy -from text_extract_api.files.storage_strategies.local_filesystem import LocalFilesystemStorageStrategy -from text_extract_api.files.storage_strategies.storage_strategy import StorageStrategy +from text_extract_api.files.storage_strategies.google_drive import ( + GoogleDriveStorageStrategy, +) +from text_extract_api.files.storage_strategies.local_filesystem import ( + LocalFilesystemStorageStrategy, +) class StorageStrategy(Enum): @@ -17,11 +20,14 @@ class StorageStrategy(Enum): class StorageManager: def __init__(self, profile_name): - profile_path = os.path.join(os.getenv('STORAGE_PROFILE_PATH', '/storage_profiles'), f'{profile_name}.yaml') - with open(profile_path, 'r') as file: + profile_path = os.path.join( + os.getenv("STORAGE_PROFILE_PATH", "/storage_profiles"), + f"{profile_name}.yaml", + ) + with open(profile_path, "r") as file: self.profile = yaml.safe_load(file) - strategy = StorageStrategy(self.profile['strategy']) + strategy = StorageStrategy(self.profile["strategy"]) if strategy == StorageStrategy.LOCAL_FILESYSTEM: self.strategy = LocalFilesystemStorageStrategy(self.profile) elif strategy == StorageStrategy.GOOGLE_DRIVE: diff --git a/text_extract_api/files/storage_strategies/aws_s3.py b/text_extract_api/files/storage_strategies/aws_s3.py index adde33a..e359c6a 100644 --- a/text_extract_api/files/storage_strategies/aws_s3.py +++ b/text_extract_api/files/storage_strategies/aws_s3.py @@ -1,24 +1,32 @@ import boto3 from botocore.exceptions import EndpointConnectionError, ClientError -from text_extract_api.files.storage_strategies.storage_strategy import StorageStrategy +from text_extract_api.files.storage_strategies.storage_strategy import ( + BaseStorageStrategy, +) -class AWSS3StorageStrategy(StorageStrategy): +class AWSS3StorageStrategy(BaseStorageStrategy): def __init__(self, context): super().__init__(context) - self.bucket_name = self.resolve_placeholder(context['settings'].get('bucket_name')) - self.region = self.resolve_placeholder(context['settings'].get('region')) - self.access_key = self.resolve_placeholder(context['settings'].get('access_key')) - self.secret_access_key = self.resolve_placeholder(context['settings'].get('secret_access_key')) + self.bucket_name = self.resolve_placeholder( + context["settings"].get("bucket_name") + ) + self.region = self.resolve_placeholder(context["settings"].get("region")) + self.access_key = self.resolve_placeholder( + context["settings"].get("access_key") + ) + self.secret_access_key = self.resolve_placeholder( + context["settings"].get("secret_access_key") + ) try: self.s3_client = boto3.client( - 's3', + "s3", aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_access_key, - region_name=self.region + region_name=self.region, ) self.s3_client.head_bucket(Bucket=self.bucket_name) except EndpointConnectionError as e: @@ -27,8 +35,8 @@ def __init__(self, context): "Check your AWS_REGION and AWS_S3_BUCKET_NAME environment variables." ) from e except ClientError as e: - error_code = e.response.get('Error', {}).get('Code', 'Unknown') - if error_code in ('400', '403'): + error_code = e.response.get("Error", {}).get("Code", "Unknown") + if error_code in ("400", "403"): raise RuntimeError( f"{str(e)}\n" "Error: Please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY." @@ -42,7 +50,7 @@ def save(self, file_name, dest_file_name, content): self.s3_client.put_object( Bucket=self.bucket_name, Key=formatted_file_name, - Body=content.encode('utf-8') + Body=content.encode("utf-8"), ) except ClientError as e: raise RuntimeError( @@ -53,10 +61,10 @@ def save(self, file_name, dest_file_name, content): def load(self, file_name): try: response = self.s3_client.get_object(Bucket=self.bucket_name, Key=file_name) - return response['Body'].read().decode('utf-8') + return response["Body"].read().decode("utf-8") except ClientError as e: - error_code = e.response['Error']['Code'] - if error_code == 'NoSuchKey': + error_code = e.response["Error"]["Code"] + if error_code == "NoSuchKey": return None raise RuntimeError( f"{str(e)}\n" @@ -66,11 +74,10 @@ def load(self, file_name): def list(self): try: response = self.s3_client.list_objects_v2(Bucket=self.bucket_name) - return [item['Key'] for item in response.get('Contents', [])] + return [item["Key"] for item in response.get("Contents", [])] except ClientError as e: raise RuntimeError( - f"{str(e)}\n" - f"Error listing objects in bucket '{self.bucket_name}'." + f"{str(e)}\nError listing objects in bucket '{self.bucket_name}'." ) from e def delete(self, file_name): diff --git a/text_extract_api/files/storage_strategies/google_drive.py b/text_extract_api/files/storage_strategies/google_drive.py index 86f2daa..ef11735 100644 --- a/text_extract_api/files/storage_strategies/google_drive.py +++ b/text_extract_api/files/storage_strategies/google_drive.py @@ -9,32 +9,39 @@ ## Note - this code is using Service Accounts for authentication which are separate accounts other than ## your Google account. You can create a service account and download the JSON key file to use it for ## how to enable GDrive API: https://developers.google.com/drive/api/quickstart/python?hl=pl -from text_extract_api.files.storage_strategies.storage_strategy import StorageStrategy +from text_extract_api.files.storage_strategies.storage_strategy import ( + BaseStorageStrategy, +) -class GoogleDriveStorageStrategy(StorageStrategy): + +class GoogleDriveStorageStrategy(BaseStorageStrategy): def __init__(self, context): super().__init__(context) self.credentials = Credentials.from_service_account_file( - context['settings']['service_account_file'], - scopes=['https://www.googleapis.com/auth/drive'] + context["settings"]["service_account_file"], + scopes=["https://www.googleapis.com/auth/drive"], ) - self.service = build('drive', 'v3', credentials=self.credentials) - self.folder_id = context['settings']['folder_id'] + self.service = build("drive", "v3", credentials=self.credentials) + self.folder_id = context["settings"]["folder_id"] def save(self, file_name, dest_file_name, content): # Save content to a temporary file - with open(file_name, 'wb') as temp_file: - temp_file.write(content.encode('utf-8')) # Encode the string to bytes + with open(file_name, "wb") as temp_file: + temp_file.write(content.encode("utf-8")) # Encode the string to bytes file_metadata = { - 'name': self.format_file_name(file_name, dest_file_name), + "name": self.format_file_name(file_name, dest_file_name), } if self.folder_id: - file_metadata['parents'] = [self.folder_id] + file_metadata["parents"] = [self.folder_id] print(file_metadata) media = MediaFileUpload(file_name, resumable=True) - file = self.service.files().create(body=file_metadata, media_body=media, fields='id').execute() + file = ( + self.service.files() + .create(body=file_metadata, media_body=media, fields="id") + .execute() + ) print(f"File ID: {file.get('id')}") # Remove the temporary file @@ -44,12 +51,16 @@ def load(self, file_name): query = f"name = '{file_name}'" if self.folder_id: query += f" and '{self.folder_id}' in parents" - results = self.service.files().list(q=query, spaces='drive', fields='files(id, name)').execute() - items = results.get('files', []) + results = ( + self.service.files() + .list(q=query, spaces="drive", fields="files(id, name)") + .execute() + ) + items = results.get("files", []) if not items: - print('No files found.') + print("No files found.") return None - file_id = items[0]['id'] + file_id = items[0]["id"] request = self.service.files().get_media(fileId=file_id) fh = io.BytesIO() downloader = MediaIoBaseDownload(fh, request) @@ -64,19 +75,27 @@ def list(self): query = "" # "mimeType='application/vnd.google-apps.file'" if self.folder_id: query = f"'{self.folder_id}' in parents" - results = self.service.files().list(q=query, spaces='drive', fields='files(id, name)').execute() - items = results.get('files', []) - return [item['name'] for item in items] + results = ( + self.service.files() + .list(q=query, spaces="drive", fields="files(id, name)") + .execute() + ) + items = results.get("files", []) + return [item["name"] for item in items] def delete(self, file_name): query = f"name = '{file_name}'" if self.folder_id: query += f" and '{self.folder_id}' in parents" - results = self.service.files().list(q=query, spaces='drive', fields='files(id, name)').execute() - items = results.get('files', []) + results = ( + self.service.files() + .list(q=query, spaces="drive", fields="files(id, name)") + .execute() + ) + items = results.get("files", []) if not items: - print('No files found.') + print("No files found.") return - file_id = items[0]['id'] + file_id = items[0]["id"] self.service.files().delete(fileId=file_id).execute() print(f"File {file_name} deleted.") diff --git a/text_extract_api/files/storage_strategies/local_filesystem.py b/text_extract_api/files/storage_strategies/local_filesystem.py index 12683ef..ed0d627 100644 --- a/text_extract_api/files/storage_strategies/local_filesystem.py +++ b/text_extract_api/files/storage_strategies/local_filesystem.py @@ -1,9 +1,10 @@ import os -from datetime import datetime +from text_extract_api.files.storage_strategies.storage_strategy import ( + BaseStorageStrategy, +) -from text_extract_api.files.storage_strategies.storage_strategy import StorageStrategy def resolve_path(path): # Expand `~` to the home directory @@ -13,20 +14,23 @@ def resolve_path(path): return absolute_path -class LocalFilesystemStorageStrategy(StorageStrategy): +class LocalFilesystemStorageStrategy(BaseStorageStrategy): def __init__(self, context): super().__init__(context) - self.base_directory = resolve_path(self.context['settings']['root_path']) + self.base_directory = resolve_path(self.context["settings"]["root_path"]) print("Storage base directory: ", self.base_directory) - self.create_subfolders = self.context['settings'].get('create_subfolders', False) - self.subfolder_names_format = self.context['settings'].get('subfolder_names_format', '') + self.create_subfolders = self.context["settings"].get( + "create_subfolders", False + ) + self.subfolder_names_format = self.context["settings"].get( + "subfolder_names_format", "" + ) os.makedirs(self.base_directory, exist_ok=True) def _get_subfolder_path(self, file_name): if not self.subfolder_names_format: return self.base_directory - now = datetime.now() subfolder_path = self.format_file_name(file_name, self.subfolder_names_format) return os.path.join(self.base_directory, subfolder_path) @@ -36,13 +40,13 @@ def save(self, file_name, dest_file_name, content): full_path = os.path.join(subfolder_path, file_name) full_directory = os.path.dirname(full_path) os.makedirs(full_directory, exist_ok=True) - with open(full_path, 'w') as file: + with open(full_path, "w") as file: file.write(content) def load(self, file_name): subfolder_path = self._get_subfolder_path(file_name) file_path = os.path.join(subfolder_path, file_name) - with open(file_path, 'r') as file: + with open(file_path, "r") as file: return file.read() def list(self): diff --git a/text_extract_api/files/storage_strategies/storage_strategy.py b/text_extract_api/files/storage_strategies/storage_strategy.py index 6a70460..b317b9f 100644 --- a/text_extract_api/files/storage_strategies/storage_strategy.py +++ b/text_extract_api/files/storage_strategies/storage_strategy.py @@ -3,7 +3,8 @@ from pathlib import Path from string import Template -class StorageStrategy: + +class BaseStorageStrategy: def __init__(self, context): self.context = context @@ -20,15 +21,17 @@ def delete(self, file_name): raise NotImplementedError("Subclasses must implement this method") def format_file_name(self, file_name, format_string): - return format_string.format(file_fullname=file_name, # file_name with path - file_name=Path(file_name).stem, # file_name without path - file_extension=Path(file_name).suffix, # file extension - Y=datetime.now().strftime('%Y'), - mm=datetime.now().strftime('%m'), - dd=datetime.now().strftime('%d'), - HH=datetime.now().strftime('%H'), - MM=datetime.now().strftime('%M'), - SS=datetime.now().strftime('%S')) + return format_string.format( + file_fullname=file_name, # file_name with path + file_name=Path(file_name).stem, # file_name without path + file_extension=Path(file_name).suffix, # file extension + Y=datetime.now().strftime("%Y"), + mm=datetime.now().strftime("%m"), + dd=datetime.now().strftime("%d"), + HH=datetime.now().strftime("%H"), + MM=datetime.now().strftime("%M"), + SS=datetime.now().strftime("%S"), + ) def resolve_placeholder(self, value, default=None): if not value: @@ -39,4 +42,6 @@ def resolve_placeholder(self, value, default=None): if default: return default else: - raise ValueError(f"Environment variable '{e.args[0]}' is missing, and no default value is provided.") + raise ValueError( + f"Environment variable '{e.args[0]}' is missing, and no default value is provided." + ) diff --git a/text_extract_api/main.py b/text_extract_api/main.py index 57233a5..203e80f 100644 --- a/text_extract_api/main.py +++ b/text_extract_api/main.py @@ -19,30 +19,37 @@ # Define base path as text_extract_api - required for keeping absolute namespaces sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve())) + def storage_profile_exists(profile_name: str) -> bool: profile_path = os.path.abspath( - os.path.join(os.getenv('STORAGE_PROFILE_PATH', './storage_profiles'), f'{profile_name}.yaml')) - if not os.path.isfile(profile_path) and profile_path.startswith('..'): + os.path.join( + os.getenv("STORAGE_PROFILE_PATH", "./storage_profiles"), + f"{profile_name}.yaml", + ) + ) + if not os.path.isfile(profile_path) and profile_path.startswith(".."): # backward compability for ../storage_manager in .env - sub_profile_path = os.path.normpath(os.path.join('.', profile_path)) + sub_profile_path = os.path.normpath(os.path.join(".", profile_path)) return os.path.isfile(sub_profile_path) return True + app = FastAPI() # Connect to Redis -redis_url = os.getenv('REDIS_CACHE_URL', 'redis://redis:6379/1') +redis_url = os.getenv("REDIS_CACHE_URL", "redis://redis:6379/1") redis_client = redis.StrictRedis.from_url(redis_url) + @app.post("/ocr") async def ocr_endpoint( - strategy: str = Form(...), - prompt: str = Form(None), - model: str = Form(...), - file: UploadFile = File(...), - ocr_cache: bool = Form(...), - storage_profile: str = Form('default'), - storage_filename: str = Form(None), - language: str = Form('en') + strategy: str = Form(...), + prompt: str = Form(None), + model: str = Form(...), + file: UploadFile = File(...), + ocr_cache: bool = Form(...), + storage_profile: str = Form("default"), + storage_filename: str = Form(None), + language: str = Form("en"), ): """ Endpoint to extract text from an uploaded PDF, Image or Office file using different OCR strategies. @@ -50,8 +57,15 @@ async def ocr_endpoint( """ # Validate input try: - OcrFormRequest(strategy=strategy, prompt=prompt, model=model, ocr_cache=ocr_cache, - storage_profile=storage_profile, storage_filename=storage_filename, language=language) + OcrFormRequest( + strategy=strategy, + prompt=prompt, + model=model, + ocr_cache=ocr_cache, + storage_profile=storage_profile, + storage_filename=storage_filename, + language=language, + ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -60,33 +74,53 @@ async def ocr_endpoint( file_format = FileFormat.from_binary(file_binary, filename, file.content_type) print( - f"Processing Document {file_format.filename} with strategy: {strategy}, ocr_cache: {ocr_cache}, model: {model}, storage_profile: {storage_profile}, storage_filename: {storage_filename}, language: {language}, will be saved as: {filename}") + f"Processing Document {file_format.filename} with strategy: {strategy}, ocr_cache: {ocr_cache}, model: {model}, storage_profile: {storage_profile}, storage_filename: {storage_filename}, language: {language}, will be saved as: {filename}" + ) # Asynchronous processing using Celery task = ocr_task.apply_async( - args=[file_format.binary, strategy, file_format.filename, file_format.hash, ocr_cache, prompt, model, language, - storage_profile, - storage_filename]) + args=[ + file_format.binary, + strategy, + file_format.filename, + file_format.hash, + ocr_cache, + prompt, + model, + language, + storage_profile, + storage_filename, + ] + ) return {"task_id": task.id} # this is an alias for /ocr - to keep the backward compatibility @app.post("/ocr/upload") async def ocr_upload_endpoint( - strategy: str = Form(...), - prompt: str = Form(None), - model: str = Form(None), - file: UploadFile = File(...), - ocr_cache: bool = Form(...), - storage_profile: str = Form('default'), - storage_filename: str = Form(None), - language: str = Form('en') + strategy: str = Form(...), + prompt: str = Form(None), + model: str = Form(None), + file: UploadFile = File(...), + ocr_cache: bool = Form(...), + storage_profile: str = Form("default"), + storage_filename: str = Form(None), + language: str = Form("en"), ): """ Alias endpoint to extract text from an uploaded PDF/Office/Image file using different OCR strategies. Supports both synchronous and asynchronous processing. """ - return await ocr_endpoint(strategy, prompt, model, file, ocr_cache, storage_profile, storage_filename, language) + return await ocr_endpoint( + strategy, + prompt, + model, + file, + ocr_cache, + storage_profile, + storage_filename, + language, + ) class OllamaGenerateRequest(BaseModel): @@ -101,19 +135,23 @@ class OllamaPullRequest(BaseModel): class OcrRequest(BaseModel): strategy: str = Field(..., description="OCR strategy to use") prompt: Optional[str] = Field(None, description="Prompt for the Ollama model") - model: Optional[str] = Field(None, description="Model to use for the Ollama endpoint") + model: Optional[str] = Field( + None, description="Model to use for the Ollama endpoint" + ) file: FileField = Field(..., description="Base64 encoded document file") ocr_cache: bool = Field(..., description="Enable OCR result caching") - storage_profile: Optional[str] = Field('default', description="Storage profile to use") + storage_profile: Optional[str] = Field( + "default", description="Storage profile to use" + ) storage_filename: Optional[str] = Field(None, description="Storage filename to use") - language: Optional[str] = Field('en', description="Language to use for OCR") + language: Optional[str] = Field("en", description="Language to use for OCR") - @field_validator('strategy') + @field_validator("strategy") def validate_strategy(cls, v): Strategy.get_strategy(v) return v - @field_validator('storage_profile') + @field_validator("storage_profile") def validate_storage_profile(cls, v): if not storage_profile_exists(v): raise ValueError(f"Storage profile '{v}' does not exist.") @@ -123,18 +161,22 @@ def validate_storage_profile(cls, v): class OcrFormRequest(BaseModel): strategy: str = Field(..., description="OCR strategy to use") prompt: Optional[str] = Field(None, description="Prompt for the Ollama model") - model: Optional[str] = Field(None, description="Model to use for the Ollama endpoint") + model: Optional[str] = Field( + None, description="Model to use for the Ollama endpoint" + ) ocr_cache: bool = Field(..., description="Enable OCR result caching") - storage_profile: Optional[str] = Field('default', description="Storage profile to use") + storage_profile: Optional[str] = Field( + "default", description="Storage profile to use" + ) storage_filename: Optional[str] = Field(None, description="Storage filename to use") - language: Optional[str] = Field('en', description="Language to use for OCR") + language: Optional[str] = Field("en", description="Language to use for OCR") - @field_validator('strategy') + @field_validator("strategy") def validate_strategy(cls, v): Strategy.get_strategy(v) return v - @field_validator('storage_profile') + @field_validator("storage_profile") def validate_storage_profile(cls, v): if not storage_profile_exists(v): raise ValueError(f"Storage profile '{v}' does not exist.") @@ -156,12 +198,24 @@ async def ocr_request_endpoint(request: OcrRequest): raise HTTPException(status_code=400, detail=str(e)) print( - f"Processing {file.mime_type} with strategy: {request.strategy}, ocr_cache: {request.ocr_cache}, model: {request.model}, storage_profile: {request.storage_profile}, storage_filename: {request.storage_filename}, language: {request.language}") + f"Processing {file.mime_type} with strategy: {request.strategy}, ocr_cache: {request.ocr_cache}, model: {request.model}, storage_profile: {request.storage_profile}, storage_filename: {request.storage_filename}, language: {request.language}" + ) # Asynchronous processing using Celery task = ocr_task.apply_async( - args=[file.binary, request.strategy, file.filename, file.hash, request.ocr_cache, request.prompt, - request.model, request.language, request.storage_profile, request.storage_filename]) + args=[ + file.binary, + request.strategy, + file.filename, + file.hash, + request.ocr_cache, + request.prompt, + request.model, + request.language, + request.storage_profile, + request.storage_filename, + ] + ) return {"task_id": task.id} @@ -172,15 +226,23 @@ async def ocr_status(task_id: str): """ task = AsyncResult(task_id, app=celery_app) - if task.state == 'PENDING': + if task.state == "PENDING": return {"state": task.state, "status": "Task is pending..."} - elif task.state == 'PROGRESS': + elif task.state == "PROGRESS": task_info = task.info - if task_info.get('start_time'): - task_info['elapsed_time'] = time.time() - int(task_info.get('start_time')) - return {"state": task.state, "status": task.info.get("status"), "info": task_info} - elif task.state == 'SUCCESS': - return {"state": task.state, "status": "Task completed successfully.", "result": task.result} + if task_info.get("start_time"): + task_info["elapsed_time"] = time.time() - int(task_info.get("start_time")) + return { + "state": task.state, + "status": task.info.get("status"), + "info": task_info, + } + elif task.state == "SUCCESS": + return { + "state": task.state, + "status": "Task completed successfully.", + "result": task.result, + } else: return {"state": task.state, "status": str(task.info)} @@ -195,7 +257,7 @@ async def clear_ocr_cache(): @app.get("/storage/list") -async def list_files(storage_profile: str = 'default'): +async def list_files(storage_profile: str = "default"): """ Endpoint to list files using the selected storage profile. """ @@ -205,7 +267,7 @@ async def list_files(storage_profile: str = 'default'): @app.get("/storage/load") -async def load_file(file_name: str, storage_profile: str = 'default'): +async def load_file(file_name: str, storage_profile: str = "default"): """ Endpoint to load a file using the selected storage profile. """ @@ -215,7 +277,7 @@ async def load_file(file_name: str, storage_profile: str = 'default'): @app.delete("/storage/delete") -async def delete_file(file_name: str, storage_profile: str = 'default'): +async def delete_file(file_name: str, storage_profile: str = "default"): """ Endpoint to delete a file using the selected storage profile. """ @@ -233,8 +295,10 @@ async def pull_llama(request: OllamaPullRequest): try: response = ollama.pull(request.model) except ollama.ResponseError as e: - print('Error:', e.error) - raise HTTPException(status_code=500, detail="Failed to pull Llama model from Ollama API") + print("Error:", e.error) + raise HTTPException( + status_code=500, detail="Failed to pull Llama model from Ollama API" + ) return {"status": response.get("status", "Model pulled successfully")} @@ -251,12 +315,14 @@ async def generate_llama(request: OllamaGenerateRequest): try: response = ollama.generate(request.model, request.prompt) except ollama.ResponseError as e: - print('Error:', e.error) + print("Error:", e.error) if e.status_code == 404: print("Error: ", e.error) ollama.pull(request.model) - raise HTTPException(status_code=500, detail="Failed to generate text with Ollama API") + raise HTTPException( + status_code=500, detail="Failed to generate text with Ollama API" + ) generated_text = response.get("response", "") return {"generated_text": generated_text}