diff --git a/.flake8 b/.flake8 index 93f63e5..0df06ab 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,4 @@ max-line-length = 88 extend-ignore = E501 exclude = .venv, frontend -ignore = E203, W503, G004, G200 \ No newline at end of file +ignore = E203, W503, G004, G200,B008,ANN,D100,D101,D102,D103,D104,D105,D106,D107 \ No newline at end of file diff --git a/src/backend/api/api_routes.py b/src/backend/api/api_routes.py index 8a3d5a8..d234b6c 100644 --- a/src/backend/api/api_routes.py +++ b/src/backend/api/api_routes.py @@ -1,4 +1,4 @@ -"""FastAPI API routes for file processing and conversion""" +"""FastAPI API routes for file processing and conversion.""" import asyncio import io @@ -6,8 +6,10 @@ from api.auth.auth_utils import get_authenticated_user from api.status_updates import app_connection_manager, close_connection + from common.logger.app_logger import AppLogger from common.services.batch_service import BatchService + from fastapi import ( APIRouter, File, @@ -24,13 +26,14 @@ logger = AppLogger("APIRoutes") # start processing the batch -from sql_agents_start import process_batch_async +from sql_agents_start import process_batch_async # noqa: E402 @router.post("/start-processing") async def start_processing(request: Request): """ - Start processing files for a given batch + Start processing files for a given batch. + --- tags: - File Processing @@ -50,6 +53,7 @@ async def start_processing(request: Request): responses: 200: description: Processing initiated successfully + content: application/json: schema: @@ -61,6 +65,7 @@ async def start_processing(request: Request): type: string 400: description: Invalid processing request + 500: description: Internal server error """ @@ -89,7 +94,7 @@ async def start_processing(request: Request): ) async def download_files(batch_id: str): """ - Download files as ZIP + Download files as ZIP. --- tags: @@ -118,7 +123,6 @@ async def download_files(batch_id: str): type: string example: Batch not found """ - # call batch_service get_batch_for_zip to get all files for batch_id batch_service = BatchService() await batch_service.initialize_database() @@ -172,7 +176,7 @@ async def batch_status_updates( websocket: WebSocket, batch_id: str ): # , request: Request): """ - WebSocket endpoint for real-time batch status updates + Web-Socket endpoint for real-time batch status updates. --- tags: @@ -248,7 +252,7 @@ async def batch_status_updates( @router.get("/batch-story/{batch_id}") async def get_batch_status(request: Request, batch_id: str): """ - Retrieve batch history and file statuses + Retrieve batch history and file statuses. --- tags: @@ -371,9 +375,7 @@ async def get_batch_status(request: Request, batch_id: str): @router.get("/batch-summary/{batch_id}") async def get_batch_summary(request: Request, batch_id: str): - """ - Retrieve batch summary for a given batch ID. - """ + """Retrieve batch summary for a given batch ID.""" try: batch_service = BatchService() await batch_service.initialize_database() @@ -404,7 +406,7 @@ async def upload_file( request: Request, file: UploadFile = File(...), batch_id: str = Form(...) ): """ - Upload file for conversion + Upload file for conversion. --- tags: @@ -634,7 +636,7 @@ async def get_file_details(request: Request, file_id: str): @router.delete("/delete-batch/{batch_id}") async def delete_batch_details(request: Request, batch_id: str): """ - delete batch history using batch_id + Delete batch history using batch_id. --- tags: @@ -689,7 +691,7 @@ async def delete_batch_details(request: Request, batch_id: str): @router.delete("/delete-file/{file_id}") async def delete_file_details(request: Request, file_id: str): """ - delete file history using batch_id + Delete file history using batch_id. --- tags: @@ -747,7 +749,7 @@ async def delete_file_details(request: Request, file_id: str): @router.delete("/delete_all") async def delete_all_details(request: Request): """ - delete all the history of batches, files and logs + Delete all the history of batches, files and logs. --- tags: diff --git a/src/backend/api/auth/auth_utils.py b/src/backend/api/auth/auth_utils.py index c186b2c..da6a6b2 100644 --- a/src/backend/api/auth/auth_utils.py +++ b/src/backend/api/auth/auth_utils.py @@ -1,10 +1,12 @@ -from fastapi import Request, HTTPException -import logging import base64 import json +import logging from typing import Dict + from api.auth.sample_user import sample_user +from fastapi import HTTPException, Request + logger = logging.getLogger(__name__) @@ -26,19 +28,19 @@ def __init__(self, user_details: Dict): def get_tenant_id(client_principal_b64: str) -> str: - """Extract tenant ID from base64 encoded client principal""" + """Extract tenant ID from base64 encoded client principal.""" try: decoded_bytes = base64.b64decode(client_principal_b64) decoded_string = decoded_bytes.decode("utf-8") user_info = json.loads(decoded_string) return user_info.get("tid", "") - except Exception as ex: + except Exception : logger.exception("Error decoding client principal") return "" def get_authenticated_user(request: Request) -> UserDetails: - """Get authenticated user details from request headers""" + """Get authenticated user details from request headers.""" user_object = {} headers = dict(request.headers) # Check if we're in production with real headers diff --git a/src/backend/api/auth/sample_user.py b/src/backend/api/auth/sample_user.py index e15ef56..64bb2be 100644 --- a/src/backend/api/auth/sample_user.py +++ b/src/backend/api/auth/sample_user.py @@ -5,4 +5,4 @@ "x-ms-client-principal-idp": "aad", "x-ms-token-aad-id-token": "dev.token", "x-ms-client-principal": "your_base_64_encoded_token" -} \ No newline at end of file +} diff --git a/src/backend/api/status_updates.py b/src/backend/api/status_updates.py index 67f932b..7bf9f09 100644 --- a/src/backend/api/status_updates.py +++ b/src/backend/api/status_updates.py @@ -1,5 +1,6 @@ """ Holds collection of websocket connections. + from clients registering for status updates. These socket references are used to send updates to registered clients from the backend processing code. @@ -11,6 +12,7 @@ from typing import Dict from common.models.api import FileProcessUpdate, FileProcessUpdateJSONEncoder + from fastapi import WebSocket logger = logging.getLogger(__name__) diff --git a/src/backend/app.py b/src/backend/app.py index b7b2173..95d0830 100644 --- a/src/backend/app.py +++ b/src/backend/app.py @@ -1,12 +1,14 @@ -import uvicorn - -# Import our route modules +"""Create and configure the FastAPI application.""" from api.api_routes import router as backend_router + from common.logger.app_logger import AppLogger + from dotenv import load_dotenv + from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +import uvicorn # from agent_services.agents_routes import router as agents_router # Load environment variables @@ -17,9 +19,7 @@ def create_app() -> FastAPI: - """ - Factory function to create and configure the FastAPI application - """ + """Create and return the FastAPI application instance.""" app = FastAPI(title="Code Gen Accelerator", version="1.0.0") # Configure CORS @@ -37,7 +37,7 @@ def create_app() -> FastAPI: @app.get("/health") async def health_check(): - """Health check endpoint""" + """Health check endpoint.""" return {"status": "healthy"} return app diff --git a/src/backend/common/config/config.py b/src/backend/common/config/config.py index 9d5d1ad..3b774d6 100644 --- a/src/backend/common/config/config.py +++ b/src/backend/common/config/config.py @@ -1,6 +1,7 @@ import os from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential + from dotenv import load_dotenv load_dotenv() diff --git a/src/backend/common/database/cosmosdb.py b/src/backend/common/database/cosmosdb.py index 8444a81..a9e17e2 100644 --- a/src/backend/common/database/cosmosdb.py +++ b/src/backend/common/database/cosmosdb.py @@ -1,5 +1,4 @@ from datetime import datetime, timezone -from enum import Enum from typing import Dict, List, Optional from uuid import UUID, uuid4 @@ -7,9 +6,9 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.aio._database import DatabaseProxy from azure.cosmos.exceptions import ( - CosmosResourceExistsError, - CosmosResourceNotFoundError, + CosmosResourceExistsError ) + from common.database.database_base import DatabaseBase from common.logger.app_logger import AppLogger from common.models.api import ( @@ -20,6 +19,7 @@ LogType, ProcessStatus, ) + from semantic_kernel.contents import AuthorRole @@ -208,7 +208,7 @@ async def get_batch_files(self, batch_id: str) -> List[Dict]: raise async def get_batch_from_id(self, batch_id: str) -> Dict: - """Retrieve a batch from the database using the batch ID""" + """Retrieve a batch from the database using the batch ID.""" try: query = "SELECT * FROM c WHERE c.batch_id = @batch_id" params = [{"name": "@batch_id", "value": batch_id}] @@ -225,7 +225,7 @@ async def get_batch_from_id(self, batch_id: str) -> Dict: raise async def get_user_batches(self, user_id: str) -> Dict: - """Retrieve all batches for a given user""" + """Retrieve all batches for a given user.""" try: query = "SELECT * FROM c WHERE c.user_id = @user_id" params = [{"name": "@user_id", "value": user_id}] @@ -242,7 +242,7 @@ async def get_user_batches(self, user_id: str) -> Dict: raise async def get_file_logs(self, file_id: str) -> List[Dict]: - """Retrieve all logs for a given file""" + """Retrieve all logs for a given file.""" try: query = ( "SELECT * FROM c WHERE c.file_id = @file_id ORDER BY c.timestamp DESC" @@ -322,7 +322,7 @@ async def add_file_log( agent_type: AgentType, author_role: AuthorRole, ) -> None: - """Log a file status update""" + """Log a file status update.""" try: log_id = uuid4() log_entry = FileLog( @@ -343,7 +343,7 @@ async def add_file_log( async def update_batch_entry( self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int ): - """Update batch status""" + """Update batch status.""" try: batch = await self.get_batch(user_id, batch_id) if not batch: diff --git a/src/backend/common/database/database_base.py b/src/backend/common/database/database_base.py index a54f3c3..961426b 100644 --- a/src/backend/common/database/database_base.py +++ b/src/backend/common/database/database_base.py @@ -1,67 +1,65 @@ import uuid from abc import ABC, abstractmethod -from datetime import datetime -from enum import Enum from typing import Dict, List, Optional -from common.logger.app_logger import AppLogger -from common.models.api import AgentType, BatchRecord, FileRecord, LogType, ProcessStatus +from common.models.api import AgentType, BatchRecord, FileRecord, LogType + from semantic_kernel.contents import AuthorRole class DatabaseBase(ABC): - """Abstract base class for database operations""" + """Abstract base class for database operations.""" @abstractmethod async def initialize_cosmos(self) -> None: - """Initialize the cosmosdb client and create container if needed""" + """Initialize the cosmosdb client and create container if needed.""" pass @abstractmethod async def create_batch(self, user_id: str, batch_id: uuid.UUID) -> BatchRecord: - """Create a new conversion batch""" + """Create a new conversion batch.""" pass @abstractmethod async def get_file_logs(self, file_id: str) -> Dict: - """Retrieve all logs for a file""" + """Retrieve all logs for a file.""" pass @abstractmethod async def get_batch_from_id(self, batch_id: str) -> Dict: - """Retrieve all logs for a file""" + """Retrieve all logs for a file.""" pass @abstractmethod async def get_batch_files(self, batch_id: str) -> List[Dict]: - """Retrieve all files for a batch""" + """Retrieve all files for a batch.""" pass @abstractmethod async def delete_file_logs(self, file_id: str) -> None: - """Delete all logs for a file""" + """Delete all logs for a file.""" pass @abstractmethod async def get_user_batches(self, user_id: str) -> Dict: - """Retrieve all batches for a user""" + """Retrieve all batches for a user.""" pass @abstractmethod async def add_file( self, batch_id: uuid.UUID, file_id: uuid.UUID, file_name: str, storage_path: str ) -> FileRecord: - """Add a file entry to the database""" + """Add a file entry to the database.""" pass @abstractmethod async def get_batch(self, user_id: str, batch_id: str) -> Optional[Dict]: - """Retrieve a batch and its associated files""" + """Retrieve a batch and its associated files.""" pass @abstractmethod async def get_file(self, file_id: str) -> Optional[Dict]: - """Retrieve a file entry along with its logs""" + """Retrieve a file entry along with its logs.""" pass @abstractmethod @@ -74,12 +72,12 @@ async def add_file_log( agent_type: AgentType, author_role: AuthorRole, ) -> None: - """Log a file status update""" + """Log a file status update.""" pass @abstractmethod async def update_file(self, file_record: FileRecord) -> None: - """update file record""" + """Update file record.""" pass @abstractmethod @@ -88,25 +86,25 @@ async def update_batch(self, batch_record: BatchRecord) -> BatchRecord: @abstractmethod async def delete_all(self, user_id: str) -> None: - """Delete all batches, files, and logs for a user""" + """Delete all batches, files, and logs for a user.""" pass @abstractmethod async def delete_batch(self, user_id: str, batch_id: str) -> None: - """Delete a batch along with its files and logs""" + """Delete a batch along with its files and logs.""" pass @abstractmethod async def delete_file(self, user_id: str, batch_id: str, file_id: str) -> None: - """Delete a file and its logs, and update batch file count""" + """Delete a file and its logs, and update batch file count.""" pass @abstractmethod async def get_batch_history(self, user_id: str, batch_id: str) -> List[Dict]: - """Retrieve all logs for a batch""" + """Retrieve all logs for a batch.""" pass @abstractmethod async def close(self) -> None: - """Close database connection""" + """Close database connection.""" pass diff --git a/src/backend/common/database/database_factory.py b/src/backend/common/database/database_factory.py index 1306a52..ee92677 100644 --- a/src/backend/common/database/database_factory.py +++ b/src/backend/common/database/database_factory.py @@ -1,6 +1,5 @@ from typing import Optional -from azure.cosmos.aio import CosmosClient from common.config.config import Config from common.database.cosmosdb import CosmosDBClient from common.database.database_base import DatabaseBase diff --git a/src/backend/common/logger/app_logger.py b/src/backend/common/logger/app_logger.py index 5642ea7..b9aed46 100644 --- a/src/backend/common/logger/app_logger.py +++ b/src/backend/common/logger/app_logger.py @@ -1,7 +1,6 @@ +import json import logging -from datetime import datetime from typing import Any -import json class LogLevel: diff --git a/src/backend/common/models/api.py b/src/backend/common/models/api.py index 15c9525..7bf280a 100644 --- a/src/backend/common/models/api.py +++ b/src/backend/common/models/api.py @@ -1,9 +1,9 @@ from __future__ import annotations import json +import logging from datetime import datetime from enum import Enum -import logging from typing import Dict, List from uuid import UUID @@ -125,7 +125,7 @@ def __init__( @staticmethod def fromdb(data: Dict) -> FileLog: - """Convert str to UUID after fetching from the database""" + """Convert str to UUID after fetching from the database.""" return FileLog( log_id=UUID(data["log_id"]), # Convert str → UUID file_id=UUID(data["file_id"]), # Convert str → UUID @@ -142,7 +142,7 @@ def fromdb(data: Dict) -> FileLog: ) def dict(self) -> Dict: - """Convert UUID to str before inserting into the database""" + """Convert UUID to str before inserting into the database.""" return { "id": str(self.log_id), # Convert UUID → str "log_id": str(self.log_id), # Convert UUID → str @@ -185,7 +185,7 @@ def __init__( @staticmethod def fromdb(data: Dict) -> FileRecord: - """Convert str to UUID after fetching from the database""" + """Convert str to UUID after fetching from the database.""" return FileRecord( file_id=UUID(data["file_id"]), # Convert str → UUID batch_id=UUID(data["batch_id"]), # Convert str → UUID @@ -203,7 +203,7 @@ def fromdb(data: Dict) -> FileRecord: ) def dict(self) -> Dict: - """Convert UUID to str before inserting into the database""" + """Convert UUID to str before inserting into the database.""" return { "id": str(self.file_id), "file_id": str(self.file_id), # Convert UUID → str @@ -221,7 +221,7 @@ def dict(self) -> Dict: class FileProcessUpdate: - "websocket payload for file process updates" + """websocket payload for file process updates.""" def __init__( self, @@ -259,9 +259,7 @@ def dict(self) -> Dict: class FileProcessUpdateJSONEncoder(json.JSONEncoder): - """ - Custom JSON encoder for serializing FileProcessUpdate, ProcessStatus, and FileResult objects. - """ + """Custom JSON encoder for serializing FileProcessUpdate, ProcessStatus, and FileResult objects.""" def default(self, obj): # Check if the object is an instance of FileProcessUpdate, ProcessStatus, or FileResult @@ -294,7 +292,7 @@ def __init__( self.status = status def dict(self) -> Dict: - """Convert UUID to str before inserting into the database""" + """Convert UUID to str before inserting into the database.""" return { "batch_id": str(self.batch_id), # Convert UUID → str for DB "user_id": self.user_id, @@ -355,7 +353,7 @@ def fromdb(data: Dict) -> BatchRecord: ) def dict(self) -> Dict: - """Convert UUID to str before inserting into the database""" + """Convert UUID to str before inserting into the database.""" return { "id": str(self.batch_id), "batch_id": str(self.batch_id), # Convert UUID → str for DB diff --git a/src/backend/common/services/batch_service.py b/src/backend/common/services/batch_service.py index bbfecc1..0d5a609 100644 --- a/src/backend/common/services/batch_service.py +++ b/src/backend/common/services/batch_service.py @@ -14,7 +14,9 @@ ProcessStatus, ) from common.storage.blob_factory import BlobStorageFactory + from fastapi import HTTPException, UploadFile + from semantic_kernel.contents import AuthorRole @@ -29,7 +31,7 @@ async def initialize_database(self): self.database = await DatabaseFactory.get_database() async def get_batch(self, batch_id: UUID, user_id: str) -> Optional[Dict]: - """Retrieve batch details including files""" + """Retrieve batch details including files.""" batch = await self.database.get_batch(user_id, batch_id) if not batch: return None @@ -38,7 +40,7 @@ async def get_batch(self, batch_id: UUID, user_id: str) -> Optional[Dict]: return {"batch": batch, "files": files} async def get_file(self, file_id: str) -> Optional[Dict]: - """Retrieve file details""" + """Retrieve file details.""" file = await self.database.get_file(file_id) if not file: return None @@ -46,7 +48,7 @@ async def get_file(self, file_id: str) -> Optional[Dict]: return {"file": file} async def get_file_report(self, file_id: str) -> Optional[Dict]: - """Retrieve file logs""" + """Retrieve file logs.""" file = await self.database.get_file(file_id) file_record = FileRecord.fromdb(file) batch = await self.database.get_batch_from_id(str(file_record.batch_id)) @@ -59,7 +61,7 @@ async def get_file_report(self, file_id: str) -> Optional[Dict]: storage = await BlobStorageFactory.get_storage() if file_record.translated_path not in ["", None]: translated_content = await storage.get_file(file_record.translated_path) - except (FileNotFoundError, IOError) as e: + except IOError as e: self.logger.error(f"Error downloading file content: {str(e)}") return { @@ -71,20 +73,19 @@ async def get_file_report(self, file_id: str) -> Optional[Dict]: } async def get_file_translated(self, file: dict): - """Retrieve file logs""" - + """Retrieve file logs.""" translated_content = "" try: storage = await BlobStorageFactory.get_storage() if file["translated_path"] not in ["", None]: translated_content = await storage.get_file(file["translated_path"]) - except (FileNotFoundError, IOError) as e: + except IOError as e: self.logger.error(f"Error downloading file content: {str(e)}") return translated_content async def get_batch_for_zip(self, batch_id: str) -> List[Tuple[str, str]]: - """Retrieve batch details including files in a single zip archive""" + """Retrieve batch details including files in a single zip archive.""" files = [] try: files_meta = await self.database.get_batch_files(batch_id) @@ -108,7 +109,7 @@ async def get_batch_for_zip(self, batch_id: str) -> List[Tuple[str, str]]: raise # Re-raise for caller handling async def get_batch_summary(self, batch_id: str, user_id: str) -> Optional[Dict]: - """Retrieve file logs""" + """Retrieve file logs.""" try: try: batch = await self.database.get_batch(user_id, batch_id) @@ -148,7 +149,7 @@ async def get_batch_summary(self, batch_id: str, user_id: str) -> Optional[Dict] raise # Re-raise for caller handling async def delete_batch(self, batch_id: UUID, user_id: str): - """Delete a batch along with its files and logs""" + """Delete a batch along with its files and logs.""" batch = await self.database.get_batch(user_id, batch_id) if batch: await self.database.delete_batch(user_id, batch_id) @@ -157,7 +158,7 @@ async def delete_batch(self, batch_id: UUID, user_id: str): return {"message": "Batch deleted successfully", "batch_id": str(batch_id)} async def delete_file(self, file_id: UUID, user_id: str): - """Delete a file and its logs, and update batch file count""" + """Delete a file and its logs, and update batch file count.""" try: # Ensure storage is available storage = await BlobStorageFactory.get_storage() @@ -208,11 +209,11 @@ async def delete_file(self, file_id: UUID, user_id: str): raise RuntimeError("File deletion failed") from e async def delete_all(self, user_id: str): - """Delete all batches, files, and logs for a user""" + """Delete all batches, files, and logs for a user.""" return await self.database.delete_all(user_id) async def get_all_batches(self, user_id: str): - """Retrieve all batches for a user""" + """Retrieve all batches for a user.""" return await self.database.get_user_batches(user_id) def is_valid_uuid(self, value: str) -> bool: @@ -235,7 +236,7 @@ def generate_file_path( return file_path async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFile): - """Upload a file, create entries in the database, and log the process""" + """Upload a file, create entries in the database, and log the process.""" try: # Ensure storage is available storage = await BlobStorageFactory.get_storage() @@ -362,7 +363,7 @@ async def update_file( error_count: int, syntax_count: int, ): - """Update file entry in the database""" + """Update file entry in the database.""" file = await self.database.get_file(file_id) if not file: raise HTTPException(status_code=404, detail="File not found") @@ -376,7 +377,7 @@ async def update_file( return file_record async def update_file_record(self, file_record: FileRecord): - """Update file entry in the database""" + """Update file entry in the database.""" await self.database.update_file(file_record) async def create_file_log( @@ -388,7 +389,7 @@ async def create_file_log( agent_type: AgentType, author_role: AuthorRole, ): - """Create a new file log entry in the database""" + """Create a new file log entry in the database.""" await self.database.add_file_log( UUID(file_id), description, @@ -399,7 +400,7 @@ async def create_file_log( ) async def update_batch(self, batch_id: str, status: ProcessStatus): - """Update batch status to completed""" + """Update batch status to completed.""" batch = await self.database.get_batch_from_id(batch_id) if not batch: raise HTTPException(status_code=404, detail="Batch not found") @@ -409,7 +410,7 @@ async def update_batch(self, batch_id: str, status: ProcessStatus): await self.database.update_batch(batch_record) async def create_candidate(self, file_id: str, candidate: str): - """Create a new candidate entry in the database and upload the candita file to storage""" + """Create a new candidate entry in the database and upload the candita file to storage.""" # Ensure storage is available storage = await BlobStorageFactory.get_storage() if not storage: @@ -462,7 +463,7 @@ async def batch_files_final_update(self, batch_id: str): # file didn't completed successfully file_record.status = ProcessStatus.COMPLETED - if(file_record.translated_path == None or file_record.translated_path == ""): + if (file_record.translated_path is None or file_record.translated_path == ""): file_record.file_result = FileResult.ERROR error_count, syntax_count = await self.get_file_counts( @@ -519,11 +520,11 @@ async def get_file_counts(self, file_id: str): return error_count, syntax_count async def get_batch_from_id(self, batch_id: str): - """Retrieve a batch record from the database""" + """Retrieve a batch record from the database.""" return await self.database.get_batch_from_id(batch_id) async def delete_all_from_storage_cosmos(self, user_id: str): - """Delete a all files from storage, remove its database entry, logs""" + """Delete a all files from storage, remove its database entry, logs.""" try: # Ensure storage is available storage = await BlobStorageFactory.get_storage() diff --git a/src/backend/common/storage/blob_azure.py b/src/backend/common/storage/blob_azure.py index 839c07c..097cfd7 100644 --- a/src/backend/common/storage/blob_azure.py +++ b/src/backend/common/storage/blob_azure.py @@ -1,9 +1,8 @@ from typing import Any, BinaryIO, Dict, Optional -from azure.core.exceptions import ResourceExistsError from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient -from common.config.config import Config + from common.logger.app_logger import AppLogger from common.storage.blob_base import BlobStorageBase @@ -42,7 +41,7 @@ async def upload_file( content_type: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: - """Upload a file to Azure Blob Storage""" + """Upload a file to Azure Blob Storage.""" try: blob_client = self.container_client.get_blob_client(blob_path) @@ -51,7 +50,7 @@ async def upload_file( raise try: # Upload the file - upload_results = blob_client.upload_blob( + upload_results = blob_client.upload_blob( # noqa: F841 file_content, content_type=content_type, metadata=metadata, @@ -78,7 +77,7 @@ async def upload_file( raise async def get_file(self, blob_path: str) -> BinaryIO: - """Download a file from Azure Blob Storage""" + """Download a file from Azure Blob Storage.""" try: blob_client = self.container_client.get_blob_client(blob_path) download_stream = blob_client.download_blob() @@ -95,7 +94,7 @@ async def get_file(self, blob_path: str) -> BinaryIO: raise async def delete_file(self, blob_path: str) -> bool: - """Delete a file from Azure Blob Storage""" + """Delete a file from Azure Blob Storage.""" try: blob_client = self.container_client.get_blob_client(blob_path) blob_client.delete_blob() @@ -108,7 +107,7 @@ async def delete_file(self, blob_path: str) -> bool: return False async def list_files(self, prefix: Optional[str] = None) -> list[Dict[str, Any]]: - """List files in Azure Blob Storage""" + """List files in Azure Blob Storage.""" try: blobs = [] async for blob in self.container_client.list_blobs(name_starts_with=prefix): @@ -128,7 +127,7 @@ async def list_files(self, prefix: Optional[str] = None) -> list[Dict[str, Any]] raise async def close(self) -> None: - """Close blob storage connections""" + """Close blob storage connections.""" if self.service_client: self.service_client.close() self.logger.info("Closed blob storage connection") diff --git a/src/backend/common/storage/blob_base.py b/src/backend/common/storage/blob_base.py index af7b0c9..4495584 100644 --- a/src/backend/common/storage/blob_base.py +++ b/src/backend/common/storage/blob_base.py @@ -1,27 +1,27 @@ from abc import ABC, abstractmethod -from typing import BinaryIO, Optional, Dict, Any +from typing import Any, BinaryIO, Dict, Optional -class BlobStorageBase(ABC): - """Abstract base class for blob storage operations""" +class BlobStorageBase(ABC): + """Abstract base class for blob storage operations.""" @abstractmethod async def upload_file( - self, + self, file_content: BinaryIO, blob_path: str, content_type: Optional[str] = None, metadata: Optional[Dict[str, str]] = None ) -> Dict[str, Any]: """ - Upload a file to blob storage - + Upload a file to blob storage. + Args: file_content: The file content to upload blob_path: The path where to store the blob content_type: Optional content type of the file metadata: Optional metadata to store with the blob - + Returns: Dict containing upload details (url, size, etc.) """ @@ -30,11 +30,11 @@ async def upload_file( @abstractmethod async def get_file(self, blob_path: str) -> BinaryIO: """ - Retrieve a file from blob storage - + Retrieve a file from blob storage. + Args: blob_path: Path to the blob - + Returns: File content as a binary stream """ @@ -43,11 +43,11 @@ async def get_file(self, blob_path: str) -> BinaryIO: @abstractmethod async def delete_file(self, blob_path: str) -> bool: """ - Delete a file from blob storage - + Delete a file from blob storage. + Args: blob_path: Path to the blob to delete - + Returns: True if deletion was successful """ @@ -56,12 +56,12 @@ async def delete_file(self, blob_path: str) -> bool: @abstractmethod async def list_files(self, prefix: Optional[str] = None) -> list[Dict[str, Any]]: """ - List files in blob storage - + List files in blob storage. + Args: prefix: Optional prefix to filter blobs - + Returns: List of blob details """ - pass \ No newline at end of file + pass diff --git a/src/backend/common/storage/blob_factory.py b/src/backend/common/storage/blob_factory.py index 9e47fd8..fc85563 100644 --- a/src/backend/common/storage/blob_factory.py +++ b/src/backend/common/storage/blob_factory.py @@ -38,7 +38,6 @@ async def close_storage() -> None: async def main(): storage = await BlobStorageFactory.get_storage() # Use the storage instance... - files = await storage.list_files() blob = await storage.get_file("q1_informix.sql") print(blob) await BlobStorageFactory.close_storage() diff --git a/src/backend/sql_agents/__init__.py b/src/backend/sql_agents/__init__.py index 0648062..4251f94 100644 --- a/src/backend/sql_agents/__init__.py +++ b/src/backend/sql_agents/__init__.py @@ -1,6 +1,7 @@ -"""This module initializes the agents and helpers for the""" +"""This module initializes the agents and helpers for the.""" from common.models.api import AgentType + from sql_agents.fixer.agent import setup_fixer_agent from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion from sql_agents.helpers.utils import get_prompt @@ -10,7 +11,7 @@ from sql_agents.syntax_checker.agent import setup_syntax_checker_agent # Import the configuration function -from .agent_config import AgentsConfigDialect, create_config +from .agent_config import create_config __all__ = [ "setup_migrator_agent", diff --git a/src/backend/sql_agents/agent_config.py b/src/backend/sql_agents/agent_config.py index d815235..8f46372 100644 --- a/src/backend/sql_agents/agent_config.py +++ b/src/backend/sql_agents/agent_config.py @@ -1,6 +1,5 @@ """Configuration for the agents module.""" -import json import os from enum import Enum diff --git a/src/backend/sql_agents/fixer/agent.py b/src/backend/sql_agents/fixer/agent.py index 2ace3bc..033b51c 100644 --- a/src/backend/sql_agents/fixer/agent.py +++ b/src/backend/sql_agents/fixer/agent.py @@ -3,18 +3,18 @@ import logging from common.models.api import AgentType -from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion -from sql_agents.helpers.utils import get_prompt + from semantic_kernel.agents import ChatCompletionAgent from semantic_kernel.kernel import KernelArguments -from semantic_kernel.prompt_template import PromptTemplateConfig + from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect from sql_agents.fixer.response import FixerResponse +from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion +from sql_agents.helpers.utils import get_prompt logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -22,7 +22,7 @@ def setup_fixer_agent( name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment ) -> ChatCompletionAgent: - """Setup the fixer agent.""" + """Set up the fixer agent.""" _deployment_name = deployment_name.value _name = name.value kernel = create_kernel_with_chat_completion(_name, _deployment_name) diff --git a/src/backend/sql_agents/fixer/response.py b/src/backend/sql_agents/fixer/response.py index 39bf521..4cb7b2a 100644 --- a/src/backend/sql_agents/fixer/response.py +++ b/src/backend/sql_agents/fixer/response.py @@ -2,9 +2,7 @@ class FixerResponse(BaseModel): - """ - Model for the response of the fixer - """ + """Model for the response of the fixer.""" thought: str fixed_query: str diff --git a/src/backend/sql_agents/helpers/selection_function.py b/src/backend/sql_agents/helpers/selection_function.py index 4e3c045..35481a4 100644 --- a/src/backend/sql_agents/helpers/selection_function.py +++ b/src/backend/sql_agents/helpers/selection_function.py @@ -1,4 +1,4 @@ -"""selection_function.py""" +"""selection_function.py.""" from semantic_kernel.functions import KernelFunctionFromPrompt @@ -6,7 +6,7 @@ def setup_selection_function( name, migrator_name, picker_name, syntax_checker_name, fixer_name ): - """Setup the selection function.""" + """Set up the selection function.""" selection_function = KernelFunctionFromPrompt( function_name=name, prompt=f""" @@ -19,12 +19,12 @@ def setup_selection_function( - {picker_name.value} - {syntax_checker_name.value} - {fixer_name.value} - + Follow these instructions to determine the next participant: 1. After user input, it is always {migrator_name.value}'s turn. 2. After {migrator_name.value}, it is always {picker_name.value}'s turn. 3. After {picker_name.value}, it is always {syntax_checker_name.value}'s turn. - + The next two steps are repeated until the migration is complete: 4. After {syntax_checker_name.value}, it is {fixer_name.value}'s turn. 5. After {fixer_name.value}, it is {syntax_checker_name.value}'s turn. diff --git a/src/backend/sql_agents/helpers/termination_function.py b/src/backend/sql_agents/helpers/termination_function.py index 443fd2d..5b97ae2 100644 --- a/src/backend/sql_agents/helpers/termination_function.py +++ b/src/backend/sql_agents/helpers/termination_function.py @@ -1,19 +1,19 @@ -""" Helper function to set up the termination function for the semantic kernel. """ +"""Helper function to set up the termination function for the semantic kernel.""" from semantic_kernel.functions import KernelFunctionFromPrompt def setup_termination_function(name, termination_keyword): - """Setup the termination function for the semantic kernel.""" + """Set up the termination function for the semantic kernel.""" termination_function = KernelFunctionFromPrompt( function_name=name, prompt=f""" Examine the response and determine whether the query migration is complete. If so, respond with a single word without explanation: {termination_keyword}. - + INPUT: - Your input will be a JSON structure that contains a "syntax_errors" key. - + RULES: - If "syntax_errors" is an empty list, migration is complete. - If "syntax_errors" is not empty, migration is not complete. diff --git a/src/backend/sql_agents/helpers/utils.py b/src/backend/sql_agents/helpers/utils.py index 28e1a74..d2000ab 100644 --- a/src/backend/sql_agents/helpers/utils.py +++ b/src/backend/sql_agents/helpers/utils.py @@ -14,7 +14,7 @@ def get_prompt(agent_type: str) -> str: def is_text(content): - """Check if the content is text and not empty""" + """Check if the content is text and not empty.""" if isinstance(content, str): if len(content) == 0: return False diff --git a/src/backend/sql_agents/migrator/agent.py b/src/backend/sql_agents/migrator/agent.py index b881006..390dfa1 100644 --- a/src/backend/sql_agents/migrator/agent.py +++ b/src/backend/sql_agents/migrator/agent.py @@ -3,11 +3,13 @@ import logging from common.models.api import AgentType -from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion -from sql_agents.helpers.utils import get_prompt + from semantic_kernel.agents import ChatCompletionAgent from semantic_kernel.functions import KernelArguments + from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect +from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion +from sql_agents.helpers.utils import get_prompt from sql_agents.migrator.response import MigratorResponse logger = logging.getLogger(__name__) @@ -17,7 +19,7 @@ def setup_migrator_agent( name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment ) -> ChatCompletionAgent: - """Setup the migrator agent.""" + """Set up the migrator agent.""" _deployment_name = deployment_name.value _name = name.value NUM_CANDIDATES = 3 diff --git a/src/backend/sql_agents/migrator/response.py b/src/backend/sql_agents/migrator/response.py index da8124d..fa74f82 100644 --- a/src/backend/sql_agents/migrator/response.py +++ b/src/backend/sql_agents/migrator/response.py @@ -2,21 +2,17 @@ class MigratorCandidate(BaseModel): - """ - Model for a single candidate for migration - """ + """Model for a single candidate for migration.""" plan: str candidate_query: str class MigratorResponse(BaseModel): - """ - Model for the response of the migrator - """ + """Model for the response of the migrator.""" input_summary: str candidates: list[MigratorCandidate] input_error: str | None = None summary: str | None = None - rai_error: str | None = None \ No newline at end of file + rai_error: str | None = None diff --git a/src/backend/sql_agents/picker/agent.py b/src/backend/sql_agents/picker/agent.py index c724c13..867c790 100644 --- a/src/backend/sql_agents/picker/agent.py +++ b/src/backend/sql_agents/picker/agent.py @@ -1,15 +1,18 @@ -"""Picker agent setup.""" +"""Set up the Picker agent.""" import logging from common.models.api import AgentType -from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion -from sql_agents.helpers.utils import get_prompt + from semantic_kernel.agents import ChatCompletionAgent from semantic_kernel.kernel import KernelArguments + from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect +from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion +from sql_agents.helpers.utils import get_prompt from sql_agents.picker.response import PickerResponse + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -19,7 +22,7 @@ def setup_picker_agent( name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment ) -> ChatCompletionAgent: - """Setup the picker agent.""" + """Set up the picker agent.""" _deployment_name = deployment_name.value _name = name.value kernel = create_kernel_with_chat_completion(_name, _deployment_name) diff --git a/src/backend/sql_agents/picker/response.py b/src/backend/sql_agents/picker/response.py index eaad7c8..33a3804 100644 --- a/src/backend/sql_agents/picker/response.py +++ b/src/backend/sql_agents/picker/response.py @@ -7,9 +7,7 @@ class PickerCandidateSummary(BaseModel): class PickerResponse(BaseModel): - """ - The response of the picker agent. - """ + """The response of the picker agent.""" source_summary: str candidate_summaries: list[PickerCandidateSummary] diff --git a/src/backend/sql_agents/semantic_verifier/agent.py b/src/backend/sql_agents/semantic_verifier/agent.py index ab60ada..c20dc3d 100644 --- a/src/backend/sql_agents/semantic_verifier/agent.py +++ b/src/backend/sql_agents/semantic_verifier/agent.py @@ -1,15 +1,18 @@ -"""This module contains the setup for the semantic verifier agent.""" +"""Set up the semantic verifier agent.""" import logging from common.models.api import AgentType -from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion -from sql_agents.helpers.utils import get_prompt + from semantic_kernel.agents import ChatCompletionAgent from semantic_kernel.kernel import KernelArguments + from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect +from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion +from sql_agents.helpers.utils import get_prompt from sql_agents.semantic_verifier.response import SemanticVerifierResponse + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -21,7 +24,7 @@ def setup_semantic_verifier_agent( source_query: str, target_query: str, ) -> ChatCompletionAgent: - """Setup the semantic verifier agent.""" + """Set up the semantic verifier agent.""" _deployment_name = deployment_name.value _name = name.value kernel = create_kernel_with_chat_completion(_name, _deployment_name) diff --git a/src/backend/sql_agents/semantic_verifier/response.py b/src/backend/sql_agents/semantic_verifier/response.py index 0c3f5dd..ab771a4 100644 --- a/src/backend/sql_agents/semantic_verifier/response.py +++ b/src/backend/sql_agents/semantic_verifier/response.py @@ -2,9 +2,7 @@ class SemanticVerifierResponse(BaseModel): - """ - Response model for the semantic verifier agent - """ + """Response model for the semantic verifier agent.""" analysis: str judgement: str diff --git a/src/backend/sql_agents/syntax_checker/agent.py b/src/backend/sql_agents/syntax_checker/agent.py index 0c70993..9ee89ec 100644 --- a/src/backend/sql_agents/syntax_checker/agent.py +++ b/src/backend/sql_agents/syntax_checker/agent.py @@ -1,17 +1,20 @@ -"""This module contains the syntax checker agent.""" +"""Set up the syntax checker agent.""" import logging from common.models.api import AgentType -from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion -from sql_agents.helpers.utils import get_prompt + from semantic_kernel.agents import ChatCompletionAgent from semantic_kernel.connectors.ai import FunctionChoiceBehavior from semantic_kernel.kernel import KernelArguments + from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect +from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion +from sql_agents.helpers.utils import get_prompt from sql_agents.syntax_checker.plug_ins import SyntaxCheckerPlugin from sql_agents.syntax_checker.response import SyntaxCheckerResponse + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -19,7 +22,7 @@ def setup_syntax_checker_agent( name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment ) -> ChatCompletionAgent: - """Setup the syntax checker agent.""" + """Set up the syntax checker agent.""" _deployment_name = deployment_name.value _name = name.value kernel = create_kernel_with_chat_completion(_name, _deployment_name) diff --git a/src/backend/sql_agents/syntax_checker/plug_ins.py b/src/backend/sql_agents/syntax_checker/plug_ins.py index f5f2703..ca68998 100644 --- a/src/backend/sql_agents/syntax_checker/plug_ins.py +++ b/src/backend/sql_agents/syntax_checker/plug_ins.py @@ -27,7 +27,7 @@ def check_syntax( ) -> Annotated[ str, """ - Returns a json list of errors in the format of + Returns a json list of errors in the format of. [ { "Line": , @@ -39,13 +39,11 @@ def check_syntax( """, ]: """Check the TSQL syntax using tsqlParser.""" - print(f"Called syntaxCheckerPlugin with: {candidate_sql}") return self._call_tsqlparser(candidate_sql) def _call_tsqlparser(self, param): - """Select the executable based on the operating system""" - + """Select the executable based on the operating system.""" print("cwd =" + os.getcwd()) print(f"Calling tsqlParser with: {param}") if platform.system() == "Windows": diff --git a/src/backend/sql_agents/syntax_checker/response.py b/src/backend/sql_agents/syntax_checker/response.py index 14fd3a4..7907018 100644 --- a/src/backend/sql_agents/syntax_checker/response.py +++ b/src/backend/sql_agents/syntax_checker/response.py @@ -10,9 +10,7 @@ class SyntaxErrorInt(BaseModel): class SyntaxCheckerResponse(BaseModel): - """ - Response model for the syntax checker agent - """ + """Response model for the syntax checker agent.""" thought: str syntax_errors: List[SyntaxErrorInt] diff --git a/src/backend/sql_agents_start.py b/src/backend/sql_agents_start.py index a9d3796..5636756 100644 --- a/src/backend/sql_agents_start.py +++ b/src/backend/sql_agents_start.py @@ -1,15 +1,10 @@ -""" -This script demonstrates how to use the backend agents to migrate a query from one SQL dialect to another. -""" +"""This script demonstrates how to use the backend agents to migrate a query from one SQL dialect to another.""" -import asyncio import json import logging -import os -import sys -from pathlib import Path -from api.status_updates import close_connection, send_status_update +from api.status_updates import send_status_update + from common.models.api import ( AgentType, FileProcessUpdate, @@ -20,10 +15,9 @@ ) from common.services.batch_service import BatchService from common.storage.blob_factory import BlobStorageFactory + from fastapi import HTTPException -from sql_agents.helpers.selection_function import setup_selection_function -from sql_agents.helpers.termination_function import setup_termination_function -from sql_agents.helpers.utils import is_text + from semantic_kernel.agents import AgentGroupChat from semantic_kernel.agents.strategies import ( KernelFunctionSelectionStrategy, @@ -36,6 +30,7 @@ ChatMessageContent, ) from semantic_kernel.exceptions.service_exceptions import ServiceResponseException + from sql_agents import ( create_kernel_with_chat_completion, setup_fixer_agent, @@ -46,6 +41,9 @@ ) from sql_agents.agent_config import AgentModelDeployment, create_config from sql_agents.fixer.response import FixerResponse +from sql_agents.helpers.selection_function import setup_selection_function +from sql_agents.helpers.termination_function import setup_termination_function +from sql_agents.helpers.utils import is_text from sql_agents.migrator.response import MigratorResponse from sql_agents.picker.response import PickerResponse from sql_agents.semantic_verifier.response import SemanticVerifierResponse @@ -78,8 +76,11 @@ def extract_query(content): - """Extract the query from a chat that contains the following template: - # "migrated_query": 'SELECT TOP 10 * FROM mytable'""" + """ + Extract the query from a chat that contains the following template:. + + # "migrated_query": 'SELECT TOP 10 * FROM mytable' + """ if "migrated_query" in content: sub_str = content.split("migrated_query")[1] return sub_str.split(":")[1].strip().strip('"') @@ -136,7 +137,7 @@ async def configure_agents(): async def convert( source_script, file: FileRecord, batch_service: BatchService, agent_config ) -> str: - """setup agents, selection and termination.""" + """Set up agents, selection and termination.""" logger.info("Migrating query: %s\n", source_script) history_reducer = ChatHistoryTruncationReducer( @@ -432,7 +433,7 @@ async def invoke_semantic_verifier( async def process_batch_async(batch_id: str): - """Run main script with dummy Cosmos data""" + """Run main script with dummy Cosmos data.""" logger.info("Processing batch: %s", batch_id) storage = await BlobStorageFactory.get_storage() batch_service = BatchService() @@ -542,7 +543,7 @@ async def process_batch_async(batch_id: str): async def process_error( ex: Exception, file_record: FileRecord, batch_service: BatchService ): - """insert data base write to file record stating invalid file and send ws notification""" + """Insert data base write to file record stating invalid file and send ws notification.""" await batch_service.create_file_log( str(file_record.file_id), "Error processing file {}".format(ex), diff --git a/src/tests/backend/common/config/config_test.py b/src/tests/backend/common/config/config_test.py index 16f52ea..87531bb 100644 --- a/src/tests/backend/common/config/config_test.py +++ b/src/tests/backend/common/config/config_test.py @@ -22,7 +22,7 @@ class TestConfigInitialization(unittest.TestCase): clear=True, ) def test_config_initialization(self): - """Test if all attributes are correctly assigned from environment variables""" + """Test if all attributes are correctly assigned from environment variables.""" config = Config() # Ensure every attribute is accessed diff --git a/src/tests/backend/common/database/cosmosdb_test.py b/src/tests/backend/common/database/cosmosdb_test.py index 44521e1..7ef364a 100644 --- a/src/tests/backend/common/database/cosmosdb_test.py +++ b/src/tests/backend/common/database/cosmosdb_test.py @@ -1,19 +1,15 @@ import asyncio +import enum import uuid from datetime import datetime -import enum -import pytest + from azure.cosmos import PartitionKey, exceptions from common.database.cosmosdb import CosmosDBClient -from common.models.api import ( - BatchRecord, - FileRecord, - ProcessStatus, - FileLog, - LogType, -) from common.logger.app_logger import AppLogger +from common.models.api import ProcessStatus + +import pytest # --- Enums for Testing --- diff --git a/src/tests/backend/common/database/database_base_test.py b/src/tests/backend/common/database/database_base_test.py index 6000d86..0e9d1fe 100644 --- a/src/tests/backend/common/database/database_base_test.py +++ b/src/tests/backend/common/database/database_base_test.py @@ -1,12 +1,11 @@ -import asyncio import uuid -import pytest -from datetime import datetime from enum import Enum # Import the abstract base class and related models/enums. from common.database.database_base import DatabaseBase -from common.models.api import BatchRecord, FileRecord, ProcessStatus +from common.models.api import ProcessStatus + +import pytest DatabaseBase.__abstractmethods__ = set() @@ -63,6 +62,7 @@ def close(self): def get_dummy_status(): """ Try to use a specific ProcessStatus value (e.g. PROCESSING). + If that member is not available, just return the first member in the enum. """ try: diff --git a/src/tests/backend/common/database/database_factory_test.py b/src/tests/backend/common/database/database_factory_test.py index b597e56..bdf99d3 100644 --- a/src/tests/backend/common/database/database_factory_test.py +++ b/src/tests/backend/common/database/database_factory_test.py @@ -1,7 +1,8 @@ -import pytest from common.config.config import Config from common.database.database_factory import DatabaseFactory +import pytest + class DummyConfig: cosmosdb_endpoint = "dummy_endpoint" @@ -20,6 +21,7 @@ def __init__(self, endpoint, credential, database_name, batch_container, file_co self.file_container = file_container self.log_container = log_container + def dummy_config_init(self): self.cosmosdb_endpoint = DummyConfig.cosmosdb_endpoint self.cosmosdb_database = DummyConfig.cosmosdb_database @@ -29,19 +31,23 @@ def dummy_config_init(self): # Provide a dummy method for credentials. self.get_azure_credentials = lambda: "dummy_credential" + @pytest.fixture(autouse=True) def patch_config(monkeypatch): # Patch the __init__ of Config so that an instance will have the required attributes. monkeypatch.setattr(Config, "__init__", dummy_config_init) + @pytest.fixture(autouse=True) def patch_cosmosdb_client(monkeypatch): # Patch CosmosDBClient in the module under test to use our dummy client. monkeypatch.setattr("common.database.database_factory.CosmosDBClient", DummyCosmosDBClient) + def test_get_database(): """ - Test that DatabaseFactory.get_database() correctly returns an instance of the + Test that DatabaseFactory.get_database() correctly returns an instance of the. + dummy CosmosDB client with the expected configuration values. """ # When get_database() is called, it creates a new Config() instance. diff --git a/src/tests/backend/common/storage/blob_azure_test.py b/src/tests/backend/common/storage/blob_azure_test.py index 2abb8c8..2f74302 100644 --- a/src/tests/backend/common/storage/blob_azure_test.py +++ b/src/tests/backend/common/storage/blob_azure_test.py @@ -1,17 +1,20 @@ # blob_azure_test.py -import asyncio from datetime import datetime -import pytest from unittest.mock import AsyncMock, MagicMock, patch # Import the class under test -from common.storage.blob_azure import AzureBlobStorage from azure.core.exceptions import ResourceExistsError +from common.storage.blob_azure import AzureBlobStorage + + +import pytest + class DummyBlob: """A dummy blob item returned by list_blobs.""" + def __init__(self, name, size, creation_time, content_type, metadata): self.name = name self.size = size @@ -19,8 +22,10 @@ def __init__(self, name, size, creation_time, content_type, metadata): self.content_settings = MagicMock(content_type=content_type) self.metadata = metadata + class DummyAsyncIterator: """A dummy async iterator that yields the given items.""" + def __init__(self, items): self.items = items self.index = 0 @@ -35,18 +40,22 @@ async def __anext__(self): self.index += 1 return item + class DummyDownloadStream: """A dummy download stream whose content_as_bytes method returns a fixed byte string.""" + async def content_as_bytes(self): return b"file content" # --- Fixtures --- + @pytest.fixture def dummy_storage(): # Create an instance with dummy connection string and container name. return AzureBlobStorage("dummy_connection_string", "dummy_container") + @pytest.fixture def dummy_container_client(): container = MagicMock() @@ -55,12 +64,14 @@ def dummy_container_client(): container.get_blob_client = MagicMock() return container + @pytest.fixture def dummy_service_client(dummy_container_client): service = MagicMock() service.get_container_client.return_value = dummy_container_client return service + @pytest.fixture def dummy_blob_client(): blob_client = MagicMock() @@ -73,6 +84,7 @@ def dummy_blob_client(): # --- Tests for AzureBlobStorage methods --- + @pytest.mark.asyncio async def test_initialize_creates_container(dummy_storage, dummy_service_client, dummy_container_client): with patch("common.storage.blob_azure.BlobServiceClient.from_connection_string", return_value=dummy_service_client) as mock_from_conn: @@ -83,6 +95,7 @@ async def test_initialize_creates_container(dummy_storage, dummy_service_client, dummy_service_client.get_container_client.assert_called_once_with("dummy_container") dummy_container_client.create_container.assert_awaited_once() + @pytest.mark.asyncio async def test_initialize_container_already_exists(dummy_storage, dummy_service_client, dummy_container_client): with patch("common.storage.blob_azure.BlobServiceClient.from_connection_string", return_value=dummy_service_client): @@ -93,6 +106,7 @@ async def test_initialize_container_already_exists(dummy_storage, dummy_service_ dummy_container_client.create_container.assert_awaited_once() mock_debug.assert_called_with("Container dummy_container already exists") + @pytest.mark.asyncio async def test_initialize_failure(dummy_storage): # Simulate failure during initialization. @@ -102,6 +116,7 @@ async def test_initialize_failure(dummy_storage): await dummy_storage.initialize() mock_error.assert_called() + @pytest.mark.asyncio async def test_upload_file_success(dummy_storage, dummy_blob_client): # Patch get_blob_client to return our dummy blob client. @@ -127,6 +142,7 @@ async def test_upload_file_success(dummy_storage, dummy_blob_client): assert result["url"] == dummy_blob_client.url assert result["etag"] == "dummy_etag" + @pytest.mark.asyncio async def test_upload_file_error(dummy_storage, dummy_blob_client): dummy_storage.container_client = MagicMock() @@ -135,6 +151,7 @@ async def test_upload_file_error(dummy_storage, dummy_blob_client): with pytest.raises(Exception, match="Upload failed"): await dummy_storage.upload_file(b"data", "blob.txt", "text/plain", {}) + @pytest.mark.asyncio async def test_get_file_success(dummy_storage, dummy_blob_client): dummy_storage.container_client = MagicMock() @@ -146,6 +163,7 @@ async def test_get_file_success(dummy_storage, dummy_blob_client): dummy_blob_client.download_blob.assert_awaited() assert result == b"file content" + @pytest.mark.asyncio async def test_get_file_error(dummy_storage, dummy_blob_client): dummy_storage.container_client = MagicMock() @@ -154,6 +172,7 @@ async def test_get_file_error(dummy_storage, dummy_blob_client): with pytest.raises(Exception, match="Download error"): await dummy_storage.get_file("nonexistent.txt") + @pytest.mark.asyncio async def test_delete_file_success(dummy_storage, dummy_blob_client): dummy_storage.container_client = MagicMock() @@ -164,6 +183,7 @@ async def test_delete_file_success(dummy_storage, dummy_blob_client): dummy_blob_client.delete_blob.assert_awaited() assert result is True + @pytest.mark.asyncio async def test_delete_file_error(dummy_storage, dummy_blob_client): dummy_storage.container_client = MagicMock() @@ -172,6 +192,7 @@ async def test_delete_file_error(dummy_storage, dummy_blob_client): result = await dummy_storage.delete_file("blob.txt") assert result is False + @pytest.mark.asyncio async def test_list_files_success(dummy_storage): dummy_storage.container_client = MagicMock() @@ -185,17 +206,20 @@ async def test_list_files_success(dummy_storage): names = {item["name"] for item in result} assert names == {"file1.txt", "file2.txt"} + @pytest.mark.asyncio async def test_list_files_failure(dummy_storage): dummy_storage.container_client = MagicMock() # Define list_blobs to return an invalid object (simulate error) + async def invalid_list_blobs(*args, **kwargs): # Return a plain string (which does not implement __aiter__) return "invalid" dummy_storage.container_client.list_blobs = invalid_list_blobs - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa B017 await dummy_storage.list_files("") + @pytest.mark.asyncio async def test_close(dummy_storage): dummy_storage.service_client = MagicMock() diff --git a/src/tests/backend/common/storage/blob_base_test.py b/src/tests/backend/common/storage/blob_base_test.py index b4b0361..561007e 100644 --- a/src/tests/backend/common/storage/blob_base_test.py +++ b/src/tests/backend/common/storage/blob_base_test.py @@ -1,14 +1,13 @@ -import pytest -import asyncio -import uuid from datetime import datetime -from typing import BinaryIO, Dict, Any +from typing import Any, BinaryIO, Dict # Import the abstract base class from the production code. from common.storage.blob_base import BlobStorageBase - +import pytest # Create a dummy concrete subclass of BlobStorageBase that calls the parent's abstract methods. + + class DummyBlobStorage(BlobStorageBase): async def initialize(self) -> None: # Call the parent (which is just a pass) diff --git a/src/tests/backend/common/storage/blob_factory_test.py b/src/tests/backend/common/storage/blob_factory_test.py index e19af49..47e344f 100644 --- a/src/tests/backend/common/storage/blob_factory_test.py +++ b/src/tests/backend/common/storage/blob_factory_test.py @@ -1,10 +1,7 @@ -# blob_factory_test.py import asyncio -import json import os import sys -import pytest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock # Adjust sys.path so that the project root is found. sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) @@ -22,21 +19,26 @@ sys.modules["azure.monitor.events.extension"] = MagicMock() # --- Import the module under test --- -from common.storage.blob_factory import BlobStorageFactory -from common.storage.blob_base import BlobStorageBase -from common.storage.blob_azure import AzureBlobStorage +from common.storage.blob_base import BlobStorageBase # noqa: E402 +from common.storage.blob_factory import BlobStorageFactory # noqa: E402 + +import pytest # noqa: E402 # --- Dummy configuration for testing --- + + class DummyConfig: azure_blob_connection_string = "dummy_connection_string" azure_blob_container_name = "dummy_container" # --- Fixture to patch Config in our tests --- + + @pytest.fixture(autouse=True) def patch_config(monkeypatch): # Import the real Config from your project. from common.config.config import Config - + def dummy_init(self): self.azure_blob_connection_string = DummyConfig.azure_blob_connection_string self.azure_blob_container_name = DummyConfig.azure_blob_container_name @@ -81,6 +83,8 @@ async def close(self): self.initialized = False # --- Fixture to patch AzureBlobStorage --- + + @pytest.fixture(autouse=True) def patch_azure_blob_storage(monkeypatch): monkeypatch.setattr("common.storage.blob_factory.AzureBlobStorage", DummyAzureBlobStorage) @@ -88,6 +92,7 @@ def patch_azure_blob_storage(monkeypatch): # -------------------- Tests for BlobStorageFactory -------------------- + @pytest.mark.asyncio async def test_get_storage_success(): """Test that get_storage returns an initialized DummyAzureBlobStorage instance and is a singleton.""" @@ -99,13 +104,16 @@ async def test_get_storage_success(): storage2 = await BlobStorageFactory.get_storage() assert storage is storage2 + @pytest.mark.asyncio async def test_get_storage_missing_config(monkeypatch): """ Test that get_storage raises a ValueError when configuration is missing. + We simulate missing connection string and container name. """ from common.config.config import Config + def dummy_init_missing(self): self.azure_blob_connection_string = "" self.azure_blob_container_name = "" @@ -113,6 +121,7 @@ def dummy_init_missing(self): with pytest.raises(ValueError, match="Azure Blob Storage configuration is missing"): await BlobStorageFactory.get_storage() + @pytest.mark.asyncio async def test_close_storage_success(): """Test that close_storage calls close() on the storage instance and resets the singleton.""" @@ -125,6 +134,7 @@ async def test_close_storage_success(): # -------------------- File Upload Tests -------------------- + @pytest.mark.asyncio async def test_upload_file_success(): """Test that upload_file successfully uploads a file and returns metadata.""" @@ -139,6 +149,7 @@ async def test_upload_file_success(): assert result["size"] == len(file_content) assert blob_path in storage.files + @pytest.mark.asyncio async def test_upload_file_error(monkeypatch): """Test that an exception during file upload is propagated.""" @@ -150,6 +161,7 @@ async def test_upload_file_error(monkeypatch): # -------------------- File Retrieval Tests -------------------- + @pytest.mark.asyncio async def test_get_file_success(): """Test that get_file retrieves the correct file content.""" @@ -161,6 +173,7 @@ async def test_get_file_success(): result = await storage.get_file(blob_path) assert result == file_content + @pytest.mark.asyncio async def test_get_file_not_found(): """Test that get_file raises FileNotFoundError when file does not exist.""" @@ -171,6 +184,7 @@ async def test_get_file_not_found(): # -------------------- File Deletion Tests -------------------- + @pytest.mark.asyncio async def test_delete_file_success(): """Test that delete_file removes an existing file.""" @@ -181,6 +195,7 @@ async def test_delete_file_success(): await storage.delete_file(blob_path) assert blob_path not in storage.files + @pytest.mark.asyncio async def test_delete_file_nonexistent(): """Test that deleting a non-existent file does not raise an error.""" @@ -192,6 +207,7 @@ async def test_delete_file_nonexistent(): # -------------------- File Listing Tests -------------------- + @pytest.mark.asyncio async def test_list_files_with_prefix(): """Test that list_files returns files that match the given prefix.""" @@ -205,6 +221,7 @@ async def test_list_files_with_prefix(): result = await storage.list_files("folder/") assert set(result) == {"folder/a.txt", "folder/b.txt"} + @pytest.mark.asyncio async def test_list_files_no_files(): """Test that list_files returns an empty list when no files match the prefix.""" @@ -216,6 +233,7 @@ async def test_list_files_no_files(): # -------------------- Additional Basic Tests -------------------- + @pytest.mark.asyncio async def test_dummy_azure_blob_storage_initialize(): """Test that initializing DummyAzureBlobStorage sets the initialized flag.""" @@ -224,6 +242,7 @@ async def test_dummy_azure_blob_storage_initialize(): await storage.initialize() assert storage.initialized is True + @pytest.mark.asyncio async def test_dummy_azure_blob_storage_upload_and_retrieve(): """Test that a file uploaded to DummyAzureBlobStorage can be retrieved.""" @@ -238,6 +257,7 @@ async def test_dummy_azure_blob_storage_upload_and_retrieve(): retrieved = await storage.get_file(blob_path) assert retrieved == content + @pytest.mark.asyncio async def test_dummy_azure_blob_storage_close(): """Test that close() sets initialized to False.""" @@ -248,6 +268,7 @@ async def test_dummy_azure_blob_storage_close(): # -------------------- Test for BlobStorageFactory Singleton Usage -------------------- + def test_common_usage_of_blob_factory(): """Test that manually setting the singleton in BlobStorageFactory works as expected.""" # Create a dummy storage instance. @@ -257,6 +278,7 @@ def test_common_usage_of_blob_factory(): storage = asyncio.run(BlobStorageFactory.get_storage()) assert storage is dummy_storage + if __name__ == "__main__": # Run tests when this file is executed directly. asyncio.run(pytest.main())