@@ -1323,6 +1327,10 @@ useEffect(() => {
// Don't allow selecting queued files
if (file.status === "ready_to_process") return;
setSelectedFileId(file.id);
+ handleClick(file.id);
+ }}
+ style={{
+ backgroundColor: selectedFilebg === file.id ? "#EBEBEB" : "var(--NeutralBackground1-Rest)",
}}
>
{isSummary ? (
diff --git a/src/tests/backend/app_test.py b/src/tests/backend/app_test.py
new file mode 100644
index 0000000..610e36c
--- /dev/null
+++ b/src/tests/backend/app_test.py
@@ -0,0 +1,33 @@
+from backend.app import create_app
+
+from fastapi import FastAPI
+
+from httpx import ASGITransport
+from httpx import AsyncClient
+
+import pytest
+
+
+@pytest.fixture
+def app() -> FastAPI:
+ """Fixture to create a test app instance."""
+ return create_app()
+
+
+@pytest.mark.asyncio
+async def test_health_check(app: FastAPI):
+ """Test the /health endpoint returns a healthy status."""
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
+ response = await ac.get("/health")
+ assert response.status_code == 200
+ assert response.json() == {"status": "healthy"}
+
+
+@pytest.mark.asyncio
+async def test_backend_routes_exist(app: FastAPI):
+ """Ensure /api routes are available (smoke test)."""
+ # Check available routes include /api prefix from backend_router
+ routes = [route.path for route in app.router.routes]
+ backend_routes = [r for r in routes if r.startswith("/api")]
+ assert backend_routes, "No backend routes found under /api prefix"
diff --git a/src/tests/backend/common/config/config_test.py b/src/tests/backend/common/config/config_test.py
index 16f52ea..6984ae8 100644
--- a/src/tests/backend/common/config/config_test.py
+++ b/src/tests/backend/common/config/config_test.py
@@ -1,62 +1,67 @@
-import unittest
-from unittest.mock import patch
-
-# from config import Config
-from common.config.config import Config
-
-
-class TestConfigInitialization(unittest.TestCase):
- @patch.dict(
- "os.environ",
- {
- "AZURE_TENANT_ID": "test-tenant-id",
- "AZURE_CLIENT_ID": "test-client-id",
- "AZURE_CLIENT_SECRET": "test-client-secret",
- "COSMOSDB_DATABASE": "test-database",
- "COSMOSDB_BATCH_CONTAINER": "test-batch-container",
- "COSMOSDB_FILE_CONTAINER": "test-file-container",
- "COSMOSDB_LOG_CONTAINER": "test-log-container",
- "AZURE_BLOB_CONTAINER_NAME": "test-blob-container-name",
- "AZURE_BLOB_ACCOUNT_NAME": "test-blob-account-name",
- },
- clear=True,
- )
- def test_config_initialization(self):
- """Test if all attributes are correctly assigned from environment variables"""
- config = Config()
-
- # Ensure every attribute is accessed
- self.assertEqual(config.azure_tenant_id, "test-tenant-id")
- self.assertEqual(config.azure_client_id, "test-client-id")
- self.assertEqual(config.azure_client_secret, "test-client-secret")
-
- self.assertEqual(config.cosmosdb_endpoint, "test-cosmosdb-endpoint")
- self.assertEqual(config.cosmosdb_database, "test-database")
- self.assertEqual(config.cosmosdb_batch_container, "test-batch-container")
- self.assertEqual(config.cosmosdb_file_container, "test-file-container")
- self.assertEqual(config.cosmosdb_log_container, "test-log-container")
-
- self.assertEqual(config.azure_blob_container_name, "test-blob-container-name")
- self.assertEqual(config.azure_blob_account_name, "test-blob-account-name")
-
- @patch.dict(
- "os.environ",
- {
- "COSMOSDB_ENDPOINT": "test-cosmosdb-endpoint",
- "COSMOSDB_DATABASE": "test-database",
- "COSMOSDB_BATCH_CONTAINER": "test-batch-container",
- "COSMOSDB_FILE_CONTAINER": "test-file-container",
- "COSMOSDB_LOG_CONTAINER": "test-log-container",
- },
- )
- def test_cosmosdb_config_initialization(self):
- config = Config()
- self.assertEqual(config.cosmosdb_endpoint, "test-cosmosdb-endpoint")
- self.assertEqual(config.cosmosdb_database, "test-database")
- self.assertEqual(config.cosmosdb_batch_container, "test-batch-container")
- self.assertEqual(config.cosmosdb_file_container, "test-file-container")
- self.assertEqual(config.cosmosdb_log_container, "test-log-container")
-
-
-if __name__ == "__main__":
- unittest.main()
+import pytest
+
+
+@pytest.fixture(autouse=True)
+def clear_env(monkeypatch):
+ # Clear environment variables that might affect tests.
+ keys = [
+ "AZURE_TENANT_ID",
+ "AZURE_CLIENT_ID",
+ "AZURE_CLIENT_SECRET",
+ "COSMOSDB_ENDPOINT",
+ "COSMOSDB_DATABASE",
+ "COSMOSDB_BATCH_CONTAINER",
+ "COSMOSDB_FILE_CONTAINER",
+ "COSMOSDB_LOG_CONTAINER",
+ "AZURE_BLOB_CONTAINER_NAME",
+ "AZURE_BLOB_ACCOUNT_NAME",
+ ]
+ for key in keys:
+ monkeypatch.delenv(key, raising=False)
+
+
+def test_config_initialization(monkeypatch):
+ # Set the full configuration environment variables.
+ monkeypatch.setenv("AZURE_TENANT_ID", "test-tenant-id")
+ monkeypatch.setenv("AZURE_CLIENT_ID", "test-client-id")
+ monkeypatch.setenv("AZURE_CLIENT_SECRET", "test-client-secret")
+ monkeypatch.setenv("COSMOSDB_ENDPOINT", "test-cosmosdb-endpoint")
+ monkeypatch.setenv("COSMOSDB_DATABASE", "test-database")
+ monkeypatch.setenv("COSMOSDB_BATCH_CONTAINER", "test-batch-container")
+ monkeypatch.setenv("COSMOSDB_FILE_CONTAINER", "test-file-container")
+ monkeypatch.setenv("COSMOSDB_LOG_CONTAINER", "test-log-container")
+ monkeypatch.setenv("AZURE_BLOB_CONTAINER_NAME", "test-blob-container-name")
+ monkeypatch.setenv("AZURE_BLOB_ACCOUNT_NAME", "test-blob-account-name")
+
+ # Local import to avoid triggering circular imports during module collection.
+ from common.config.config import Config
+ config = Config()
+
+ assert config.azure_tenant_id == "test-tenant-id"
+ assert config.azure_client_id == "test-client-id"
+ assert config.azure_client_secret == "test-client-secret"
+ assert config.cosmosdb_endpoint == "test-cosmosdb-endpoint"
+ assert config.cosmosdb_database == "test-database"
+ assert config.cosmosdb_batch_container == "test-batch-container"
+ assert config.cosmosdb_file_container == "test-file-container"
+ assert config.cosmosdb_log_container == "test-log-container"
+ assert config.azure_blob_container_name == "test-blob-container-name"
+ assert config.azure_blob_account_name == "test-blob-account-name"
+
+
+def test_cosmosdb_config_initialization(monkeypatch):
+ # Set only cosmosdb-related environment variables.
+ monkeypatch.setenv("COSMOSDB_ENDPOINT", "test-cosmosdb-endpoint")
+ monkeypatch.setenv("COSMOSDB_DATABASE", "test-database")
+ monkeypatch.setenv("COSMOSDB_BATCH_CONTAINER", "test-batch-container")
+ monkeypatch.setenv("COSMOSDB_FILE_CONTAINER", "test-file-container")
+ monkeypatch.setenv("COSMOSDB_LOG_CONTAINER", "test-log-container")
+
+ from common.config.config import Config
+ config = Config()
+
+ assert config.cosmosdb_endpoint == "test-cosmosdb-endpoint"
+ assert config.cosmosdb_database == "test-database"
+ assert config.cosmosdb_batch_container == "test-batch-container"
+ assert config.cosmosdb_file_container == "test-file-container"
+ assert config.cosmosdb_log_container == "test-log-container"
diff --git a/src/tests/backend/common/database/cosmosdb_test.py b/src/tests/backend/common/database/cosmosdb_test.py
index 44521e1..df53fde 100644
--- a/src/tests/backend/common/database/cosmosdb_test.py
+++ b/src/tests/backend/common/database/cosmosdb_test.py
@@ -1,622 +1,1117 @@
-import asyncio
-import uuid
-from datetime import datetime
-import enum
-import pytest
-from azure.cosmos import PartitionKey, exceptions
-
-from common.database.cosmosdb import CosmosDBClient
-from common.models.api import (
- BatchRecord,
- FileRecord,
- ProcessStatus,
- FileLog,
- LogType,
+import os
+import sys
+# Add backend directory to sys.path
+sys.path.insert(
+ 0,
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..", "backend")),
)
-from common.logger.app_logger import AppLogger
+from datetime import datetime, timezone # noqa: E402
+from unittest import mock # noqa: E402
+from unittest.mock import AsyncMock # noqa: E402
+from uuid import uuid4 # noqa: E402
+from azure.cosmos.aio import CosmosClient # noqa: E402
+from azure.cosmos.exceptions import CosmosResourceExistsError # noqa: E402
-# --- Enums for Testing ---
-class DummyProcessStatus(enum.Enum):
- READY_TO_PROCESS = "READY"
- PROCESSING = "PROCESSING"
+from common.database.cosmosdb import ( # noqa: E402
+ CosmosDBClient,
+)
+from common.models.api import ( # noqa: E402
+ AgentType,
+ AuthorRole,
+ BatchRecord,
+ FileRecord,
+ LogType,
+ ProcessStatus,
+) # noqa: E402
+import pytest # noqa: E402
-class DummyLogType(enum.Enum):
- INFO = "INFO"
- ERROR = "ERROR"
+# Mocked data for the test
+endpoint = "https://fake.cosmosdb.azure.com"
+credential = "fake_credential"
+database_name = "test_database"
+batch_container = "batch_container"
+file_container = "file_container"
+log_container = "log_container"
-@pytest.fixture(autouse=True)
-def patch_enums(monkeypatch):
- monkeypatch.setattr("common.models.api.ProcessStatus", DummyProcessStatus)
- monkeypatch.setattr("common.models.api.LogType", DummyLogType)
-
-
-# --- implementations to simulate Cosmos DB behavior ---
-async def async_query_generator(items):
- for item in items:
- yield item
-
-
-async def async_query_error_generator(*args, **kwargs):
- raise Exception("Error in query")
- if False:
- yield
-
-
-class DummyContainerClient:
- def __init__(self, container_name):
- self.container_name = container_name
- self.created_items = []
- self.deleted_items = []
- self._query_items_func = None
-
- async def create_item(self, body):
- self.created_items.append(body)
-
- async def replace_item(self, item, body):
- return body
-
- async def delete_item(self, item, partition_key=None):
- self.deleted_items.append((item, partition_key))
-
- async def delete_items(self, key):
- self.deleted_items.append(key)
-
- async def query_items(self, query, parameters):
- if self._query_items_func:
- async for item in self._query_items_func(query, parameters):
- yield item
- else:
- if False:
- yield
-
- def set_query_items(self, func):
- self._query_items_func = func
-
-
-class DummyDatabase:
- def __init__(self, database_name):
- self.database_name = database_name
- self.containers = {}
-
- async def create_container(self, id, partition_key):
- if id in self.containers:
- raise exceptions.CosmosResourceExistsError(404, "Container exists")
- container = DummyContainerClient(id)
- self.containers[id] = container
- return container
-
- def get_container_client(self, container_name):
- return self.containers.get(container_name, DummyContainerClient(container_name))
-
-
-class DummyCosmosClient:
- def __init__(self, url, credential):
- self.url = url
- self.credential = credential
- self._database = DummyDatabase("dummy_db")
- self.closed = False
-
- def get_database_client(self, database_name):
- return self._database
-
- def close(self):
- self.closed = True
-
-
-class FakeCosmosDBClient(CosmosDBClient):
- async def _async_init(
- self,
- endpoint: str,
- credential: any,
- database_name: str,
- batch_container: str,
- file_container: str,
- log_container: str,
- ):
- self.endpoint = endpoint
- self.credential = credential
- self.database_name = database_name
- self.batch_container_name = batch_container
- self.file_container_name = file_container
- self.log_container_name = log_container
- self.logger = AppLogger("CosmosDB")
- self.client = DummyCosmosClient(endpoint, credential)
- db = self.client.get_database_client(database_name)
- self.batch_container = await db.create_container(
- batch_container, PartitionKey(path="/batch_id")
- )
- self.file_container = await db.create_container(
- file_container, PartitionKey(path="/file_id")
- )
- self.log_container = await db.create_container(
- log_container, PartitionKey(path="/log_id")
- )
-
- @classmethod
- async def create(
- cls,
- endpoint,
- credential,
- database_name,
- batch_container,
- file_container,
- log_container,
- ):
- instance = cls.__new__(cls)
- await instance._async_init(
- endpoint,
- credential,
- database_name,
- batch_container,
- file_container,
- log_container,
- )
- return instance
-
- # Minimal implementations for abstract methods not under test.
- async def delete_file_logs(self, file_id: str) -> None:
- await self.log_container.delete_items(file_id)
-
- async def log_batch_status(
- self, batch_id: str, status: ProcessStatus, processed_files: int
- ) -> None:
- return
-
-
-# --- Fixture ---
@pytest.fixture
-def cosmosdb_client(event_loop):
- client = event_loop.run_until_complete(
- FakeCosmosDBClient.create(
- endpoint="dummy_endpoint",
- credential="dummy_credential",
- database_name="dummy_db",
- batch_container="batch",
- file_container="file",
- log_container="log",
- )
+def cosmos_db_client():
+ return CosmosDBClient(
+ endpoint=endpoint,
+ credential=credential,
+ database_name=database_name,
+ batch_container=batch_container,
+ file_container=file_container,
+ log_container=log_container,
)
- return client
-
-
-# --- Test Cases ---
@pytest.mark.asyncio
-async def test_initialization_success(cosmosdb_client):
- assert cosmosdb_client.client is not None
- assert cosmosdb_client.batch_container is not None
- assert cosmosdb_client.file_container is not None
- assert cosmosdb_client.log_container is not None
+async def test_initialize_cosmos(cosmos_db_client, mocker):
+ # Mocking CosmosClient and its methods
+ mock_client = mocker.patch.object(CosmosClient, 'get_database_client', return_value=mock.MagicMock())
+ mock_database = mock_client.return_value
+
+ # Use AsyncMock for asynchronous methods
+ mock_batch_container = mock.MagicMock()
+ mock_file_container = mock.MagicMock()
+ mock_log_container = mock.MagicMock()
+
+ # Use AsyncMock to mock asynchronous container creation
+ mock_database.create_container = AsyncMock(side_effect=[
+ mock_batch_container,
+ mock_file_container,
+ mock_log_container
+ ])
+
+ # Call the initialize_cosmos method
+ await cosmos_db_client.initialize_cosmos()
+
+ # Assert that the containers were created or fetched successfully
+ mock_database.create_container.assert_any_call(id=batch_container, partition_key=mock.ANY)
+ mock_database.create_container.assert_any_call(id=file_container, partition_key=mock.ANY)
+ mock_database.create_container.assert_any_call(id=log_container, partition_key=mock.ANY)
+
+ # Check the client and containers were set
+ assert cosmos_db_client.client is not None
+ assert cosmos_db_client.batch_container == mock_batch_container
+ assert cosmos_db_client.file_container == mock_file_container
+ assert cosmos_db_client.log_container == mock_log_container
@pytest.mark.asyncio
-async def test_init_error(monkeypatch):
- async def fake_async_init(*args, **kwargs):
- raise Exception("client error")
+async def test_initialize_cosmos_with_error(cosmos_db_client, mocker):
+ # Mocking CosmosClient and its methods
+ mock_client = mocker.patch.object(CosmosClient, 'get_database_client', return_value=mock.MagicMock())
+ mock_database = mock_client.return_value
- monkeypatch.setattr(FakeCosmosDBClient, "_async_init", fake_async_init)
+ # Simulate a general exception during container creation
+ mock_database.create_container = AsyncMock(side_effect=Exception("Failed to create container"))
+
+ # Call the initialize_cosmos method and expect it to raise an error
with pytest.raises(Exception) as exc_info:
- await FakeCosmosDBClient.create("dummy", "dummy", "dummy", "a", "b", "c")
- assert "client error" in str(exc_info.value)
+ await cosmos_db_client.initialize_cosmos()
+
+ # Assert that the exception message matches the expected message
+ assert str(exc_info.value) == "Failed to create container"
@pytest.mark.asyncio
-async def test_get_or_create_container_existing(monkeypatch, cosmosdb_client):
- db = DummyDatabase("dummy_db")
- existing = DummyContainerClient("existing")
- db.containers["existing"] = existing
+async def test_initialize_cosmos_container_exists_error(cosmos_db_client, mocker):
+ # Mocking CosmosClient and its methods
+ mock_client = mocker.patch.object(CosmosClient, 'get_database_client', return_value=mock.MagicMock())
+ mock_database = mock_client.return_value
+
+ # Simulating CosmosResourceExistsError for container creation
+ mock_database.create_container = AsyncMock(side_effect=CosmosResourceExistsError)
+
+ # Use AsyncMock for asynchronous methods
+ mock_batch_container = mock.MagicMock()
+ mock_file_container = mock.MagicMock()
+ mock_log_container = mock.MagicMock()
+
+ # Use AsyncMock to mock asynchronous container creation
+ mock_database.create_container = AsyncMock(side_effect=[
+ mock_batch_container,
+ mock_file_container,
+ mock_log_container
+ ])
- async def fake_create_container(id, partition_key):
- raise exceptions.CosmosResourceExistsError(404, "Container exists")
+ # Call the initialize_cosmos method
+ await cosmos_db_client.initialize_cosmos()
- monkeypatch.setattr(db, "create_container", fake_create_container)
- monkeypatch.setattr(db, "get_container_client", lambda name: existing)
+ # Assert that the container creation method was called with the correct arguments
+ mock_database.create_container.assert_any_call(id='batch_container', partition_key=mock.ANY)
+ mock_database.create_container.assert_any_call(id='file_container', partition_key=mock.ANY)
+ mock_database.create_container.assert_any_call(id='log_container', partition_key=mock.ANY)
- # Directly call _get_or_create_container on a new instance.
- instance = FakeCosmosDBClient.__new__(FakeCosmosDBClient)
- instance.logger = AppLogger("CosmosDB")
- result = await instance._get_or_create_container(db, "existing", "/id")
- assert result is existing
+ # Check that existing containers are returned (mocked containers)
+ assert cosmos_db_client.batch_container == mock_batch_container
+ assert cosmos_db_client.file_container == mock_file_container
+ assert cosmos_db_client.log_container == mock_log_container
@pytest.mark.asyncio
-async def test_create_batch_success(monkeypatch, cosmosdb_client):
- called = False
+async def test_create_batch_new(cosmos_db_client, mocker):
+ user_id = "user_1"
+ batch_id = uuid4()
- async def fake_create_item(body):
- nonlocal called
- called = True
+ # Mock container creation
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
- monkeypatch.setattr(
- cosmosdb_client.batch_container, "create_item", fake_create_item
- )
- bid = uuid.uuid4()
- batch = await cosmosdb_client.create_batch("user1", bid)
- assert batch.batch_id == bid
- assert batch.user_id == "user1"
- assert called
+ # Mock the method to return the batch
+ mock_batch_container.create_item = AsyncMock(return_value=None)
+
+ # Call the method
+ batch = await cosmos_db_client.create_batch(user_id, batch_id)
+
+ # Assert that the batch is created
+ assert batch.batch_id == batch_id
+ assert batch.user_id == user_id
+ assert batch.status == ProcessStatus.READY_TO_PROCESS
+
+ mock_batch_container.create_item.assert_called_once_with(body=batch.dict())
@pytest.mark.asyncio
-async def test_create_batch_error(monkeypatch, cosmosdb_client):
- async def fake_create_item(body):
- raise Exception("Batch creation error")
+async def test_create_batch_exists(cosmos_db_client, mocker):
+ user_id = "user_1"
+ batch_id = uuid4()
- monkeypatch.setattr(
- cosmosdb_client.batch_container, "create_item", fake_create_item
- )
- with pytest.raises(Exception) as exc_info:
- await cosmosdb_client.create_batch("user1", uuid.uuid4())
- assert "Batch creation error" in str(exc_info.value)
+ # Mock container creation and get_batch
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)
+
+ # Mock the get_batch method
+ mock_get_batch = AsyncMock(return_value=BatchRecord(
+ batch_id=batch_id,
+ user_id=user_id,
+ file_count=0,
+ created_at=datetime.now(timezone.utc),
+ updated_at=datetime.now(timezone.utc),
+ status=ProcessStatus.READY_TO_PROCESS
+ ))
+ mocker.patch.object(cosmos_db_client, 'get_batch', mock_get_batch)
+
+ # Call the method
+ batch = await cosmos_db_client.create_batch(user_id, batch_id)
+
+ # Assert that batch was fetched (not created) due to already existing
+ assert batch.batch_id == batch_id
+ assert batch.user_id == user_id
+ assert batch.status == ProcessStatus.READY_TO_PROCESS
+
+ mock_get_batch.assert_called_once_with(user_id, str(batch_id))
@pytest.mark.asyncio
-async def test_add_file_success(monkeypatch, cosmosdb_client):
- called = False
+async def test_create_batch_exception(cosmos_db_client, mocker):
+ user_id = "user_1"
+ batch_id = uuid4()
- async def fake_create_item(body):
- nonlocal called
- called = True
+ # Mock the batch_container and make create_item raise a general Exception
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.create_item = AsyncMock(side_effect=Exception("Unexpected Error"))
- monkeypatch.setattr(cosmosdb_client.file_container, "create_item", fake_create_item)
- bid = uuid.uuid4()
- fid = uuid.uuid4()
- fs = await cosmosdb_client.add_file(bid, fid, "test.txt", "path/to/blob")
- assert fs.file_id == fid
- assert fs.original_name == "test.txt"
- assert fs.blob_path == "path/to/blob"
- assert called
+ # Mock the logger to verify logging
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Call the method and assert it raises the exception
+ with pytest.raises(Exception, match="Unexpected Error"):
+ await cosmos_db_client.create_batch(user_id, batch_id)
+
+ # Ensure logger.error was called with expected message and error
+ mock_logger.error.assert_called_once()
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to create batch"
+ assert "error" in called_kwargs
+ assert "Unexpected Error" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_add_file_error(monkeypatch, cosmosdb_client):
- async def fake_create_item(body):
- raise Exception("Add file error")
+async def test_add_file(cosmos_db_client, mocker):
+ batch_id = uuid4()
+ file_id = uuid4()
+ file_name = "file.txt"
+ storage_path = "/path/to/storage"
- monkeypatch.setattr(
- cosmosdb_client.file_container,
- "create_item",
- lambda *args, **kwargs: fake_create_item(*args, **kwargs),
- )
- with pytest.raises(Exception) as exc_info:
- await cosmosdb_client.add_file(
- uuid.uuid4(), uuid.uuid4(), "test.txt", "path/to/blob"
- )
- assert "Add file error" in str(exc_info.value)
+ # Mock file container creation
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+
+ # Mock the create_item method
+ mock_file_container.create_item = AsyncMock(return_value=None)
+
+ # Call the method
+ file_record = await cosmos_db_client.add_file(batch_id, file_id, file_name, storage_path)
+
+ # Assert that the file record is created
+ assert file_record.file_id == file_id
+ assert file_record.batch_id == batch_id
+ assert file_record.original_name == file_name
+ assert file_record.blob_path == storage_path
+ assert file_record.status == ProcessStatus.READY_TO_PROCESS
+
+ mock_file_container.create_item.assert_called_once_with(body=file_record.dict())
@pytest.mark.asyncio
-async def test_get_batch_success(monkeypatch, cosmosdb_client):
- batch_item = {
- "id": "batch1",
- "user_id": "user1",
- "created_at": datetime.utcnow().isoformat(),
- }
- file_item = {"file_id": "file1", "batch_id": "batch1"}
+async def test_add_file_exception(cosmos_db_client, mocker):
+ batch_id = uuid4()
+ file_id = uuid4()
+ file_name = "document.pdf"
+ storage_path = "/files/document.pdf"
+
+ # Mock file_container.create_item to raise a general exception
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+ mock_file_container.create_item = AsyncMock(side_effect=Exception("Insert failed"))
- async def fake_query_items_batch(*args, **kwargs):
- for item in [batch_item]:
- yield item
+ # Mock logger to capture error logs
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
- async def fake_query_items_files(*args, **kwargs):
- for item in [file_item]:
- yield item
+ # Expect an exception when calling add_file
+ with pytest.raises(Exception, match="Insert failed"):
+ await cosmos_db_client.add_file(batch_id, file_id, file_name, storage_path)
- cosmosdb_client.batch_container.set_query_items(fake_query_items_batch)
- cosmosdb_client.file_container.set_query_items(fake_query_items_files)
- result = await cosmosdb_client.get_batch("user1", "batch1")
- assert result is not None
- assert result.get("id") == "batch1"
+ # Check that logger.error was called properly
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to add file"
+ assert "error" in called_kwargs
+ assert "Insert failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_get_batch_not_found(monkeypatch, cosmosdb_client):
- async def fake_query_items(*args, **kwargs):
- if False:
- yield
+async def test_update_file(cosmos_db_client, mocker):
+ file_id = uuid4()
+ file_record = FileRecord(
+ file_id=file_id,
+ batch_id=uuid4(),
+ original_name="file.txt",
+ blob_path="/path/to/storage",
+ translated_path="",
+ status=ProcessStatus.READY_TO_PROCESS,
+ error_count=0,
+ syntax_count=0,
+ created_at=datetime.now(timezone.utc),
+ updated_at=datetime.now(timezone.utc)
+ )
+
+ # Mock file container replace_item method
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+ mock_file_container.replace_item = AsyncMock(return_value=None)
+
+ # Call the method
+ updated_file_record = await cosmos_db_client.update_file(file_record)
- cosmosdb_client.batch_container.set_query_items(fake_query_items)
- result = await cosmosdb_client.get_batch("user1", "nonexistent")
- assert result is None
+ # Assert that the file record is updated
+ assert updated_file_record.file_id == file_id
+
+ mock_file_container.replace_item.assert_called_once_with(item=str(file_id), body=file_record.dict())
@pytest.mark.asyncio
-async def test_get_batch_error(monkeypatch, cosmosdb_client):
- async def fake_query_items(*args, **kwargs):
- raise Exception("Query batch error")
- if False:
- yield
+async def test_update_file_exception(cosmos_db_client, mocker):
+ # Create a sample FileRecord
+ file_record = FileRecord(
+ file_id=uuid4(),
+ batch_id=uuid4(),
+ original_name="file.txt",
+ blob_path="/storage/file.txt",
+ translated_path="",
+ status=ProcessStatus.READY_TO_PROCESS,
+ error_count=0,
+ syntax_count=0,
+ created_at=datetime.now(timezone.utc),
+ updated_at=datetime.now(timezone.utc),
+ )
- monkeypatch.setattr(
- cosmosdb_client.batch_container,
- "query_items",
- lambda *args, **kwargs: fake_query_items(*args, **kwargs),
+ # Mock file_container.replace_item to raise an exception
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+ mock_file_container.replace_item = AsyncMock(side_effect=Exception("Update failed"))
+
+ # Mock logger
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Expect an exception when update_file is called
+ with pytest.raises(Exception, match="Update failed"):
+ await cosmos_db_client.update_file(file_record)
+
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to update file"
+ assert "error" in called_kwargs
+ assert "Update failed" in called_kwargs["error"]
+
+
+@pytest.mark.asyncio
+async def test_update_batch(cosmos_db_client, mocker):
+ batch_record = BatchRecord(
+ batch_id=uuid4(),
+ user_id="user_1",
+ file_count=0,
+ created_at=datetime.now(timezone.utc),
+ updated_at=datetime.now(timezone.utc),
+ status=ProcessStatus.READY_TO_PROCESS
)
- with pytest.raises(Exception) as exc_info:
- await cosmosdb_client.get_batch("user1", "batch1")
- assert "Query batch error" in str(exc_info.value)
+
+ # Mock batch container replace_item method
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.replace_item = AsyncMock(return_value=None)
+
+ # Call the method
+ updated_batch_record = await cosmos_db_client.update_batch(batch_record)
+
+ # Assert that the batch record is updated
+ assert updated_batch_record.batch_id == batch_record.batch_id
+
+ mock_batch_container.replace_item.assert_called_once_with(item=str(batch_record.batch_id), body=batch_record.dict())
@pytest.mark.asyncio
-async def test_get_file_success(monkeypatch, cosmosdb_client):
- file_item = {"file_id": "file1", "original_name": "test.txt"}
+async def test_update_batch_exception(cosmos_db_client, mocker):
+ # Create a sample BatchRecord
+ batch_record = BatchRecord(
+ batch_id=uuid4(),
+ user_id="user_1",
+ file_count=3,
+ created_at=datetime.now(timezone.utc),
+ updated_at=datetime.now(timezone.utc),
+ status=ProcessStatus.READY_TO_PROCESS,
+ )
+
+ # Mock batch_container.replace_item to raise an exception
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.replace_item = AsyncMock(side_effect=Exception("Update batch failed"))
- async def fake_query_items(*args, **kwargs):
- for item in [file_item]:
- yield item
+ # Mock logger to verify logging
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
- cosmosdb_client.file_container.set_query_items(fake_query_items)
- result = await cosmosdb_client.get_file("file1")
- assert result == file_item
+ # Expect an exception when update_batch is called
+ with pytest.raises(Exception, match="Update batch failed"):
+ await cosmos_db_client.update_batch(batch_record)
+
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to update batch"
+ assert "error" in called_kwargs
+ assert "Update batch failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_get_file_error(monkeypatch, cosmosdb_client):
- async def fake_query_items(*args, **kwargs):
- raise Exception("Query file error")
- if False:
- yield
+async def test_get_batch(cosmos_db_client, mocker):
+ user_id = "user_1"
+ batch_id = str(uuid4())
+
+ # Mock batch container query_items method
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, "batch_container", mock_batch_container)
+
+ # Simulate the query result
+ expected_batch = {
+ "batch_id": batch_id,
+ "user_id": user_id,
+ "file_count": 0,
+ "status": ProcessStatus.READY_TO_PROCESS,
+ }
- monkeypatch.setattr(
- cosmosdb_client.file_container,
- "query_items",
- lambda *args, **kwargs: fake_query_items(*args, **kwargs),
+ # We define the async generator function that will yield the expected batch
+ async def mock_query_items(query, parameters):
+ yield expected_batch
+
+ # Assign the async generator to query_items mock
+ mock_batch_container.query_items.side_effect = mock_query_items
+ # Call the method
+ batch = await cosmos_db_client.get_batch(user_id, batch_id)
+
+ # Assert the batch is returned correctly
+ assert batch["batch_id"] == batch_id
+ assert batch["user_id"] == user_id
+
+ mock_batch_container.query_items.assert_called_once_with(
+ query="SELECT * FROM c WHERE c.batch_id = @batch_id and c.user_id = @user_id",
+ parameters=[
+ {"name": "@batch_id", "value": batch_id},
+ {"name": "@user_id", "value": user_id},
+ ],
)
- with pytest.raises(Exception) as exc_info:
- await cosmosdb_client.get_file("file1")
- assert "Query file error" in str(exc_info.value)
@pytest.mark.asyncio
-async def test_get_batch_files_success(monkeypatch, cosmosdb_client):
- file_item = {"file_id": "file1", "batch_id": "batch1"}
+async def test_get_batch_exception(cosmos_db_client, mocker):
+ user_id = "user_1"
+ batch_id = str(uuid4())
+
+ # Mock batch_container.query_items to raise an exception
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.query_items = mock.MagicMock(
+ side_effect=Exception("Get batch failed")
+ )
+
+ # Patch logger
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Call get_batch and expect it to raise an exception
+ with pytest.raises(Exception, match="Get batch failed"):
+ await cosmos_db_client.get_batch(user_id, batch_id)
+
+ # Ensure logger.error was called with the expected error message
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to get batch"
+ assert "error" in called_kwargs
+ assert "Get batch failed" in called_kwargs["error"]
+
+
+@pytest.mark.asyncio
+async def test_get_file(cosmos_db_client, mocker):
+ file_id = str(uuid4())
+
+ # Mock file container query_items method
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+
+ # Simulate the query result
+ expected_file = {
+ "file_id": file_id,
+ "status": ProcessStatus.READY_TO_PROCESS,
+ "original_name": "file.txt",
+ "blob_path": "/path/to/file"
+ }
- async def fake_query_items(*args, **kwargs):
- for item in [file_item]:
- yield item
+ # We define the async generator function that will yield the expected batch
+ async def mock_query_items(query, parameters):
+ yield expected_file
- cosmosdb_client.file_container.set_query_items(fake_query_items)
- files = await cosmosdb_client.get_batch_files("user1", "batch1")
- assert files == [file_item]
+ # Assign the async generator to query_items mock
+ mock_file_container.query_items.side_effect = mock_query_items
+
+ # Call the method
+ file = await cosmos_db_client.get_file(file_id)
+
+ # Assert the file is returned correctly
+ assert file["file_id"] == file_id
+ assert file["status"] == ProcessStatus.READY_TO_PROCESS
+
+ mock_file_container.query_items.assert_called_once()
@pytest.mark.asyncio
-async def test_get_user_batches_success(monkeypatch, cosmosdb_client):
- batch_item1 = {"id": "batch1", "user_id": "user1"}
- batch_item2 = {"id": "batch2", "user_id": "user1"}
+async def test_get_file_exception(cosmos_db_client, mocker):
+ file_id = str(uuid4())
+
+ # Mock file_container.query_items to raise an exception
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+ mock_file_container.query_items = mock.MagicMock(
+ side_effect=Exception("Get file failed")
+ )
+
+ # Mock logger to verify logging
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
- async def fake_query_items(*args, **kwargs):
- for item in [batch_item1, batch_item2]:
- yield item
+ # Call get_file and expect an exception
+ with pytest.raises(Exception, match="Get file failed"):
+ await cosmos_db_client.get_file(file_id)
- cosmosdb_client.batch_container.set_query_items(fake_query_items)
- result = await cosmosdb_client.get_user_batches("user1")
- assert result == [batch_item1, batch_item2]
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to get file"
+ assert "error" in called_kwargs
+ assert "Get file failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_get_user_batches_error(monkeypatch, cosmosdb_client):
- async def fake_query_items(*args, **kwargs):
- raise Exception("User batches error")
- if False:
- yield
+async def test_get_batch_files(cosmos_db_client, mocker):
+ batch_id = str(uuid4())
+
+ # Mock file container query_items method
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+
+ # Simulate the query result for multiple files
+ expected_files = [
+ {
+ "file_id": str(uuid4()),
+ "status": ProcessStatus.READY_TO_PROCESS,
+ "original_name": "file1.txt",
+ "blob_path": "/path/to/file1"
+ },
+ {
+ "file_id": str(uuid4()),
+ "status": ProcessStatus.IN_PROGRESS,
+ "original_name": "file2.txt",
+ "blob_path": "/path/to/file2"
+ }
+ ]
+
+ # Define the async generator function to yield the expected files
+ async def mock_query_items(query, parameters):
+ for file in expected_files:
+ yield file
+
+ # Set the side_effect of query_items to simulate async iteration
+ mock_file_container.query_items.side_effect = mock_query_items
+
+ # Call the method
+ files = await cosmos_db_client.get_batch_files(batch_id)
+
+ # Assert the files list contains the correct files
+ assert len(files) == len(expected_files)
+ assert files[0]["file_id"] == expected_files[0]["file_id"]
+ assert files[1]["file_id"] == expected_files[1]["file_id"]
+
+ mock_file_container.query_items.assert_called_once()
- monkeypatch.setattr(
- cosmosdb_client.batch_container,
- "query_items",
- lambda *args, **kwargs: fake_query_items(*args, **kwargs),
+
+@pytest.mark.asyncio
+async def test_get_batch_files_exception(cosmos_db_client, mocker):
+ batch_id = str(uuid4())
+
+ # Mock file_container.query_items to raise an exception
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+ mock_file_container.query_items = mock.MagicMock(
+ side_effect=Exception("Get batch file failed")
)
- with pytest.raises(Exception) as exc_info:
- await cosmosdb_client.get_user_batches("user1")
- assert "User batches error" in str(exc_info.value)
+
+ # Mock logger to verify logging
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Expect the exception to be raised
+ with pytest.raises(Exception, match="Get batch file failed"):
+ await cosmos_db_client.get_batch_files(batch_id)
+
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to get files"
+ assert "error" in called_kwargs
+ assert "Get batch file failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_get_file_logs_success(monkeypatch, cosmosdb_client):
- log_item = {
- "file_id": "file1",
- "description": "log",
- "timestamp": datetime.utcnow().isoformat(),
+async def test_get_batch_from_id(cosmos_db_client, mocker):
+ batch_id = str(uuid4())
+
+ # Mock batch container query_items method
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+
+ # Simulate the query result
+ expected_batch = {
+ "batch_id": batch_id,
+ "status": ProcessStatus.READY_TO_PROCESS,
+ "user_id": "user_123",
}
- async def fake_query_items(*args, **kwargs):
- for item in [log_item]:
- yield item
+ # Define the async generator function that will yield the expected batch
+ async def mock_query_items(query, parameters):
+ yield expected_batch
- cosmosdb_client.log_container.set_query_items(fake_query_items)
- result = await cosmosdb_client.get_file_logs("file1")
- assert result == [log_item]
+ # Assign the async generator to query_items mock
+ mock_batch_container.query_items.side_effect = mock_query_items
+ # Call the method
+ batch = await cosmos_db_client.get_batch_from_id(batch_id)
-@pytest.mark.asyncio
-async def test_get_file_logs_error(monkeypatch, cosmosdb_client):
- async def fake_query_items(*args, **kwargs):
- raise Exception("Log query error")
- if False:
- yield
+ # Assert the batch is returned correctly
+ assert batch["batch_id"] == batch_id
+ assert batch["status"] == ProcessStatus.READY_TO_PROCESS
+
+ mock_batch_container.query_items.assert_called_once()
- monkeypatch.setattr(
- cosmosdb_client.log_container,
- "query_items",
- lambda *args, **kwargs: fake_query_items(*args, **kwargs),
+
+@pytest.mark.asyncio
+async def test_get_batch_from_id_exception(cosmos_db_client, mocker):
+ batch_id = str(uuid4())
+
+ # Mock batch_container.query_items to raise an exception
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.query_items = mock.MagicMock(
+ side_effect=Exception("Get batch from id failed")
)
- with pytest.raises(Exception) as exc_info:
- await cosmosdb_client.get_file_logs("file1")
- assert "Log query error" in str(exc_info.value)
+
+ # Mock logger to verify logging
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Call the method and expect it to raise an exception
+ with pytest.raises(Exception, match="Get batch from id failed"):
+ await cosmos_db_client.get_batch_from_id(batch_id)
+
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to get batch from ID"
+ assert "error" in called_kwargs
+ assert "Get batch from id failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_delete_all_success(monkeypatch, cosmosdb_client):
- async def fake_delete_items(key):
- return
+async def test_get_user_batches(cosmos_db_client, mocker):
+ user_id = "user_123"
- monkeypatch.setattr(
- cosmosdb_client.batch_container, "delete_items", fake_delete_items
- )
- monkeypatch.setattr(
- cosmosdb_client.file_container, "delete_items", fake_delete_items
- )
- monkeypatch.setattr(
- cosmosdb_client.log_container, "delete_items", fake_delete_items
+ # Mock batch container query_items method
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+
+ # Simulate the query result
+ expected_batches = [
+ {"batch_id": str(uuid4()), "status": ProcessStatus.READY_TO_PROCESS, "user_id": user_id},
+ {"batch_id": str(uuid4()), "status": ProcessStatus.IN_PROGRESS, "user_id": user_id}
+ ]
+
+ # Define the async generator function that will yield the expected batches
+ async def mock_query_items(query, parameters):
+ for batch in expected_batches:
+ yield batch
+
+ # Assign the async generator to query_items mock
+ mock_batch_container.query_items.side_effect = mock_query_items
+
+ # Call the method
+ batches = await cosmos_db_client.get_user_batches(user_id)
+
+ # Assert the batches are returned correctly
+ assert len(batches) == 2
+ assert batches[0]["status"] == ProcessStatus.READY_TO_PROCESS
+ assert batches[1]["status"] == ProcessStatus.IN_PROGRESS
+
+ mock_batch_container.query_items.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_get_user_batches_exception(cosmos_db_client, mocker):
+ user_id = "user_" + str(uuid4())
+
+ # Mock batch_container.query_items to raise an exception
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.query_items = mock.MagicMock(
+ side_effect=Exception("Get user batch failed")
)
- await cosmosdb_client.delete_all("user1")
+
+ # Mock logger to capture the error
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Call the method and expect it to raise the exception
+ with pytest.raises(Exception, match="Get user batch failed"):
+ await cosmos_db_client.get_user_batches(user_id)
+
+ # Ensure logger.error was called with the expected message and error
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to get user batches"
+ assert "error" in called_kwargs
+ assert "Get user batch failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_delete_all_error(monkeypatch, cosmosdb_client):
- async def fake_delete_items(key):
- raise Exception("Delete all error")
+async def test_get_file_logs(cosmos_db_client, mocker):
+ file_id = str(uuid4())
+
+ # Mock log container query_items method
+ mock_log_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container)
+
+ # Simulate the query result with new log structure
+ expected_logs = [
+ {
+ "log_id": str(uuid4()),
+ "file_id": file_id,
+ "description": "Log entry 1",
+ "last_candidate": "candidate_1",
+ "log_type": LogType.INFO,
+ "agent_type": AgentType.FIXER,
+ "author_role": AuthorRole.ASSISTANT,
+ "timestamp": datetime(2025, 4, 7, 12, 0, 0)
+ },
+ {
+ "log_id": str(uuid4()),
+ "file_id": file_id,
+ "description": "Log entry 2",
+ "last_candidate": "candidate_2",
+ "log_type": LogType.ERROR,
+ "agent_type": AgentType.HUMAN,
+ "author_role": AuthorRole.USER,
+ "timestamp": datetime(2025, 4, 7, 12, 5, 0)
+ }
+ ]
+
+ # Define the async generator function that will yield the expected logs
+ async def mock_query_items(query, parameters):
+ for log in expected_logs:
+ yield log
+
+ # Assign the async generator to query_items mock
+ mock_log_container.query_items.side_effect = mock_query_items
+
+ # Call the method
+ logs = await cosmos_db_client.get_file_logs(file_id)
+
+ # Assert the logs are returned correctly
+ assert len(logs) == 2
+ assert logs[0]["description"] == "Log entry 1"
+ assert logs[1]["description"] == "Log entry 2"
+ assert logs[0]["log_type"] == LogType.INFO
+ assert logs[1]["log_type"] == LogType.ERROR
+ assert logs[0]["timestamp"] == datetime(2025, 4, 7, 12, 0, 0)
+ assert logs[1]["timestamp"] == datetime(2025, 4, 7, 12, 5, 0)
+
+ mock_log_container.query_items.assert_called_once()
+
- monkeypatch.setattr(
- cosmosdb_client.batch_container, "delete_items", fake_delete_items
+@pytest.mark.asyncio
+async def test_get_file_logs_exception(cosmos_db_client, mocker):
+ file_id = str(uuid4())
+
+ # Mock log_container.query_items to raise an exception
+ mock_log_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container)
+ mock_log_container.query_items = mock.MagicMock(
+ side_effect=Exception("Get file log failed")
)
- with pytest.raises(Exception) as exc_info:
- await cosmosdb_client.delete_all("user1")
- assert "Delete all error" in str(exc_info.value)
+
+ # Mock logger to verify error logging
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Call the method and expect it to raise the exception
+ with pytest.raises(Exception, match="Get file log failed"):
+ await cosmos_db_client.get_file_logs(file_id)
+
+ # Assert logger.error was called with correct arguments
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to get file logs"
+ assert "error" in called_kwargs
+ assert "Get file log failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_delete_logs_success(monkeypatch, cosmosdb_client):
- async def fake_delete_items(key):
- return
+async def test_delete_all(cosmos_db_client, mocker):
+ user_id = str(uuid4())
+
+ # Mock containers with AsyncMock
+ mock_batch_container = AsyncMock()
+ mock_file_container = AsyncMock()
+ mock_log_container = AsyncMock()
+
+ # Patching the containers with mock objects
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+ mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container)
- monkeypatch.setattr(
- cosmosdb_client.log_container, "delete_items", fake_delete_items
+ # Mock the delete_item method for all containers
+ mock_batch_container.delete_item = AsyncMock(return_value=None)
+ mock_file_container.delete_item = AsyncMock(return_value=None)
+ mock_log_container.delete_item = AsyncMock(return_value=None)
+
+ # Call the delete_all method
+ await cosmos_db_client.delete_all(user_id)
+
+ mock_batch_container.delete_item.assert_called_once()
+ mock_file_container.delete_item.assert_called_once()
+ mock_log_container.delete_item.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_delete_all_exception(cosmos_db_client, mocker):
+ user_id = f"user_{uuid4()}"
+
+ # Mock batch_container to raise an exception on delete
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.delete_item = mock.AsyncMock(
+ side_effect=Exception("Delete failed")
)
- await cosmosdb_client.delete_logs("file1")
+
+ # Also mock file_container and log_container to avoid accidental execution
+ mock_file_container = mock.MagicMock()
+ mock_log_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+ mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container)
+
+ # Mock logger to verify error handling
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Call the method and expect it to raise the exception
+ with pytest.raises(Exception, match="Delete failed"):
+ await cosmos_db_client.delete_all(user_id)
+
+ # Check that logger.error was called with expected error message
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to delete all user data"
+ assert "error" in called_kwargs
+ assert "Delete failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_delete_batch_success(monkeypatch, cosmosdb_client):
- delete_calls = []
+async def test_delete_logs(cosmos_db_client, mocker):
+ file_id = str(uuid4())
- async def fake_delete_items(key):
- delete_calls.append(key)
+ # Mock the log container with AsyncMock
+ mock_log_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container)
- async def fake_delete_item(item, partition_key):
- delete_calls.append((item, partition_key))
+ # Simulate the query result for logs
+ log_ids = [str(uuid4()), str(uuid4())]
- monkeypatch.setattr(
- cosmosdb_client.file_container, "delete_items", fake_delete_items
- )
- monkeypatch.setattr(
- cosmosdb_client.log_container, "delete_items", fake_delete_items
- )
- monkeypatch.setattr(
- cosmosdb_client.batch_container, "delete_item", fake_delete_item
+ # Define the async generator function to simulate query result
+ async def mock_query_items(query, parameters):
+ for log_id in log_ids:
+ yield {"id": log_id}
+
+ # Assign the async generator to query_items mock
+ mock_log_container.query_items.side_effect = mock_query_items
+
+ # Mock delete_item method for log_container
+ mock_log_container.delete_item = AsyncMock(return_value=None)
+
+ # Call the delete_logs method
+ await cosmos_db_client.delete_logs(file_id)
+
+ # Assert delete_item is called for each log id
+ for log_id in log_ids:
+ mock_log_container.delete_item.assert_any_call(log_id, partition_key=log_id)
+
+ mock_log_container.query_items.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_delete_logs_exception(cosmos_db_client, mocker):
+ file_id = str(uuid4())
+
+ # Mock log_container.query_items to raise an exception
+ mock_log_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container)
+ mock_log_container.query_items = mock.MagicMock(
+ side_effect=Exception("Query failed")
)
- await cosmosdb_client.delete_batch("user1", "batch1")
- assert len(delete_calls) == 3
+
+ # Mock logger to verify error handling
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Call the method and expect it to raise the exception
+ with pytest.raises(Exception, match="Query failed"):
+ await cosmos_db_client.delete_logs(file_id)
+
+ # Check that logger.error was called with expected error message
+ called_args, called_kwargs = mock_logger.error.call_args
+ assert called_args[0] == "Failed to delete all user data"
+ assert "error" in called_kwargs
+ assert "Query failed" in called_kwargs["error"]
@pytest.mark.asyncio
-async def test_delete_file_success(monkeypatch, cosmosdb_client):
- calls = []
+async def test_delete_batch(cosmos_db_client, mocker):
+ user_id = str(uuid4())
+ batch_id = str(uuid4())
+
+ # Mock the batch container with AsyncMock
+ mock_batch_container = AsyncMock()
+ mocker.patch.object(cosmos_db_client, "batch_container", mock_batch_container)
- async def fake_delete_items(key):
- calls.append(("log_delete", key))
+ # Call the delete_batch method
+ await cosmos_db_client.delete_batch(user_id, batch_id)
- async def fake_delete_item(file_id):
- calls.append(("file_delete", file_id))
+ mock_batch_container.delete_item.assert_called_once()
- monkeypatch.setattr(
- cosmosdb_client.log_container, "delete_items", fake_delete_items
+
+@pytest.mark.asyncio
+async def test_delete_batch_exception(cosmos_db_client, mocker):
+ user_id = f"user_{uuid4()}"
+ batch_id = str(uuid4())
+
+ # Mock batch_container.delete_item to raise an exception
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+ mock_batch_container.delete_item = mock.AsyncMock(
+ side_effect=Exception("Delete failed")
)
- monkeypatch.setattr(cosmosdb_client.file_container, "delete_item", fake_delete_item)
- await cosmosdb_client.delete_file("user1", "batch1", "file1")
- assert ("log_delete", "file1") in calls
- assert ("file_delete", "file1") in calls
+
+ # Mock logger to verify error logging
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Expect the exception to be raised from the inner try block
+ with pytest.raises(Exception, match="Delete failed"):
+ await cosmos_db_client.delete_batch(user_id, batch_id)
+
+ # Check that both error logs were triggered
+ assert mock_logger.error.call_count == 2
+
+ # First log: failed to delete the specific batch
+ first_call_args, first_call_kwargs = mock_logger.error.call_args_list[0]
+ assert f"Failed to delete batch with ID: {batch_id}" in first_call_args[0]
+ assert "error" in first_call_kwargs
+ assert "Delete failed" in first_call_kwargs["error"]
+
+ # Second log: higher-level operation failed
+ second_call_args, second_call_kwargs = mock_logger.error.call_args_list[1]
+ assert second_call_args[0] == "Failed to perform delete batch operation"
+ assert "error" in second_call_kwargs
+ assert "Delete failed" in second_call_kwargs["error"]
@pytest.mark.asyncio
-async def test_log_file_status_success(monkeypatch, cosmosdb_client):
- called = False
+async def test_delete_file(cosmos_db_client, mocker):
+ user_id = str(uuid4())
+ file_id = str(uuid4())
- async def fake_create_item(body):
- nonlocal called
- called = True
+ # Mock containers with AsyncMock
+ mock_file_container = AsyncMock()
+ mock_log_container = AsyncMock()
- monkeypatch.setattr(cosmosdb_client.log_container, "create_item", fake_create_item)
- await cosmosdb_client.log_file_status(
- "file1", DummyProcessStatus.READY_TO_PROCESS, "desc", DummyLogType.INFO
- )
- assert called
+ # Patching the containers with mock objects
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+ mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container)
+
+ # Mock the delete_logs method (since it's called in delete_file)
+ mocker.patch.object(cosmos_db_client, 'delete_logs', return_value=None)
+
+ # Call the delete_file method
+ await cosmos_db_client.delete_file(user_id, file_id)
+
+ cosmos_db_client.delete_logs.assert_called_once_with(file_id)
+
+ mock_file_container.delete_item.assert_called_once_with(file_id, partition_key=file_id)
@pytest.mark.asyncio
-async def test_log_file_status_error(monkeypatch, cosmosdb_client):
- async def fake_create_item(body):
- raise Exception("Log error")
+async def test_delete_file_exception(cosmos_db_client, mocker):
+ user_id = f"user_{uuid4()}"
+ file_id = str(uuid4())
+
+ # Mock delete_logs to raise an exception
+ mocker.patch.object(
+ cosmos_db_client,
+ 'delete_logs',
+ mock.AsyncMock(side_effect=Exception("Delete file failed"))
+ )
+
+ # Mock file_container to ensure delete_item is not accidentally called
+ mock_file_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container)
+
+ # Mock logger to verify error logging
+ mock_logger = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'logger', mock_logger)
+
+ # Expect an exception to be raised from delete_logs
+ with pytest.raises(Exception, match="Delete file failed"):
+ await cosmos_db_client.delete_file(user_id, file_id)
+
+ mock_logger.error.assert_called_once()
+ called_args, _ = mock_logger.error.call_args
+ assert f"Failed to delete file and logs for file_id {file_id}" in called_args[0]
- monkeypatch.setattr(
- cosmosdb_client.log_container,
- "create_item",
- lambda *args, **kwargs: fake_create_item(*args, **kwargs),
+
+@pytest.mark.asyncio
+async def test_add_file_log(cosmos_db_client, mocker):
+ file_id = uuid4()
+ description = "File processing started"
+ last_candidate = "candidate_123"
+ log_type = LogType.INFO
+ agent_type = AgentType.MIGRATOR
+ author_role = AuthorRole.ASSISTANT
+
+ # Mock log container create_item method
+ mock_log_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container)
+
+ # Mock the create_item method
+ mock_log_container.create_item = AsyncMock(return_value=None)
+
+ # Call the method
+ await cosmos_db_client.add_file_log(
+ file_id, description, last_candidate, log_type, agent_type, author_role
)
- with pytest.raises(Exception) as exc_info:
- await cosmosdb_client.log_file_status(
- "file1", DummyProcessStatus.READY_TO_PROCESS, "desc", DummyLogType.INFO
- )
- assert "Log error" in str(exc_info.value)
+
+ mock_log_container.create_item.assert_called_once()
@pytest.mark.asyncio
-async def test_update_batch_entry_success(monkeypatch, cosmosdb_client):
- dummy_batch = {
- "id": "batch1",
- "user_id": "user1",
- "status": DummyProcessStatus.READY_TO_PROCESS,
- "updated_at": datetime.utcnow().isoformat(),
+async def test_update_batch_entry(cosmos_db_client, mocker):
+ batch_id = "batch_123"
+ user_id = "user_123"
+ status = ProcessStatus.IN_PROGRESS
+ file_count = 5
+
+ # Mock batch container replace_item method
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+
+ # Mock the get_batch method
+ mocker.patch.object(cosmos_db_client, 'get_batch', return_value={
+ "batch_id": batch_id,
+ "status": ProcessStatus.READY_TO_PROCESS.value,
+ "user_id": user_id,
"file_count": 0,
- }
+ "updated_at": "2025-04-07T00:00:00Z"
+ })
- async def fake_get_batch(user_id, batch_id):
- return dummy_batch.copy()
+ # Mock the replace_item method
+ mock_batch_container.replace_item = AsyncMock(return_value=None)
- monkeypatch.setattr(cosmosdb_client, "get_batch", fake_get_batch)
- updated_body = None
+ # Call the method
+ updated_batch = await cosmos_db_client.update_batch_entry(batch_id, user_id, status, file_count)
- async def fake_replace_item(item, body):
- nonlocal updated_body
- updated_body = body
- return body
+ # Assert that replace_item was called with the correct arguments
+ mock_batch_container.replace_item.assert_called_once_with(item=batch_id, body={
+ "batch_id": batch_id,
+ "status": status.value,
+ "user_id": user_id,
+ "file_count": file_count,
+ "updated_at": updated_batch["updated_at"]
+ })
- monkeypatch.setattr(
- cosmosdb_client.batch_container, "replace_item", fake_replace_item
- )
- new_status = DummyProcessStatus.PROCESSING
- file_count = 5
- result = await cosmosdb_client.update_batch_entry(
- "batch1", "user1", new_status, file_count
- )
- assert result["file_count"] == file_count
- assert result["status"] == new_status.value
- assert updated_body is not None
+ # Assert the returned batch matches expected values
+ assert updated_batch["batch_id"] == batch_id
+ assert updated_batch["status"] == status.value
+ assert updated_batch["file_count"] == file_count
@pytest.mark.asyncio
-async def test_update_batch_entry_not_found(monkeypatch, cosmosdb_client):
- monkeypatch.setattr(
- cosmosdb_client, "get_batch", lambda u, b: asyncio.sleep(0, result=None)
- )
- with pytest.raises(ValueError, match="Batch not found"):
- await cosmosdb_client.update_batch_entry(
- "nonexistent", "user1", DummyProcessStatus.READY_TO_PROCESS, 0
- )
+async def test_close(cosmos_db_client, mocker):
+ # Mock the client and logger
+ mock_client = mock.MagicMock()
+ mock_logger = mock.MagicMock()
+ cosmos_db_client.client = mock_client
+ cosmos_db_client.logger = mock_logger
+ # Call the method
+ await cosmos_db_client.close()
-@pytest.mark.asyncio
-async def test_close(monkeypatch, cosmosdb_client):
- closed = False
+ # Assert that the client was closed
+ mock_client.close.assert_called_once()
- def fake_close():
- nonlocal closed
- closed = True
+ # Assert that logger's info method was called
+ mock_logger.info.assert_called_once_with("Closed Cosmos DB connection")
- monkeypatch.setattr(cosmosdb_client.client, "close", fake_close)
- await cosmosdb_client.close()
- assert closed
+
+@pytest.mark.asyncio
+async def test_get_batch_history(cosmos_db_client, mocker):
+ user_id = "user_123"
+ limit = 5
+ offset = 0
+ sort_order = "DESC"
+
+ # Mock batch container query_items method
+ mock_batch_container = mock.MagicMock()
+ mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
+
+ # Simulate the query result for batches
+ expected_batches = [
+ {"batch_id": "batch_1", "status": ProcessStatus.IN_PROGRESS.value, "user_id": user_id, "file_count": 5},
+ {"batch_id": "batch_2", "status": ProcessStatus.COMPLETED.value, "user_id": user_id, "file_count": 3},
+ ]
+
+ # Define the async generator function to simulate query result
+ async def mock_query_items(query, parameters):
+ for batch in expected_batches:
+ yield batch
+
+ # Assign the async generator to query_items mock
+ mock_batch_container.query_items.side_effect = mock_query_items
+
+ # Call the method
+ batches = await cosmos_db_client.get_batch_history(user_id, limit, sort_order, offset)
+
+ # Assert the returned batches are correct
+ assert len(batches) == len(expected_batches)
+ assert batches[0]["batch_id"] == expected_batches[0]["batch_id"]
+
+ mock_batch_container.query_items.assert_called_once()
diff --git a/src/tests/backend/common/database/database_base_test.py b/src/tests/backend/common/database/database_base_test.py
index 6000d86..325cf7e 100644
--- a/src/tests/backend/common/database/database_base_test.py
+++ b/src/tests/backend/common/database/database_base_test.py
@@ -1,60 +1,61 @@
-import asyncio
import uuid
-import pytest
-from datetime import datetime
from enum import Enum
-# Import the abstract base class and related models/enums.
+
from common.database.database_base import DatabaseBase
-from common.models.api import BatchRecord, FileRecord, ProcessStatus
+from common.models.api import ProcessStatus
+
+import pytest
+
+# Allow instantiation of the abstract base class by clearing its abstract methods.
DatabaseBase.__abstractmethods__ = set()
@pytest.fixture
def db_instance():
- # Instantiate the DatabaseBase directly.
+ # Create a concrete implementation of DatabaseBase using async methods.
class ConcreteDatabase(DatabaseBase):
- def create_batch(self, user_id, batch_id):
+ async def create_batch(self, user_id, batch_id):
pass
- def get_file_logs(self, file_id):
+ async def get_file_logs(self, file_id):
pass
- def get_batch_files(self, user_id, batch_id):
+ async def get_batch_files(self, user_id, batch_id):
pass
- def delete_file_logs(self, file_id):
+ async def delete_file_logs(self, file_id):
pass
- def get_user_batches(self, user_id):
+ async def get_user_batches(self, user_id):
pass
- def add_file(self, batch_id, file_id, file_name, file_path):
+ async def add_file(self, batch_id, file_id, file_name, file_path):
pass
- def get_batch(self, user_id, batch_id):
+ async def get_batch(self, user_id, batch_id):
pass
- def get_file(self, file_id):
+ async def get_file(self, file_id):
pass
- def log_file_status(self, file_id, status, description, log_type):
+ async def log_file_status(self, file_id, status, description, log_type):
pass
- def log_batch_status(self, batch_id, status, file_count):
+ async def log_batch_status(self, batch_id, status, file_count):
pass
- def delete_all(self, user_id):
+ async def delete_all(self, user_id):
pass
- def delete_batch(self, user_id, batch_id):
+ async def delete_batch(self, user_id, batch_id):
pass
- def delete_file(self, user_id, batch_id, file_id):
+ async def delete_file(self, user_id, batch_id, file_id):
pass
- def close(self):
+ async def close(self):
pass
return ConcreteDatabase()
@@ -71,7 +72,7 @@ def get_dummy_status():
members = list(ProcessStatus)
if members:
return members[0]
- # If the enum is empty, create a dummy one
+ # If the enum is empty, create a dummy one.
DummyStatus = Enum("DummyStatus", {"DUMMY": "dummy"})
return DummyStatus.DUMMY
@@ -79,7 +80,7 @@ def get_dummy_status():
@pytest.mark.asyncio
async def test_create_batch(db_instance):
result = await db_instance.create_batch("user1", uuid.uuid4())
- # Since the method is abstract (and implemented as pass), result is None.
+ # Since the method is implemented as pass, result is None.
assert result is None
@@ -109,9 +110,7 @@ async def test_get_user_batches(db_instance):
@pytest.mark.asyncio
async def test_add_file(db_instance):
- result = await db_instance.add_file(
- uuid.uuid4(), uuid.uuid4(), "test.txt", "/dummy/path"
- )
+ result = await db_instance.add_file(uuid.uuid4(), uuid.uuid4(), "test.txt", "/dummy/path")
assert result is None
@@ -129,10 +128,8 @@ async def test_get_file(db_instance):
@pytest.mark.asyncio
async def test_log_file_status(db_instance):
- # Use an existing member for file status—here we use COMPLETED.
- result = await db_instance.log_file_status(
- "file1", ProcessStatus.COMPLETED, "desc", "log_type"
- )
+ # Using ProcessStatus.COMPLETED as an example.
+ result = await db_instance.log_file_status("file1", ProcessStatus.COMPLETED, "desc", "log_type")
assert result is None
diff --git a/src/tests/backend/common/database/database_factory_test.py b/src/tests/backend/common/database/database_factory_test.py
index b597e56..27d9810 100644
--- a/src/tests/backend/common/database/database_factory_test.py
+++ b/src/tests/backend/common/database/database_factory_test.py
@@ -1,57 +1,79 @@
+from unittest.mock import AsyncMock, patch
+
+
import pytest
-from common.config.config import Config
-from common.database.database_factory import DatabaseFactory
-
-
-class DummyConfig:
- cosmosdb_endpoint = "dummy_endpoint"
- cosmosdb_database = "dummy_database"
- cosmosdb_batch_container = "dummy_batch"
- cosmosdb_file_container = "dummy_file"
- cosmosdb_log_container = "dummy_log"
-
-
-class DummyCosmosDBClient:
- def __init__(self, endpoint, credential, database_name, batch_container, file_container, log_container):
- self.endpoint = endpoint
- self.credential = credential
- self.database_name = database_name
- self.batch_container = batch_container
- self.file_container = file_container
- self.log_container = log_container
-
-def dummy_config_init(self):
- self.cosmosdb_endpoint = DummyConfig.cosmosdb_endpoint
- self.cosmosdb_database = DummyConfig.cosmosdb_database
- self.cosmosdb_batch_container = DummyConfig.cosmosdb_batch_container
- self.cosmosdb_file_container = DummyConfig.cosmosdb_file_container
- self.cosmosdb_log_container = DummyConfig.cosmosdb_log_container
- # Provide a dummy method for credentials.
- self.get_azure_credentials = lambda: "dummy_credential"
+
@pytest.fixture(autouse=True)
def patch_config(monkeypatch):
- # Patch the __init__ of Config so that an instance will have the required attributes.
- monkeypatch.setattr(Config, "__init__", dummy_config_init)
+ """Patch Config class to use dummy values."""
+ from common.config.config import Config
+
+ def dummy_init(self):
+ """Mocked __init__ method for Config to set dummy values."""
+ self.cosmosdb_endpoint = "dummy_endpoint"
+ self.cosmosdb_database = "dummy_database"
+ self.cosmosdb_batch_container = "dummy_batch"
+ self.cosmosdb_file_container = "dummy_file"
+ self.cosmosdb_log_container = "dummy_log"
+ self.get_azure_credentials = lambda: "dummy_credential"
+
+ monkeypatch.setattr(Config, "__init__", dummy_init) # Replace the init method
+
@pytest.fixture(autouse=True)
def patch_cosmosdb_client(monkeypatch):
- # Patch CosmosDBClient in the module under test to use our dummy client.
+ """Patch CosmosDBClient to use a dummy implementation."""
+
+ class DummyCosmosDBClient:
+ def __init__(self, endpoint, credential, database_name, batch_container, file_container, log_container):
+ self.endpoint = endpoint
+ self.credential = credential
+ self.database_name = database_name
+ self.batch_container = batch_container
+ self.file_container = file_container
+ self.log_container = log_container
+
+ async def initialize_cosmos(self):
+ pass
+
+ async def create_batch(self, *args, **kwargs):
+ pass
+
+ async def add_file(self, *args, **kwargs):
+ pass
+
+ async def get_batch(self, *args, **kwargs):
+ return "mock_batch"
+
+ async def close(self):
+ pass
+
monkeypatch.setattr("common.database.database_factory.CosmosDBClient", DummyCosmosDBClient)
-def test_get_database():
- """
- Test that DatabaseFactory.get_database() correctly returns an instance of the
- dummy CosmosDB client with the expected configuration values.
- """
- # When get_database() is called, it creates a new Config() instance.
- db_instance = DatabaseFactory.get_database()
-
- # Verify that the returned instance is our dummy client with the expected attributes.
- assert isinstance(db_instance, DummyCosmosDBClient)
- assert db_instance.endpoint == DummyConfig.cosmosdb_endpoint
+
+@pytest.mark.asyncio
+async def test_get_database():
+ """Test database retrieval using the factory."""
+ from common.database.database_factory import DatabaseFactory
+
+ db_instance = await DatabaseFactory.get_database()
+
+ assert db_instance.endpoint == "dummy_endpoint"
assert db_instance.credential == "dummy_credential"
- assert db_instance.database_name == DummyConfig.cosmosdb_database
- assert db_instance.batch_container == DummyConfig.cosmosdb_batch_container
- assert db_instance.file_container == DummyConfig.cosmosdb_file_container
- assert db_instance.log_container == DummyConfig.cosmosdb_log_container
+ assert db_instance.database_name == "dummy_database"
+ assert db_instance.batch_container == "dummy_batch"
+ assert db_instance.file_container == "dummy_file"
+ assert db_instance.log_container == "dummy_log"
+
+
+@pytest.mark.asyncio
+async def test_main_function():
+ """Test the main function in database factory."""
+ with patch("common.database.database_factory.DatabaseFactory.get_database", new_callable=AsyncMock, return_value=AsyncMock()) as mock_get_database, patch("builtins.print") as mock_print:
+
+ from common.database.database_factory import main
+ await main()
+
+ mock_get_database.assert_called_once()
+ mock_print.assert_called() # Ensures print is executed
diff --git a/src/tests/backend/common/logger/app_logger_test.py b/src/tests/backend/common/logger/app_logger_test.py
new file mode 100644
index 0000000..9301eb3
--- /dev/null
+++ b/src/tests/backend/common/logger/app_logger_test.py
@@ -0,0 +1,94 @@
+import json
+import logging
+from unittest.mock import MagicMock, patch
+
+from common.logger.app_logger import AppLogger, LogLevel # Adjust the import based on your actual path
+
+import pytest
+
+
+@pytest.fixture
+def logger_name():
+ return "test_logger"
+
+
+@pytest.fixture
+def logger_instance(logger_name):
+ """Fixture to return AppLogger with mocked handler"""
+ with patch("common.logger.app_logger.logging.getLogger") as mock_get_logger:
+ mock_logger = MagicMock()
+ mock_get_logger.return_value = mock_logger
+ yield AppLogger(logger_name)
+
+
+def test_log_levels():
+ """Ensure log levels are set correctly"""
+ assert LogLevel.NONE == logging.NOTSET
+ assert LogLevel.DEBUG == logging.DEBUG
+ assert LogLevel.INFO == logging.INFO
+ assert LogLevel.WARNING == logging.WARNING
+ assert LogLevel.ERROR == logging.ERROR
+ assert LogLevel.CRITICAL == logging.CRITICAL
+
+
+def test_format_message_basic(logger_instance):
+ result = logger_instance._format_message("Test message")
+ parsed = json.loads(result)
+ assert parsed["message"] == "Test message"
+ assert "context" not in parsed
+
+
+def test_format_message_with_context(logger_instance):
+ result = logger_instance._format_message("Contextual message", key1="value1", key2="value2")
+ parsed = json.loads(result)
+ assert parsed["message"] == "Contextual message"
+ assert parsed["context"] == {"key1": "value1", "key2": "value2"}
+
+
+def test_debug_log(logger_instance):
+ with patch.object(logger_instance.logger, "debug") as mock_debug:
+ logger_instance.debug("Debug log", user="tester")
+ mock_debug.assert_called_once()
+ log_json = json.loads(mock_debug.call_args[0][0])
+ assert log_json["message"] == "Debug log"
+ assert log_json["context"]["user"] == "tester"
+
+
+def test_info_log(logger_instance):
+ with patch.object(logger_instance.logger, "info") as mock_info:
+ logger_instance.info("Info log", module="log_module")
+ mock_info.assert_called_once()
+ log_json = json.loads(mock_info.call_args[0][0])
+ assert log_json["message"] == "Info log"
+ assert log_json["context"]["module"] == "log_module"
+
+
+def test_warning_log(logger_instance):
+ with patch.object(logger_instance.logger, "warning") as mock_warning:
+ logger_instance.warning("Warning log")
+ mock_warning.assert_called_once()
+
+
+def test_error_log(logger_instance):
+ with patch.object(logger_instance.logger, "error") as mock_error:
+ logger_instance.error("Error log", error_code=500)
+ mock_error.assert_called_once()
+ log_json = json.loads(mock_error.call_args[0][0])
+ assert log_json["message"] == "Error log"
+ assert log_json["context"]["error_code"] == 500
+
+
+def test_critical_log(logger_instance):
+ with patch.object(logger_instance.logger, "critical") as mock_critical:
+ logger_instance.critical("Critical log")
+ mock_critical.assert_called_once()
+
+
+def test_set_min_log_level():
+ with patch("common.logger.app_logger.logging.getLogger") as mock_get_logger:
+ mock_logger = MagicMock()
+ mock_get_logger.return_value = mock_logger
+
+ AppLogger.set_min_log_level(LogLevel.ERROR)
+
+ mock_logger.setLevel.assert_called_once_with(LogLevel.ERROR)
diff --git a/src/tests/backend/common/models/api_test.py b/src/tests/backend/common/models/api_test.py
new file mode 100644
index 0000000..b338efc
--- /dev/null
+++ b/src/tests/backend/common/models/api_test.py
@@ -0,0 +1,123 @@
+from datetime import datetime
+from uuid import uuid4
+
+from backend.common.models.api import AgentType, BatchRecord, FileLog, FileProcessUpdate, FileProcessUpdateJSONEncoder, FileRecord, FileResult, ProcessStatus, QueueBatch, TranslateType
+
+import pytest
+
+
+@pytest.fixture
+def common_datetime():
+ return datetime.now()
+
+
+@pytest.fixture
+def uuid_pair():
+ return str(uuid4()), str(uuid4())
+
+
+def test_filelog_fromdb_and_dict(uuid_pair, common_datetime):
+ log_id, file_id = uuid_pair
+ data = {
+ "log_id": log_id,
+ "file_id": file_id,
+ "description": "test log",
+ "last_candidate": "some_candidate",
+ "log_type": "SUCCESS",
+ "agent_type": "migrator",
+ "author_role": "user",
+ "timestamp": common_datetime.isoformat(),
+ }
+ log = FileLog.fromdb(data)
+ assert log.log_id.hex == log_id.replace("-", "")
+ assert log.dict()["log_type"] == "info"
+
+ assert log.dict()["author_role"] == "user"
+
+
+def test_filerecord_fromdb_and_dict(uuid_pair, common_datetime):
+ file_id, batch_id = uuid_pair
+ data = {
+ "file_id": file_id,
+ "batch_id": batch_id,
+ "original_name": "file.sql",
+ "blob_path": "/blob/file.sql",
+ "translated_path": "/translated/file.sql",
+ "status": "in_progress",
+ "file_result": "warning",
+ "error_count": 2,
+ "syntax_count": 5,
+ "created_at": common_datetime.isoformat(),
+ "updated_at": common_datetime.isoformat(),
+ }
+ record = FileRecord.fromdb(data)
+ assert record.file_id.hex == file_id.replace("-", "")
+ assert record.dict()["status"] == "ready_to_process"
+ assert record.dict()["file_result"] == "warning"
+
+
+def test_fileprocessupdate_dict(uuid_pair):
+ file_id, batch_id = uuid_pair
+ update = FileProcessUpdate(
+ file_id=file_id,
+ batch_id=batch_id,
+ process_status=ProcessStatus.COMPLETED,
+ file_result=FileResult.SUCCESS,
+ agent_type=AgentType.FIXER,
+ agent_message="Translation done",
+ )
+ result = update.dict()
+ assert result["process_status"] == "completed"
+ assert result["file_result"] == "success"
+ assert result["agent_type"] == "fixer"
+ assert result["agent_message"] == "Translation done"
+
+
+def test_fileprocessupdate_json_encoder(uuid_pair):
+ file_id, batch_id = uuid_pair
+ update = FileProcessUpdate(
+ file_id=file_id,
+ batch_id=batch_id,
+ process_status=ProcessStatus.FAILED,
+ file_result=FileResult.ERROR,
+ agent_type=AgentType.HUMAN,
+ agent_message="Something failed",
+ )
+ json_string = FileProcessUpdateJSONEncoder().encode(update)
+ assert "failed" in json_string
+ assert "human" in json_string
+
+
+def test_queuebatch_dict(uuid_pair, common_datetime):
+ batch_id, _ = uuid_pair
+ batch = QueueBatch(
+ batch_id=batch_id,
+ user_id="user123",
+ translate_from="en",
+ translate_to="tsql",
+ created_at=common_datetime,
+ updated_at=common_datetime,
+ status=ProcessStatus.IN_PROGRESS,
+ )
+ result = batch.dict()
+ assert result["status"] == "in_process"
+ assert result["user_id"] == "user123"
+
+
+def test_batchrecord_fromdb_and_dict(uuid_pair, common_datetime):
+ batch_id, _ = uuid_pair
+ data = {
+ "batch_id": batch_id,
+ "user_id": "user123",
+ "file_count": 3,
+ "created_at": common_datetime.isoformat(),
+ "updated_at": common_datetime.isoformat(),
+ "status": "completed",
+ "from_language": "Informix",
+ "to_language": "T-SQL"
+ }
+ record = BatchRecord.fromdb(data)
+ assert record.status == ProcessStatus.COMPLETED
+ assert record.from_language == TranslateType.INFORMIX
+ assert record.to_language == TranslateType.TSQL
+ assert record.dict()["status"] == "completed"
diff --git a/src/tests/backend/common/services/__init__.py b/src/tests/backend/common/services/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/tests/backend/common/services/batch_service_test.py b/src/tests/backend/common/services/batch_service_test.py
new file mode 100644
index 0000000..21fd3a6
--- /dev/null
+++ b/src/tests/backend/common/services/batch_service_test.py
@@ -0,0 +1,785 @@
+from io import BytesIO
+from unittest.mock import AsyncMock, MagicMock, patch
+from uuid import uuid4
+
+from common.models.api import AgentType, AuthorRole, BatchRecord, FileResult, LogType, ProcessStatus
+from common.services.batch_service import BatchService
+
+from fastapi import HTTPException, UploadFile
+
+import pytest
+
+import pytest_asyncio
+
+
+@pytest.fixture
+def mock_service(mocker):
+ service = BatchService()
+ service.logger = mocker.Mock()
+ service.database = MagicMock()
+
+ return service
+
+
+@pytest_asyncio.fixture
+async def service():
+ svc = BatchService()
+ svc.logger = MagicMock()
+ return svc
+
+
+def batch_service():
+ service = BatchService() # Correct constructor
+ service.database = MagicMock() # Inject mock database
+ return service
+
+
+@pytest.mark.asyncio
+@patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock)
+async def test_initialize_database(mock_get_db, service):
+ mock_db = AsyncMock()
+ mock_get_db.return_value = mock_db
+ await service.initialize_database()
+ assert service.database == mock_db
+
+
+@pytest.mark.asyncio
+async def test_get_batch_found(service):
+ service.database = AsyncMock()
+ batch_id = uuid4()
+ user_id = "user123"
+ service.database.get_batch.return_value = {"id": str(batch_id)}
+ service.database.get_batch_files.return_value = [{"file_id": "f1"}]
+ result = await service.get_batch(batch_id, user_id)
+ assert result["batch"] == {"id": str(batch_id)}
+ assert result["files"] == [{"file_id": "f1"}]
+
+
+@pytest.mark.asyncio
+async def test_get_batch_not_found(service):
+ service.database = AsyncMock()
+ batch_id = uuid4()
+ user_id = "user123"
+ service.database.get_batch.return_value = None
+ result = await service.get_batch(batch_id, user_id)
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_get_file_found(service):
+ service.database = AsyncMock()
+ service.database.get_file.return_value = {"file_id": "file123"}
+ result = await service.get_file("file123")
+ assert result == {"file": {"file_id": "file123"}}
+
+
+@pytest.mark.asyncio
+async def test_get_file_not_found(service):
+ service.database = AsyncMock()
+ service.database.get_file.return_value = None
+ result = await service.get_file("notfound")
+ assert result is None
+
+
+@pytest.mark.asyncio
+@patch("common.services.batch_service.BlobStorageFactory.get_storage", new_callable=AsyncMock)
+@patch("common.models.api.FileRecord.fromdb")
+@patch("common.models.api.BatchRecord.fromdb")
+async def test_get_file_report_success(mock_batch_fromdb, mock_file_fromdb, mock_get_storage, service):
+ service.database = AsyncMock()
+ file_id = "file123"
+ mock_file = {"batch_id": uuid4(), "translated_path": "some/path"}
+ mock_batch = {"batch_id": "batch123"}
+ mock_logs = [{"log": "log1"}]
+ mock_translated = "translated content"
+ service.database.get_file.return_value = mock_file
+ service.database.get_batch_from_id.return_value = mock_batch
+ service.database.get_file_logs.return_value = mock_logs
+ mock_file_fromdb.return_value = MagicMock(dict=lambda: mock_file, batch_id=mock_file["batch_id"], translated_path="some/path")
+ mock_batch_fromdb.return_value = MagicMock(dict=lambda: mock_batch)
+ mock_storage = AsyncMock()
+ mock_storage.get_file.return_value = mock_translated
+ mock_get_storage.return_value = mock_storage
+ result = await service.get_file_report(file_id)
+ assert result["file"] == mock_file
+ assert result["batch"] == mock_batch
+ assert result["logs"] == mock_logs
+ assert result["translated_content"] == mock_translated
+
+
+@pytest.mark.asyncio
+@patch("common.services.batch_service.BlobStorageFactory.get_storage", new_callable=AsyncMock)
+async def test_get_file_translated_success(mock_get_storage, service):
+ file = {"translated_path": "some/path"}
+ mock_storage = AsyncMock()
+ mock_storage.get_file.return_value = "translated"
+ mock_get_storage.return_value = mock_storage
+ result = await service.get_file_translated(file)
+ assert result == "translated"
+
+
+@pytest.mark.asyncio
+@patch("common.services.batch_service.BlobStorageFactory.get_storage", new_callable=AsyncMock)
+async def test_get_file_translated_error(mock_get_storage, service):
+ file = {"translated_path": "some/path"}
+ mock_storage = AsyncMock()
+ mock_storage.get_file.side_effect = IOError("Failed to download")
+ mock_get_storage.return_value = mock_storage
+ result = await service.get_file_translated(file)
+ assert result == ""
+
+
+@pytest.mark.asyncio
+async def test_get_batch_for_zip(service):
+ service.database = AsyncMock()
+ service.get_file_translated = AsyncMock(return_value="file-content")
+ service.database.get_batch_files.return_value = [
+ {"original_name": "doc1.txt", "translated_path": "path1"},
+ {"original_name": "doc2.txt", "translated_path": "path2"},
+ ]
+ result = await service.get_batch_for_zip("batch1")
+ assert len(result) == 2
+ assert result[0][0] == "rslt_doc1.txt"
+ assert result[0][1] == "file-content"
+
+
+@pytest.mark.asyncio
+@patch("common.models.api.BatchRecord.fromdb")
+async def test_get_batch_summary_success(mock_batch_fromdb, service):
+ service.database = AsyncMock()
+ mock_batch = {"batch_id": "batch1"}
+ mock_batch_record = MagicMock(dict=lambda: {"batch_id": "batch1"})
+ mock_batch_fromdb.return_value = mock_batch_record
+ service.database.get_batch.return_value = mock_batch
+ service.database.get_batch_files.return_value = [
+ {"file_id": "file1", "translated_path": "path1"},
+ {"file_id": "file2", "translated_path": None},
+ ]
+ service.database.get_file_logs.return_value = ["log1"]
+ service.get_file_translated = AsyncMock(return_value="translated")
+ result = await service.get_batch_summary("batch1", "user1")
+ assert "files" in result
+ assert "batch" in result
+ assert result["files"][0]["logs"] == ["log1"]
+ assert result["files"][0]["translated_content"] == "translated"
+
+
+@pytest.mark.asyncio
+async def test_batch_zip_with_no_files(service):
+ service.database = AsyncMock()
+ service.database.get_batch_files.return_value = []
+ service.get_file_translated = AsyncMock()
+ result = await service.get_batch_for_zip("batch_empty")
+ assert result == []
+
+
+def test_is_valid_uuid():
+ service = BatchService()
+ valid = str(uuid4())
+ invalid = "not-a-uuid"
+ assert service.is_valid_uuid(valid)
+ assert not service.is_valid_uuid(invalid)
+
+
+def test_generate_file_path():
+ service = BatchService()
+ path = service.generate_file_path("batch1", "user1", "file1", "test@file.pdf")
+ assert path == "user1/batch1/file1/test_file.pdf"
+
+
+@pytest.mark.asyncio
+async def test_delete_batch_existing():
+ service = BatchService()
+ service.database = AsyncMock()
+ batch_id = uuid4()
+ service.database.get_batch.return_value = {"id": str(batch_id)}
+ service.database.delete_batch.return_value = None
+ result = await service.delete_batch(batch_id, "user1")
+ assert result["message"] == "Batch deleted successfully"
+ assert result["batch_id"] == str(batch_id)
+
+
+@pytest.mark.asyncio
+async def test_delete_file_success():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = uuid4()
+ batch_id = uuid4()
+ mock_file = MagicMock()
+ mock_file.batch_id = batch_id
+ mock_file.blob_path = "some/path/file.pdf"
+ mock_file.translated_path = "some/path/file_translated.pdf"
+ with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage:
+ mock_storage.return_value.delete_file.return_value = True
+ service.database.get_file.return_value = mock_file
+ service.database.get_batch.return_value = {"id": str(batch_id)}
+ service.database.get_batch_files.return_value = [1, 2]
+ with patch("common.models.api.FileRecord.fromdb", return_value=mock_file), \
+ patch("common.models.api.BatchRecord.fromdb") as mock_batch_record:
+ mock_record = MagicMock()
+ mock_record.file_count = 1
+ service.database.update_batch.return_value = None
+ mock_batch_record.return_value = mock_record
+ result = await service.delete_file(file_id, "user1")
+ assert result["message"] == "File deleted successfully"
+ assert result["file_id"] == str(file_id)
+
+
+@pytest.mark.asyncio
+async def test_upload_file_to_batch_dict_batch():
+ service = BatchService()
+ service.database = AsyncMock()
+ file = UploadFile(filename="hello@file.txt", file=BytesIO(b"test content"))
+ batch_id = str(uuid4())
+ file_id = str(uuid4())
+ with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \
+ patch("uuid.uuid4", return_value=file_id), \
+ patch("common.models.api.FileRecord.fromdb", return_value={"blob_path": "path"}):
+
+ mock_storage.return_value.upload_file.return_value = None
+ service.database.get_batch.side_effect = [None, {"file_count": 0}]
+ service.database.create_batch.return_value = {}
+ service.database.get_batch_files.return_value = ["file1", "file2"]
+ service.database.get_file.return_value = {"filename": file.filename}
+ service.database.update_batch_entry.return_value = {"batch_id": batch_id, "file_count": 2}
+ result = await service.upload_file_to_batch(batch_id, "user1", file)
+ assert "batch" in result
+ assert "file" in result
+
+
+@pytest.mark.asyncio
+async def test_upload_file_to_batch_invalid_storage():
+ service = BatchService()
+ service.database = AsyncMock()
+ file = UploadFile(filename="file.txt", file=BytesIO(b"data"))
+ with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", return_value=None):
+ with pytest.raises(RuntimeError) as exc_info:
+ await service.upload_file_to_batch(str(uuid4()), "user1", file)
+ # Check outer exception message
+ assert str(exc_info.value) == "File upload failed"
+
+ # Check original cause of the exception
+ assert isinstance(exc_info.value.__cause__, RuntimeError)
+ assert str(exc_info.value.__cause__) == "Storage service not initialized"
+
+
+def test_generate_file_path_only_filename():
+ service = BatchService()
+ path = service.generate_file_path(None, None, None, "weird@name!.txt")
+ assert path.endswith("weird_name_.txt")
+
+
+def test_is_valid_uuid_empty_string():
+ service = BatchService()
+ assert not service.is_valid_uuid("")
+
+
+def test_is_valid_uuid_partial_uuid():
+ service = BatchService()
+ assert not service.is_valid_uuid("1234abcd")
+
+
+@pytest.mark.asyncio
+async def test_delete_file_file_not_found():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+
+ service.database.get_file.return_value = None
+ result = await service.delete_file(file_id, "user1")
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_upload_file_to_batch_storage_upload_fails():
+ service = BatchService()
+ service.database = AsyncMock()
+ file = UploadFile(filename="test.txt", file=BytesIO(b"abc"))
+ file_id = str(uuid4())
+
+ with patch("common.storage.blob_factory.BlobStorageFactory.get_storage") as mock_get_storage, \
+ patch("uuid.uuid4", return_value=file_id):
+ mock_storage = AsyncMock()
+ mock_storage.upload_file.side_effect = RuntimeError("upload failed")
+ mock_get_storage.return_value = mock_storage
+
+ service.database.get_batch.side_effect = [None, {"file_count": 0}]
+ service.database.create_batch.return_value = {}
+ service.database.get_batch_files.return_value = []
+ service.database.update_batch_entry.return_value = {}
+
+ with pytest.raises(RuntimeError, match="File upload failed"):
+ await service.upload_file_to_batch("batch123", "user1", file)
+
+ @pytest.mark.asyncio
+ async def test_update_file_counts_success(service):
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ mock_file = {"file_id": file_id}
+ mock_logs = [
+ {"log_type": LogType.ERROR.value},
+ {"log_type": LogType.WARNING.value},
+ {"log_type": LogType.WARNING.value},
+ ]
+ service.database.get_file.return_value = mock_file
+ service.database.get_file_logs.return_value = mock_logs
+ with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock()) as mock_file_record:
+ await service.update_file_counts(file_id)
+ mock_file_record.assert_called_once()
+ service.database.update_file.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_update_file_counts_no_logs(service):
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ mock_file = {"file_id": file_id}
+ service.database.get_file.return_value = mock_file
+ service.database.get_file_logs.return_value = []
+ with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock()) as mock_file_record:
+ await service.update_file_counts(file_id)
+ mock_file_record.assert_called_once()
+ service.database.update_file.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_get_file_counts_success(service):
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ mock_logs = [
+ {"log_type": LogType.ERROR.value},
+ {"log_type": LogType.WARNING.value},
+ {"log_type": LogType.WARNING.value},
+ ]
+ service.database.get_file_logs.return_value = mock_logs
+ error_count, syntax_count = await service.get_file_counts(file_id)
+ assert error_count == 1
+ assert syntax_count == 2
+
+ @pytest.mark.asyncio
+ async def test_get_file_counts_no_logs(service):
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ service.database.get_file_logs.return_value = []
+ error_count, syntax_count = await service.get_file_counts(file_id)
+ assert error_count == 0
+ assert syntax_count == 0
+
+ @pytest.mark.asyncio
+ async def test_get_batch_history_success(service):
+ service.database = AsyncMock()
+ user_id = "user123"
+ mock_history = [{"batch_id": "batch1"}, {"batch_id": "batch2"}]
+ service.database.get_batch_history.return_value = mock_history
+ result = await service.get_batch_history(user_id, limit=10, offset=0)
+ assert result == mock_history
+
+ @pytest.mark.asyncio
+ async def test_get_batch_history_no_history(service):
+ service.database = AsyncMock()
+ user_id = "user123"
+ service.database.get_batch_history.return_value = []
+ result = await service.get_batch_history(user_id, limit=10, offset=0)
+ assert result == []
+
+ @pytest.mark.asyncio
+ @patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock)
+ async def test_initialize_database_success(mock_get_database, service):
+ # Arrange
+ mock_database = AsyncMock()
+ mock_get_database.return_value = mock_database
+
+ # Act
+ await service.initialize_database()
+
+ # Assert
+ assert service.database == mock_database
+ mock_get_database.assert_called_once()
+
+ @pytest.mark.asyncio
+ @patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock)
+ async def test_initialize_database_failure(mock_get_database, service):
+ # Arrange
+ mock_get_database.side_effect = RuntimeError("Database initialization failed")
+
+ # Act & Assert
+ with pytest.raises(RuntimeError, match="Database initialization failed"):
+ await service.initialize_database()
+ mock_get_database.assert_called_once()
+
+ @pytest.mark.asyncio
+ @patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock)
+ async def test_initialize_database_success(mock_get_database, service):
+ # Arrange
+ mock_database = AsyncMock()
+ mock_get_database.return_value = mock_database
+
+ # Act
+ await service.initialize_database()
+
+ # Assert
+ assert service.database == mock_database
+ mock_get_database.assert_called_once()
+
+ @pytest.mark.asyncio
+ @patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock)
+ async def test_initialize_database_failure(mock_get_database, service):
+ # Arrange
+ mock_get_database.side_effect = RuntimeError("Database initialization failed")
+
+ # Act & Assert
+ with pytest.raises(RuntimeError, match="Database initialization failed"):
+ await service.initialize_database()
+ mock_get_database.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_update_file_success():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ mock_file = {"file_id": file_id}
+ mock_record = MagicMock()
+ mock_record.error_count = 0
+ mock_record.syntax_count = 0
+
+ service.database.get_file.return_value = mock_file
+ with patch("common.models.api.FileRecord.fromdb", return_value=mock_record):
+ await service.update_file(file_id, ProcessStatus.COMPLETED, FileResult.SUCCESS, 1, 2)
+ assert mock_record.error_count == 1
+ assert mock_record.syntax_count == 2
+ service.database.update_file.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_update_file_record():
+ service = BatchService()
+ service.database = AsyncMock()
+ mock_file_record = MagicMock()
+ await service.update_file_record(mock_file_record)
+ service.database.update_file.assert_called_once_with(mock_file_record)
+
+
+@pytest.mark.asyncio
+async def test_create_file_log():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ await service.create_file_log(
+ file_id=file_id,
+ description="test log",
+ last_candidate="candidate",
+ log_type=LogType.SUCCESS,
+ agent_type=AgentType.HUMAN,
+ author_role=AuthorRole.USER
+ )
+ service.database.add_file_log.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_update_batch_success():
+ service = BatchService()
+ service.database = AsyncMock()
+ batch_id = str(uuid4())
+ mock_batch = {"batch_id": batch_id}
+ mock_batch_record = MagicMock()
+ service.database.get_batch_from_id.return_value = mock_batch
+ with patch("common.models.api.BatchRecord.fromdb", return_value=mock_batch_record):
+ await service.update_batch(batch_id, ProcessStatus.COMPLETED)
+ service.database.update_batch.assert_called_once_with(mock_batch_record)
+
+
+@pytest.mark.asyncio
+async def test_delete_batch_and_files_success():
+ service = BatchService()
+ service.database = AsyncMock()
+ batch_id = str(uuid4())
+ user_id = "user"
+ mock_file = MagicMock()
+ mock_file.file_id = uuid4()
+ mock_file.blob_path = "blob/file"
+ mock_file.translated_path = "blob/translated"
+ service.database.get_batch.return_value = {"batch_id": batch_id}
+ service.database.get_batch_files.return_value = [mock_file]
+
+ with patch("common.models.api.FileRecord.fromdb", return_value=mock_file), \
+ patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage:
+ mock_storage.return_value.delete_file.return_value = True
+ result = await service.delete_batch_and_files(batch_id, user_id)
+ assert result["message"] == "Files deleted successfully"
+
+
+@pytest.mark.asyncio
+async def test_batch_files_final_update():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ file = {
+ "file_id": file_id,
+ "translated_path": "",
+ "status": "IN_PROGRESS"
+ }
+ service.database.get_batch_files.return_value = [file]
+ with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock(file_id=file_id, translated_path="", status=None)), \
+ patch.object(service, "get_file_counts", return_value=(1, 1)), \
+ patch.object(service, "create_file_log", new_callable=AsyncMock), \
+ patch.object(service, "update_file_record", new_callable=AsyncMock):
+ await service.batch_files_final_update("batch1")
+
+
+@pytest.mark.asyncio
+async def test_delete_all_from_storage_cosmos_success():
+ service = BatchService()
+ service.database = AsyncMock()
+ user_id = "user123"
+ file_id = str(uuid4())
+ batch_id = str(uuid4())
+ mock_file = {
+ "translated_path": "translated/path"
+ }
+
+ service.get_all_batches = AsyncMock(return_value=[{"batch_id": batch_id}])
+ service.database.get_file.return_value = mock_file
+ service.database.list_files = AsyncMock(return_value=[{"name": f"user/{batch_id}/{file_id}/file.txt"}])
+
+ with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage:
+ mock_storage.return_value.list_files.return_value = [{"name": f"user/{batch_id}/{file_id}/file.txt"}]
+ mock_storage.return_value.delete_file.return_value = True
+ result = await service.delete_all_from_storage_cosmos(user_id)
+ assert result["message"] == "All user data deleted successfully"
+
+
+@pytest.mark.asyncio
+async def test_create_candidate_success():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ batch_id = str(uuid4())
+ user_id = "user123"
+ mock_file = {"batch_id": batch_id, "original_name": "doc.txt"}
+ mock_batch = {"user_id": user_id}
+
+ with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock(original_name="doc.txt", batch_id=batch_id)), \
+ patch("common.models.api.BatchRecord.fromdb", return_value=MagicMock(user_id=user_id)), \
+ patch.object(service, "get_file_counts", return_value=(0, 1)), \
+ patch.object(service, "update_file_record", new_callable=AsyncMock), \
+ patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage:
+
+ mock_storage.return_value.upload_file.return_value = None
+ service.database.get_file.return_value = mock_file
+ service.database.get_batch_from_id.return_value = mock_batch
+ await service.create_candidate(file_id, "Some content")
+
+
+@pytest.mark.asyncio
+async def test_batch_files_final_update_success_path():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ file = {
+ "file_id": file_id,
+ "translated_path": "some/path",
+ "status": "IN_PROGRESS"
+ }
+
+ mock_file_record = MagicMock(translated_path="some/path", file_id=file_id)
+ service.database.get_batch_files.return_value = [file]
+
+ with patch("common.models.api.FileRecord.fromdb", return_value=mock_file_record), \
+ patch.object(service, "update_file_record", new_callable=AsyncMock):
+ await service.batch_files_final_update("batch123")
+
+
+@pytest.mark.asyncio
+async def test_get_file_counts_logs_none():
+ service = BatchService()
+ service.database = AsyncMock()
+ service.database.get_file_logs.return_value = None
+ error_count, syntax_count = await service.get_file_counts("file_id")
+ assert error_count == 0
+ assert syntax_count == 0
+
+
+@pytest.mark.asyncio
+async def test_create_candidate_upload_error():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ mock_file = {"batch_id": str(uuid4()), "original_name": "doc.txt"}
+ mock_batch = {"user_id": "user1"}
+
+ with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock(original_name="doc.txt", batch_id=mock_file["batch_id"])), \
+ patch("common.models.api.BatchRecord.fromdb", return_value=MagicMock(user_id="user1")), \
+ patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \
+ patch.object(service, "get_file_counts", return_value=(1, 1)), \
+ patch.object(service, "update_file_record", new_callable=AsyncMock):
+
+ mock_storage.return_value.upload_file.side_effect = Exception("Upload fail")
+ service.database.get_file.return_value = mock_file
+ service.database.get_batch_from_id.return_value = mock_batch
+
+ await service.create_candidate(file_id, "candidate content")
+
+
+@pytest.mark.asyncio
+async def test_get_batch_history_failure():
+ service = BatchService()
+ service.logger = MagicMock()
+ service.database = AsyncMock()
+
+ service.database.get_batch_history.side_effect = RuntimeError("DB failure")
+
+ with pytest.raises(RuntimeError, match="Error retrieving batch history"):
+ await service.get_batch_history("user1", limit=5, offset=0)
+
+
+@pytest.mark.asyncio
+async def test_delete_file_logs_exception():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ batch_id = str(uuid4())
+ mock_file = MagicMock()
+ mock_file.batch_id = batch_id
+ mock_file.blob_path = "blob"
+ mock_file.translated_path = "translated"
+ with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage:
+ mock_storage.return_value.delete_file.return_value = True
+ service.database.get_file.return_value = mock_file
+ service.database.get_batch.return_value = {"id": str(batch_id)}
+ service.database.get_batch_files.return_value = [1, 2]
+
+ with patch("common.models.api.FileRecord.fromdb", return_value=mock_file), \
+ patch("common.models.api.BatchRecord.fromdb") as mock_batch_record:
+ mock_record = MagicMock()
+ mock_record.file_count = 2
+ mock_batch_record.return_value = mock_record
+ service.database.update_batch.side_effect = Exception("Update failed")
+
+ result = await service.delete_file(file_id, "user1")
+ assert result["message"] == "File deleted successfully"
+
+
+@pytest.mark.asyncio
+async def test_upload_file_to_batch_batchrecord():
+ service = BatchService()
+ service.database = AsyncMock()
+ file = UploadFile(filename="test.txt", file=BytesIO(b"test content"))
+ batch_id = str(uuid4())
+ file_id = str(uuid4())
+
+ # Create a mock BatchRecord instance
+ mock_batch_record = MagicMock(spec=BatchRecord)
+ mock_batch_record.file_count = 0
+ mock_batch_record.updated_at = None
+
+ with patch("uuid.uuid4", return_value=file_id), \
+ patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \
+ patch("common.models.api.FileRecord.fromdb", return_value={"blob_path": "blob/path"}), \
+ patch("common.models.api.BatchRecord.fromdb", return_value=mock_batch_record):
+
+ mock_storage.return_value.upload_file.return_value = None
+ # This will trigger the BatchRecord path
+ service.database.get_batch.side_effect = [mock_batch_record]
+ service.database.get_batch_files.return_value = ["file1", "file2"]
+ service.database.get_file.return_value = {"file_id": file_id}
+ service.database.update_batch_entry.return_value = mock_batch_record
+
+ result = await service.upload_file_to_batch(batch_id, "user1", file)
+ assert "batch" in result
+ assert "file" in result
+
+
+@pytest.mark.asyncio
+async def test_upload_file_to_batch_unknown_type():
+ service = BatchService()
+ service.database = AsyncMock()
+ file = UploadFile(filename="file.txt", file=BytesIO(b"data"))
+ file_id = str(uuid4())
+
+ with patch("uuid.uuid4", return_value=file_id), \
+ patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \
+ patch("common.models.api.FileRecord.fromdb", return_value={"blob_path": "path"}):
+
+ mock_storage.return_value.upload_file.return_value = None
+ service.database.get_batch.side_effect = [object()] # Unknown type
+ service.database.get_batch_files.return_value = []
+ service.database.get_file.return_value = {"file_id": file_id}
+
+ with pytest.raises(RuntimeError, match="File upload failed"):
+ await service.upload_file_to_batch("batch123", "user1", file)
+
+
+@pytest.mark.asyncio
+@patch("common.services.batch_service.BlobStorageFactory.get_storage", new_callable=AsyncMock)
+@patch("common.models.api.FileRecord.fromdb")
+@patch("common.models.api.BatchRecord.fromdb")
+async def test_get_file_report_ioerror(mock_batch_fromdb, mock_file_fromdb, mock_get_storage):
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = "file123"
+ mock_file = {"batch_id": uuid4(), "translated_path": "some/path"}
+ mock_batch = {"batch_id": "batch123"}
+ mock_logs = [{"log": "log1"}]
+
+ mock_file_fromdb.return_value = MagicMock(dict=lambda: mock_file, batch_id=mock_file["batch_id"], translated_path="some/path")
+ mock_batch_fromdb.return_value = MagicMock(dict=lambda: mock_batch)
+ service.database.get_file.return_value = mock_file
+ service.database.get_batch_from_id.return_value = mock_batch
+ service.database.get_file_logs.return_value = mock_logs
+
+ mock_storage = AsyncMock()
+ mock_storage.get_file.side_effect = IOError("Boom")
+ mock_get_storage.return_value = mock_storage
+
+ result = await service.get_file_report(file_id)
+ assert result["translated_content"] == ""
+
+
+@pytest.mark.asyncio
+@patch("common.models.api.BatchRecord.fromdb")
+async def test_get_batch_summary_log_exception(mock_batch_fromdb):
+ service = BatchService()
+ service.database = AsyncMock()
+ mock_batch = {"batch_id": "batch1"}
+ mock_batch_record = MagicMock(dict=lambda: {"batch_id": "batch1"})
+ mock_batch_fromdb.return_value = mock_batch_record
+
+ service.database.get_batch.return_value = mock_batch
+ service.database.get_batch_files.return_value = [{"file_id": "file1", "translated_path": None}]
+ service.database.get_file_logs.side_effect = Exception("DB log fail")
+
+ result = await service.get_batch_summary("batch1", "user1")
+ assert result["files"][0]["logs"] == []
+
+
+@pytest.mark.asyncio
+async def test_update_file_not_found():
+ service = BatchService()
+ service.database = AsyncMock()
+ service.database.get_file.return_value = None
+ with pytest.raises(HTTPException) as exc:
+ await service.update_file("invalid_id", ProcessStatus.COMPLETED, FileResult.SUCCESS, 0, 0)
+ assert exc.value.status_code == 404
+
+
+@pytest.mark.asyncio
+async def test_create_candidate_success_flow():
+ service = BatchService()
+ service.database = AsyncMock()
+ file_id = str(uuid4())
+ batch_id = str(uuid4())
+ user_id = "user1"
+
+ mock_file = {"batch_id": batch_id, "original_name": "test.txt"}
+ mock_batch = {"user_id": user_id}
+
+ with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock(original_name="test.txt", batch_id=batch_id)), \
+ patch("common.models.api.BatchRecord.fromdb", return_value=MagicMock(user_id=user_id)), \
+ patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \
+ patch.object(service, "get_file_counts", return_value=(0, 0)), \
+ patch.object(service, "update_file_record", new_callable=AsyncMock):
+
+ service.database.get_file.return_value = mock_file
+ service.database.get_batch_from_id.return_value = mock_batch
+ mock_storage.return_value.upload_file.return_value = None
+
+ await service.create_candidate(file_id, "candidate content")
diff --git a/src/tests/backend/common/storage/blob_azure_test.py b/src/tests/backend/common/storage/blob_azure_test.py
index 2abb8c8..68e5ad0 100644
--- a/src/tests/backend/common/storage/blob_azure_test.py
+++ b/src/tests/backend/common/storage/blob_azure_test.py
@@ -1,204 +1,225 @@
-# blob_azure_test.py
+import json
+from io import BytesIO
+from unittest.mock import MagicMock, patch
-import asyncio
-from datetime import datetime
-import pytest
-from unittest.mock import AsyncMock, MagicMock, patch
-# Import the class under test
from common.storage.blob_azure import AzureBlobStorage
-from azure.core.exceptions import ResourceExistsError
-
-
-class DummyBlob:
- """A dummy blob item returned by list_blobs."""
- def __init__(self, name, size, creation_time, content_type, metadata):
- self.name = name
- self.size = size
- self.creation_time = creation_time
- self.content_settings = MagicMock(content_type=content_type)
- self.metadata = metadata
-
-class DummyAsyncIterator:
- """A dummy async iterator that yields the given items."""
- def __init__(self, items):
- self.items = items
- self.index = 0
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- if self.index >= len(self.items):
- raise StopAsyncIteration
- item = self.items[self.index]
- self.index += 1
- return item
-class DummyDownloadStream:
- """A dummy download stream whose content_as_bytes method returns a fixed byte string."""
- async def content_as_bytes(self):
- return b"file content"
-# --- Fixtures ---
-
-@pytest.fixture
-def dummy_storage():
- # Create an instance with dummy connection string and container name.
- return AzureBlobStorage("dummy_connection_string", "dummy_container")
+import pytest
-@pytest.fixture
-def dummy_container_client():
- container = MagicMock()
- container.create_container = AsyncMock()
- container.list_blobs = MagicMock() # Will be overridden per test.
- container.get_blob_client = MagicMock()
- return container
@pytest.fixture
-def dummy_service_client(dummy_container_client):
- service = MagicMock()
- service.get_container_client.return_value = dummy_container_client
- return service
+def mock_blob_service():
+ """Fixture to mock Azure Blob Storage service client"""
+ with patch("common.storage.blob_azure.BlobServiceClient") as mock_service:
+ mock_service_instance = MagicMock()
+ mock_container_client = MagicMock()
+ mock_blob_client = MagicMock()
-@pytest.fixture
-def dummy_blob_client():
- blob_client = MagicMock()
- blob_client.upload_blob = AsyncMock()
- blob_client.get_blob_properties = AsyncMock()
- blob_client.download_blob = AsyncMock()
- blob_client.delete_blob = AsyncMock()
- blob_client.url = "https://dummy.blob.core.windows.net/dummy_container/dummy_blob"
- return blob_client
+ # Set up mock methods
+ mock_service.return_value = mock_service_instance
+ mock_service_instance.get_container_client.return_value = mock_container_client
+ mock_container_client.get_blob_client.return_value = mock_blob_client
-# --- Tests for AzureBlobStorage methods ---
+ yield mock_service_instance, mock_container_client, mock_blob_client
-@pytest.mark.asyncio
-async def test_initialize_creates_container(dummy_storage, dummy_service_client, dummy_container_client):
- with patch("common.storage.blob_azure.BlobServiceClient.from_connection_string", return_value=dummy_service_client) as mock_from_conn:
- # Simulate normal container creation.
- dummy_container_client.create_container = AsyncMock()
- await dummy_storage.initialize()
- mock_from_conn.assert_called_once_with("dummy_connection_string")
- dummy_service_client.get_container_client.assert_called_once_with("dummy_container")
- dummy_container_client.create_container.assert_awaited_once()
-@pytest.mark.asyncio
-async def test_initialize_container_already_exists(dummy_storage, dummy_service_client, dummy_container_client):
- with patch("common.storage.blob_azure.BlobServiceClient.from_connection_string", return_value=dummy_service_client):
- # Simulate container already existing.
- dummy_container_client.create_container = AsyncMock(side_effect=ResourceExistsError("Container exists"))
- with patch.object(dummy_storage.logger, "debug") as mock_debug:
- await dummy_storage.initialize()
- dummy_container_client.create_container.assert_awaited_once()
- mock_debug.assert_called_with("Container dummy_container already exists")
+@pytest.fixture
+def blob_storage(mock_blob_service):
+ """Fixture to initialize AzureBlobStorage with mocked dependencies"""
+ service_client, container_client, blob_client = mock_blob_service
+ return AzureBlobStorage(account_name="test_account", container_name="test_container")
-@pytest.mark.asyncio
-async def test_initialize_failure(dummy_storage):
- # Simulate failure during initialization.
- with patch("common.storage.blob_azure.BlobServiceClient.from_connection_string", side_effect=Exception("Init error")):
- with patch.object(dummy_storage.logger, "error") as mock_error:
- with pytest.raises(Exception, match="Init error"):
- await dummy_storage.initialize()
- mock_error.assert_called()
@pytest.mark.asyncio
-async def test_upload_file_success(dummy_storage, dummy_blob_client):
- # Patch get_blob_client to return our dummy blob client.
- dummy_storage.container_client = MagicMock()
- dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client
-
- # Create a dummy properties object.
- dummy_properties = MagicMock()
- dummy_properties.size = 1024
- dummy_properties.content_settings = MagicMock(content_type="text/plain")
- dummy_properties.creation_time = datetime(2023, 1, 1)
- dummy_properties.etag = "dummy_etag"
- dummy_blob_client.get_blob_properties = AsyncMock(return_value=dummy_properties)
-
- file_content = b"Hello, world!"
- result = await dummy_storage.upload_file(file_content, "dummy_blob.txt", "text/plain", {"key": "value"})
- dummy_storage.container_client.get_blob_client.assert_called_once_with("dummy_blob.txt")
- dummy_blob_client.upload_blob.assert_awaited_with(file_content, content_type="text/plain", metadata={"key": "value"}, overwrite=True)
- dummy_blob_client.get_blob_properties.assert_awaited()
- assert result["path"] == "dummy_blob.txt"
+async def test_upload_file(blob_storage, mock_blob_service):
+ """Test uploading a file"""
+ _, _, mock_blob_client = mock_blob_service
+ mock_blob_client.upload_blob.return_value = MagicMock()
+ mock_blob_client.get_blob_properties.return_value = MagicMock(
+ size=1024,
+ content_settings=MagicMock(content_type="text/plain"),
+ creation_time="2024-03-15T12:00:00Z",
+ etag="dummy_etag",
+ )
+
+ file_content = BytesIO(b"dummy data")
+
+ result = await blob_storage.upload_file(file_content, "test_blob.txt", "text/plain")
+
+ assert result["path"] == "test_blob.txt"
assert result["size"] == 1024
assert result["content_type"] == "text/plain"
- assert result["url"] == dummy_blob_client.url
+ assert result["created_at"] == "2024-03-15T12:00:00Z"
assert result["etag"] == "dummy_etag"
+ assert "url" in result
+
@pytest.mark.asyncio
-async def test_upload_file_error(dummy_storage, dummy_blob_client):
- dummy_storage.container_client = MagicMock()
- dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client
- dummy_blob_client.upload_blob = AsyncMock(side_effect=Exception("Upload failed"))
+async def test_upload_file_exception(blob_storage, mock_blob_service):
+ """Test upload_file when an exception occurs"""
+ _, _, mock_blob_client = mock_blob_service
+ mock_blob_client.upload_blob.side_effect = Exception("Upload failed")
+
with pytest.raises(Exception, match="Upload failed"):
- await dummy_storage.upload_file(b"data", "blob.txt", "text/plain", {})
+ await blob_storage.upload_file(BytesIO(b"dummy data"), "test_blob.txt")
+
@pytest.mark.asyncio
-async def test_get_file_success(dummy_storage, dummy_blob_client):
- dummy_storage.container_client = MagicMock()
- dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client
- # Make download_blob return a DummyDownloadStream (not wrapped in extra coroutine)
- dummy_blob_client.download_blob = AsyncMock(return_value=DummyDownloadStream())
- result = await dummy_storage.get_file("blob.txt")
- dummy_storage.container_client.get_blob_client.assert_called_once_with("blob.txt")
- dummy_blob_client.download_blob.assert_awaited()
- assert result == b"file content"
+async def test_get_file(blob_storage, mock_blob_service):
+ """Test downloading a file"""
+ _, _, mock_blob_client = mock_blob_service
+ mock_blob_client.download_blob.return_value.readall.return_value = b"dummy data"
+
+ result = await blob_storage.get_file("test_blob.txt")
+
+ assert result == "dummy data"
+
@pytest.mark.asyncio
-async def test_get_file_error(dummy_storage, dummy_blob_client):
- dummy_storage.container_client = MagicMock()
- dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client
- dummy_blob_client.download_blob = AsyncMock(side_effect=Exception("Download error"))
- with pytest.raises(Exception, match="Download error"):
- await dummy_storage.get_file("nonexistent.txt")
+async def test_get_file_exception(blob_storage, mock_blob_service):
+ """Test get_file when an exception occurs"""
+ _, _, mock_blob_client = mock_blob_service
+ mock_blob_client.download_blob.side_effect = Exception("Download failed")
+
+ with pytest.raises(Exception, match="Download failed"):
+ await blob_storage.get_file("test_blob.txt")
+
@pytest.mark.asyncio
-async def test_delete_file_success(dummy_storage, dummy_blob_client):
- dummy_storage.container_client = MagicMock()
- dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client
- dummy_blob_client.delete_blob = AsyncMock()
- result = await dummy_storage.delete_file("blob.txt")
- dummy_storage.container_client.get_blob_client.assert_called_once_with("blob.txt")
- dummy_blob_client.delete_blob.assert_awaited()
+async def test_delete_file(blob_storage, mock_blob_service):
+ """Test deleting a file"""
+ _, _, mock_blob_client = mock_blob_service
+ mock_blob_client.delete_blob.return_value = None
+
+ result = await blob_storage.delete_file("test_blob.txt")
+
assert result is True
+
@pytest.mark.asyncio
-async def test_delete_file_error(dummy_storage, dummy_blob_client):
- dummy_storage.container_client = MagicMock()
- dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client
- dummy_blob_client.delete_blob = AsyncMock(side_effect=Exception("Delete error"))
- result = await dummy_storage.delete_file("blob.txt")
+async def test_delete_file_exception(blob_storage, mock_blob_service):
+ """Test delete_file when an exception occurs"""
+ _, _, mock_blob_client = mock_blob_service
+ mock_blob_client.delete_blob.side_effect = Exception("Delete failed")
+
+ result = await blob_storage.delete_file("test_blob.txt")
+
assert result is False
+
@pytest.mark.asyncio
-async def test_list_files_success(dummy_storage):
- dummy_storage.container_client = MagicMock()
- # Create two dummy blobs.
- blob1 = DummyBlob("file1.txt", 100, datetime(2023, 1, 1), "text/plain", {"a": "1"})
- blob2 = DummyBlob("file2.txt", 200, datetime(2023, 1, 2), "text/plain", {"b": "2"})
- async_iterator = DummyAsyncIterator([blob1, blob2])
- dummy_storage.container_client.list_blobs.return_value = async_iterator
- result = await dummy_storage.list_files("file")
+async def test_list_files(blob_storage, mock_blob_service):
+ """Test listing files in a container"""
+ _, mock_container_client, _ = mock_blob_service
+
+ class AsyncIterator:
+ """Helper class to create an async iterator"""
+
+ def __init__(self, items):
+ self._items = items
+
+ def __aiter__(self):
+ self._iter = iter(self._items)
+ return self
+
+ async def __anext__(self):
+ try:
+ return next(self._iter)
+ except StopIteration:
+ raise StopAsyncIteration
+
+ mock_blobs = [
+ MagicMock(name="file1.txt"),
+ MagicMock(name="file2.txt"),
+ ]
+
+ # Explicitly set attributes to avoid MagicMock issues
+ mock_blobs[0].name = "file1.txt"
+ mock_blobs[0].size = 123
+ mock_blobs[0].creation_time = "2024-03-15T12:00:00Z"
+ mock_blobs[0].content_settings = MagicMock(content_type="text/plain")
+ mock_blobs[0].metadata = {}
+
+ mock_blobs[1].name = "file2.txt"
+ mock_blobs[1].size = 456
+ mock_blobs[1].creation_time = "2024-03-16T12:00:00Z"
+ mock_blobs[1].content_settings = MagicMock(content_type="application/json")
+ mock_blobs[1].metadata = {}
+
+ mock_container_client.list_blobs = MagicMock(return_value=AsyncIterator(mock_blobs))
+
+ result = await blob_storage.list_files()
+
assert len(result) == 2
- names = {item["name"] for item in result}
- assert names == {"file1.txt", "file2.txt"}
+ assert result[0]["name"] == "file1.txt"
+ assert result[0]["size"] == 123
+ assert result[0]["created_at"] == "2024-03-15T12:00:00Z"
+ assert result[0]["content_type"] == "text/plain"
+ assert result[0]["metadata"] == {}
+
+ assert result[1]["name"] == "file2.txt"
+ assert result[1]["size"] == 456
+ assert result[1]["created_at"] == "2024-03-16T12:00:00Z"
+ assert result[1]["content_type"] == "application/json"
+ assert result[1]["metadata"] == {}
+
+
+@pytest.mark.asyncio
+async def test_list_files_exception(blob_storage, mock_blob_service):
+ """Test list_files when an exception occurs"""
+ _, mock_container_client, _ = mock_blob_service
+ mock_container_client.list_blobs.side_effect = Exception("List failed")
+
+ with pytest.raises(Exception, match="List failed"):
+ await blob_storage.list_files()
+
@pytest.mark.asyncio
-async def test_list_files_failure(dummy_storage):
- dummy_storage.container_client = MagicMock()
- # Define list_blobs to return an invalid object (simulate error)
- async def invalid_list_blobs(*args, **kwargs):
- # Return a plain string (which does not implement __aiter__)
- return "invalid"
- dummy_storage.container_client.list_blobs = invalid_list_blobs
- with pytest.raises(Exception):
- await dummy_storage.list_files("")
+async def test_close(blob_storage, mock_blob_service):
+ """Test closing the storage client"""
+ service_client, _, _ = mock_blob_service
+
+ await blob_storage.close()
+
+ service_client.close.assert_called_once()
+
@pytest.mark.asyncio
-async def test_close(dummy_storage):
- dummy_storage.service_client = MagicMock()
- dummy_storage.service_client.close = AsyncMock()
- await dummy_storage.close()
- dummy_storage.service_client.close.assert_awaited()
+async def test_blob_storage_init_exception():
+ """Test that an exception during initialization logs the error message"""
+ with patch("common.storage.blob_azure.BlobServiceClient") as mock_service, \
+ patch("logging.getLogger") as mock_logger: # Patch logging globally
+
+ # Mock logger instance
+ mock_logger_instance = MagicMock()
+ mock_logger.return_value = mock_logger_instance
+
+ # Simulate an exception when creating BlobServiceClient
+ mock_service.side_effect = Exception("Connection failed")
+
+ # Try to initialize AzureBlobStorage
+ try:
+ AzureBlobStorage(account_name="test_account", container_name="test_container")
+ except Exception:
+ pass # Prevent test failure due to the exception
+
+ # Construct the expected JSON log format
+ expected_error_log = json.dumps({
+ "message": "Failed to initialize Azure Blob Storage",
+ "context": {
+ "error": "Connection failed",
+ "account_name": "test_account"
+ }
+ })
+
+ expected_debug_log = json.dumps({
+ "message": "Container test_container already exists"
+ })
+
+ # Assert that error logging happened with the expected JSON string
+ mock_logger_instance.error.assert_called_once_with(expected_error_log)
+
+ # Assert that debug log is written for container existence
+ mock_logger_instance.debug.assert_called_once_with(expected_debug_log)
diff --git a/src/tests/backend/common/storage/blob_base_test.py b/src/tests/backend/common/storage/blob_base_test.py
index b4b0361..d7e2383 100644
--- a/src/tests/backend/common/storage/blob_base_test.py
+++ b/src/tests/backend/common/storage/blob_base_test.py
@@ -1,129 +1,86 @@
-import pytest
-import asyncio
-import uuid
-from datetime import datetime
-from typing import BinaryIO, Dict, Any
+from io import BytesIO
+from typing import Any, BinaryIO, Dict, Optional
+
+
+from common.storage.blob_base import BlobStorageBase # Adjust import path as needed
-# Import the abstract base class from the production code.
-from common.storage.blob_base import BlobStorageBase
+
+import pytest
-# Create a dummy concrete subclass of BlobStorageBase that calls the parent's abstract methods.
-class DummyBlobStorage(BlobStorageBase):
- async def initialize(self) -> None:
- # Call the parent (which is just a pass)
- await super().initialize()
- # Return a dummy value so we can verify our override is called.
- return "initialized"
+class MockBlobStorage(BlobStorageBase):
+ """Mock implementation of BlobStorageBase for testing"""
async def upload_file(
self,
file_content: BinaryIO,
blob_path: str,
- content_type: str = None,
- metadata: Dict[str, str] = None,
+ content_type: Optional[str] = None,
+ metadata: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
- await super().upload_file(file_content, blob_path, content_type, metadata)
- # Return a dummy dictionary that simulates upload details.
return {
- "url": "https://dummy.blob.core.windows.net/dummy_container/" + blob_path,
- "size": len(file_content),
- "etag": "dummy_etag",
+ "path": blob_path,
+ "size": len(file_content.read()),
+ "content_type": content_type or "application/octet-stream",
+ "metadata": metadata or {},
+ "url": f"https://mockstorage.com/{blob_path}",
}
async def get_file(self, blob_path: str) -> BinaryIO:
- await super().get_file(blob_path)
- # Return dummy binary content.
- return b"dummy content"
+ return BytesIO(b"mock data")
async def delete_file(self, blob_path: str) -> bool:
- await super().delete_file(blob_path)
- # Simulate a successful deletion.
return True
- async def list_files(self, prefix: str = None) -> list[Dict[str, Any]]:
- await super().list_files(prefix)
+ async def list_files(self, prefix: Optional[str] = None) -> list[Dict[str, Any]]:
return [
- {
- "name": "dummy.txt",
- "size": 123,
- "created_at": datetime.now(),
- "content_type": "text/plain",
- "metadata": {"dummy": "value"},
- }
+ {"name": "file1.txt", "size": 100, "content_type": "text/plain"},
+ {"name": "file2.jpg", "size": 200, "content_type": "image/jpeg"},
]
-# tests cases with each method.
+@pytest.fixture
+def mock_blob_storage():
+ """Fixture to provide a MockBlobStorage instance"""
+ return MockBlobStorage()
@pytest.mark.asyncio
-async def test_initialize():
- storage = DummyBlobStorage()
- result = await storage.initialize()
- # Since the dummy override returns "initialized" after calling super(),
- # we assert that the result equals that string.
- assert result == "initialized"
-
+async def test_upload_file(mock_blob_storage):
+ """Test upload_file method"""
+ file_content = BytesIO(b"dummy data")
+ result = await mock_blob_storage.upload_file(file_content, "test_blob.txt", "text/plain")
-@pytest.mark.asyncio
-async def test_upload_file():
- storage = DummyBlobStorage()
- content = b"hello world"
- blob_path = "folder/hello.txt"
- content_type = "text/plain"
- metadata = {"key": "value"}
- result = await storage.upload_file(content, blob_path, content_type, metadata)
- # Verify that our dummy return value is as expected.
- assert (
- result["url"]
- == "https://dummy.blob.core.windows.net/dummy_container/" + blob_path
- )
- assert result["size"] == len(content)
- assert result["etag"] == "dummy_etag"
+ assert result["path"] == "test_blob.txt"
+ assert result["size"] == len(b"dummy data")
+ assert result["content_type"] == "text/plain"
+ assert "url" in result
@pytest.mark.asyncio
-async def test_get_file():
- storage = DummyBlobStorage()
- result = await storage.get_file("folder/hello.txt")
- # Verify that we get the dummy binary content.
- assert result == b"dummy content"
-
+async def test_get_file(mock_blob_storage):
+ """Test get_file method"""
+ result = await mock_blob_storage.get_file("test_blob.txt")
-@pytest.mark.asyncio
-async def test_delete_file():
- storage = DummyBlobStorage()
- result = await storage.delete_file("folder/hello.txt")
- # Verify that deletion returns True.
- assert result is True
+ assert isinstance(result, BytesIO)
+ assert result.read() == b"mock data"
@pytest.mark.asyncio
-async def test_list_files():
- storage = DummyBlobStorage()
- result = await storage.list_files("dummy")
- # Verify that we receive a list with one item having a 'name' key.
- assert isinstance(result, list)
- assert len(result) == 1
- assert "dummy.txt" in result[0]["name"]
- assert result[0]["size"] == 123
- assert result[0]["content_type"] == "text/plain"
- assert result[0]["metadata"] == {"dummy": "value"}
+async def test_delete_file(mock_blob_storage):
+ """Test delete_file method"""
+ result = await mock_blob_storage.delete_file("test_blob.txt")
+
+ assert result is True
@pytest.mark.asyncio
-async def test_smoke_all_methods():
- storage = DummyBlobStorage()
- init_val = await storage.initialize()
- assert init_val == "initialized"
- upload_val = await storage.upload_file(
- b"data", "file.txt", "text/plain", {"a": "b"}
- )
- assert upload_val["size"] == 4
- file_val = await storage.get_file("file.txt")
- assert file_val == b"dummy content"
- delete_val = await storage.delete_file("file.txt")
- assert delete_val is True
- list_val = await storage.list_files("file")
- assert isinstance(list_val, list)
+async def test_list_files(mock_blob_storage):
+ """Test list_files method"""
+ result = await mock_blob_storage.list_files()
+
+ assert len(result) == 2
+ assert result[0]["name"] == "file1.txt"
+ assert result[1]["name"] == "file2.jpg"
+ assert result[0]["size"] == 100
+ assert result[1]["size"] == 200
diff --git a/src/tests/backend/common/storage/blob_factory_test.py b/src/tests/backend/common/storage/blob_factory_test.py
index e19af49..70ed7ec 100644
--- a/src/tests/backend/common/storage/blob_factory_test.py
+++ b/src/tests/backend/common/storage/blob_factory_test.py
@@ -1,262 +1,78 @@
-# blob_factory_test.py
-import asyncio
-import json
-import os
-import sys
-import pytest
-from unittest.mock import AsyncMock, MagicMock, patch
-
-# Adjust sys.path so that the project root is found.
-sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
+from unittest.mock import MagicMock, patch
-# Set required environment variables (dummy values)
-os.environ["COSMOSDB_ENDPOINT"] = "https://dummy-endpoint"
-os.environ["COSMOSDB_KEY"] = "dummy-key"
-os.environ["COSMOSDB_DATABASE"] = "dummy-database"
-os.environ["COSMOSDB_CONTAINER"] = "dummy-container"
-os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "dummy-deployment"
-os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01"
-os.environ["AZURE_OPENAI_ENDPOINT"] = "https://dummy-openai-endpoint"
-# Patch missing azure module so that event_utils imports without error.
-sys.modules["azure.monitor.events.extension"] = MagicMock()
-
-# --- Import the module under test ---
from common.storage.blob_factory import BlobStorageFactory
-from common.storage.blob_base import BlobStorageBase
-from common.storage.blob_azure import AzureBlobStorage
-
-# --- Dummy configuration for testing ---
-class DummyConfig:
- azure_blob_connection_string = "dummy_connection_string"
- azure_blob_container_name = "dummy_container"
-
-# --- Fixture to patch Config in our tests ---
-@pytest.fixture(autouse=True)
-def patch_config(monkeypatch):
- # Import the real Config from your project.
- from common.config.config import Config
-
- def dummy_init(self):
- self.azure_blob_connection_string = DummyConfig.azure_blob_connection_string
- self.azure_blob_container_name = DummyConfig.azure_blob_container_name
- monkeypatch.setattr(Config, "__init__", dummy_init)
- # Reset the BlobStorageFactory singleton before each test.
- BlobStorageFactory._instance = None
-
-
-class DummyAzureBlobStorage(BlobStorageBase):
- def __init__(self, connection_string: str, container_name: str):
- self.connection_string = connection_string
- self.container_name = container_name
- self.initialized = False
- self.files = {} # maps blob_path to tuple(file_content, content_type, metadata)
- async def initialize(self):
- self.initialized = True
- async def upload_file(self, file_content: bytes, blob_path: str, content_type: str, metadata: dict):
- self.files[blob_path] = (file_content, content_type, metadata)
- return {
- "url": f"https://dummy.blob.core.windows.net/{self.container_name}/{blob_path}",
- "size": len(file_content),
- "etag": "dummy_etag"
- }
-
- async def get_file(self, blob_path: str):
- if blob_path in self.files:
- return self.files[blob_path][0]
- else:
- raise FileNotFoundError(f"File {blob_path} not found")
-
- async def delete_file(self, blob_path: str):
- if blob_path in self.files:
- del self.files[blob_path]
- # No error if file does not exist.
-
- async def list_files(self, prefix: str = ""):
- return [path for path in self.files if path.startswith(prefix)]
+import pytest
- async def close(self):
- self.initialized = False
-# --- Fixture to patch AzureBlobStorage ---
-@pytest.fixture(autouse=True)
-def patch_azure_blob_storage(monkeypatch):
- monkeypatch.setattr("common.storage.blob_factory.AzureBlobStorage", DummyAzureBlobStorage)
+@pytest.mark.asyncio
+async def test_get_storage_logs_on_init():
+ """Test that logger logs on initialization"""
+ # Force reset the singleton before test
BlobStorageFactory._instance = None
-# -------------------- Tests for BlobStorageFactory --------------------
+ mock_storage_instance = MagicMock()
-@pytest.mark.asyncio
-async def test_get_storage_success():
- """Test that get_storage returns an initialized DummyAzureBlobStorage instance and is a singleton."""
- storage = await BlobStorageFactory.get_storage()
- assert isinstance(storage, DummyAzureBlobStorage)
- assert storage.initialized is True
+ with patch("common.storage.blob_factory.AzureBlobStorage", return_value=mock_storage_instance), \
+ patch("common.storage.blob_factory.Config") as mock_config, \
+ patch.object(BlobStorageFactory, "_logger") as mock_logger:
- # Call get_storage again; it should return the same instance.
- storage2 = await BlobStorageFactory.get_storage()
- assert storage is storage2
+ mock_config_instance = MagicMock()
+ mock_config_instance.azure_blob_account_name = "account"
+ mock_config_instance.azure_blob_container_name = "container"
+ mock_config.return_value = mock_config_instance
-@pytest.mark.asyncio
-async def test_get_storage_missing_config(monkeypatch):
- """
- Test that get_storage raises a ValueError when configuration is missing.
- We simulate missing connection string and container name.
- """
- from common.config.config import Config
- def dummy_init_missing(self):
- self.azure_blob_connection_string = ""
- self.azure_blob_container_name = ""
- monkeypatch.setattr(Config, "__init__", dummy_init_missing)
- with pytest.raises(ValueError, match="Azure Blob Storage configuration is missing"):
await BlobStorageFactory.get_storage()
-@pytest.mark.asyncio
-async def test_close_storage_success():
- """Test that close_storage calls close() on the storage instance and resets the singleton."""
- storage = await BlobStorageFactory.get_storage()
- # Patch close() method with an async mock.
- storage.close = AsyncMock()
- await BlobStorageFactory.close_storage()
- storage.close.assert_called_once()
- assert BlobStorageFactory._instance is None
-
-# -------------------- File Upload Tests --------------------
+ mock_logger.info.assert_called_once_with("Initialized Azure Blob Storage: container")
-@pytest.mark.asyncio
-async def test_upload_file_success():
- """Test that upload_file successfully uploads a file and returns metadata."""
- storage = DummyAzureBlobStorage("dummy", "container")
- await storage.initialize()
- file_content = b"Hello, Blob!"
- blob_path = "folder/blob.txt"
- content_type = "text/plain"
- metadata = {"meta": "data"}
- result = await storage.upload_file(file_content, blob_path, content_type, metadata)
- assert "url" in result
- assert result["size"] == len(file_content)
- assert blob_path in storage.files
@pytest.mark.asyncio
-async def test_upload_file_error(monkeypatch):
- """Test that an exception during file upload is propagated."""
- storage = DummyAzureBlobStorage("dummy", "container")
- await storage.initialize()
- monkeypatch.setattr(storage, "upload_file", AsyncMock(side_effect=Exception("Upload failed")))
- with pytest.raises(Exception, match="Upload failed"):
- await storage.upload_file(b"data", "file.txt", "text/plain", {})
-
-# -------------------- File Retrieval Tests --------------------
+async def test_close_storage_resets_instance():
+ """Test that close_storage resets the singleton instance"""
+ # Setup instance first
+ mock_storage_instance = MagicMock()
-@pytest.mark.asyncio
-async def test_get_file_success():
- """Test that get_file retrieves the correct file content."""
- storage = DummyAzureBlobStorage("dummy", "container")
- await storage.initialize()
- blob_path = "folder/data.bin"
- file_content = b"BinaryData"
- storage.files[blob_path] = (file_content, "application/octet-stream", {})
- result = await storage.get_file(blob_path)
- assert result == file_content
+ with patch("common.storage.blob_factory.AzureBlobStorage", return_value=mock_storage_instance), \
+ patch("common.storage.blob_factory.Config") as mock_config:
-@pytest.mark.asyncio
-async def test_get_file_not_found():
- """Test that get_file raises FileNotFoundError when file does not exist."""
- storage = DummyAzureBlobStorage("dummy", "container")
- await storage.initialize()
- with pytest.raises(FileNotFoundError):
- await storage.get_file("nonexistent.file")
+ mock_config_instance = MagicMock()
+ mock_config_instance.azure_blob_account_name = "account"
+ mock_config_instance.azure_blob_container_name = "container"
+ mock_config.return_value = mock_config_instance
-# -------------------- File Deletion Tests --------------------
+ instance = await BlobStorageFactory.get_storage()
+ assert instance is not None
-@pytest.mark.asyncio
-async def test_delete_file_success():
- """Test that delete_file removes an existing file."""
- storage = DummyAzureBlobStorage("dummy", "container")
- await storage.initialize()
- blob_path = "folder/remove.txt"
- storage.files[blob_path] = (b"To remove", "text/plain", {})
- await storage.delete_file(blob_path)
- assert blob_path not in storage.files
+ await BlobStorageFactory.close_storage()
-@pytest.mark.asyncio
-async def test_delete_file_nonexistent():
- """Test that deleting a non-existent file does not raise an error."""
- storage = DummyAzureBlobStorage("dummy", "container")
- await storage.initialize()
- # Should not raise any exception.
- await storage.delete_file("nonexistent.file")
- assert True
+ assert BlobStorageFactory._instance is None
-# -------------------- File Listing Tests --------------------
@pytest.mark.asyncio
-async def test_list_files_with_prefix():
- """Test that list_files returns files that match the given prefix."""
- storage = DummyAzureBlobStorage("dummy", "container")
- await storage.initialize()
- storage.files = {
- "folder/a.txt": (b"A", "text/plain", {}),
- "folder/b.txt": (b"B", "text/plain", {}),
- "other/c.txt": (b"C", "text/plain", {}),
- }
- result = await storage.list_files("folder/")
- assert set(result) == {"folder/a.txt", "folder/b.txt"}
-
-@pytest.mark.asyncio
-async def test_list_files_no_files():
- """Test that list_files returns an empty list when no files match the prefix."""
- storage = DummyAzureBlobStorage("dummy", "container")
- await storage.initialize()
- storage.files = {}
- result = await storage.list_files("prefix/")
- assert result == []
+async def test_get_storage_after_close_reinitializes():
+ """Test that get_storage reinitializes after close_storage is called"""
+ # Force reset before test
+ BlobStorageFactory._instance = None
-# -------------------- Additional Basic Tests --------------------
+ with patch("common.storage.blob_factory.AzureBlobStorage") as mock_storage, \
+ patch("common.storage.blob_factory.Config") as mock_config:
-@pytest.mark.asyncio
-async def test_dummy_azure_blob_storage_initialize():
- """Test that initializing DummyAzureBlobStorage sets the initialized flag."""
- storage = DummyAzureBlobStorage("dummy_conn", "dummy_container")
- assert storage.initialized is False
- await storage.initialize()
- assert storage.initialized is True
+ mock_storage.side_effect = [MagicMock(name="instance1"), MagicMock(name="instance2")]
-@pytest.mark.asyncio
-async def test_dummy_azure_blob_storage_upload_and_retrieve():
- """Test that a file uploaded to DummyAzureBlobStorage can be retrieved."""
- storage = DummyAzureBlobStorage("dummy_conn", "dummy_container")
- await storage.initialize()
- content = b"Sample file content"
- blob_path = "folder/sample.txt"
- metadata = {"author": "tester"}
- result = await storage.upload_file(content, blob_path, "text/plain", metadata)
- assert "url" in result
- assert result["size"] == len(content)
- retrieved = await storage.get_file(blob_path)
- assert retrieved == content
-
-@pytest.mark.asyncio
-async def test_dummy_azure_blob_storage_close():
- """Test that close() sets initialized to False."""
- storage = DummyAzureBlobStorage("dummy_conn", "dummy_container")
- await storage.initialize()
- await storage.close()
- assert storage.initialized is False
+ mock_config_instance = MagicMock()
+ mock_config_instance.azure_blob_account_name = "account"
+ mock_config_instance.azure_blob_container_name = "container"
+ mock_config.return_value = mock_config_instance
-# -------------------- Test for BlobStorageFactory Singleton Usage --------------------
+ # First init
+ instance1 = await BlobStorageFactory.get_storage()
+ await BlobStorageFactory.close_storage()
-def test_common_usage_of_blob_factory():
- """Test that manually setting the singleton in BlobStorageFactory works as expected."""
- # Create a dummy storage instance.
- dummy_storage = DummyAzureBlobStorage("dummy", "container")
- dummy_storage.initialized = True
- BlobStorageFactory._instance = dummy_storage
- storage = asyncio.run(BlobStorageFactory.get_storage())
- assert storage is dummy_storage
+ # Re-init
+ instance2 = await BlobStorageFactory.get_storage()
-if __name__ == "__main__":
- # Run tests when this file is executed directly.
- asyncio.run(pytest.main())
+ assert instance1 is not instance2
+ assert mock_storage.call_count == 2
diff --git a/src/tests/backend/sql_agents/agents/agent_config_test.py b/src/tests/backend/sql_agents/agents/agent_config_test.py
new file mode 100644
index 0000000..8250a23
--- /dev/null
+++ b/src/tests/backend/sql_agents/agents/agent_config_test.py
@@ -0,0 +1,42 @@
+import importlib
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+
+@pytest.fixture
+def mock_project_client():
+ return AsyncMock()
+
+
+@patch.dict("os.environ", {
+ "MIGRATOR_AGENT_MODEL_DEPLOY": "migrator-model",
+ "PICKER_AGENT_MODEL_DEPLOY": "picker-model",
+ "FIXER_AGENT_MODEL_DEPLOY": "fixer-model",
+ "SEMANTIC_VERIFIER_AGENT_MODEL_DEPLOY": "semantic-verifier-model",
+ "SYNTAX_CHECKER_AGENT_MODEL_DEPLOY": "syntax-checker-model",
+ "SELECTION_MODEL_DEPLOY": "selection-model",
+ "TERMINATION_MODEL_DEPLOY": "termination-model",
+})
+def test_agent_model_type_mapping_and_instance(mock_project_client):
+ # Re-import to re-evaluate class variable with patched env
+ from sql_agents.agents import agent_config
+ importlib.reload(agent_config)
+
+ AgentType = agent_config.AgentType
+ AgentBaseConfig = agent_config.AgentBaseConfig
+
+ # Test model_type mapping
+ assert AgentBaseConfig.model_type[AgentType.MIGRATOR] == "migrator-model"
+ assert AgentBaseConfig.model_type[AgentType.PICKER] == "picker-model"
+ assert AgentBaseConfig.model_type[AgentType.FIXER] == "fixer-model"
+ assert AgentBaseConfig.model_type[AgentType.SEMANTIC_VERIFIER] == "semantic-verifier-model"
+ assert AgentBaseConfig.model_type[AgentType.SYNTAX_CHECKER] == "syntax-checker-model"
+ assert AgentBaseConfig.model_type[AgentType.SELECTION] == "selection-model"
+ assert AgentBaseConfig.model_type[AgentType.TERMINATION] == "termination-model"
+
+ # Test __init__ stores params correctly
+ config = AgentBaseConfig(mock_project_client, sql_from="sql1", sql_to="sql2")
+ assert config.ai_project_client == mock_project_client
+ assert config.sql_from == "sql1"
+ assert config.sql_to == "sql2"
diff --git a/src/tests/conftest.py b/src/tests/conftest.py
new file mode 100644
index 0000000..cad4e26
--- /dev/null
+++ b/src/tests/conftest.py
@@ -0,0 +1,12 @@
+import os
+import sys
+
+# Determine the project root relative to this conftest.py file.
+# This file is at:
/src/tests/conftest.py
+# We want to add: /src/backend to sys.path.
+current_dir = os.path.dirname(os.path.abspath(__file__))
+project_root = os.path.abspath(os.path.join(current_dir, "..")) # Goes from tests to src
+backend_path = os.path.join(project_root, "backend")
+sys.path.insert(0, backend_path)
+
+print("Adjusted sys.path:", sys.path)